Initial untested support for batching requests
parent
882c9b64ba
commit
3673bd6fe1
111
correction.py
111
correction.py
|
|
@ -18,6 +18,10 @@ parser.add_argument("--overwrite", action="store_true",
|
|||
parser.add_argument("--limit", type=int, help="limit calls to gemini rpo integer")
|
||||
parser.add_argument("--refaire", action="store_true",
|
||||
help="Redo specific copies/labels defined in refaire.json")
|
||||
parser.add_argument("--batch", action="store_true",
|
||||
help="Generate a JSONL file of requests to send to the Gemini Batch API")
|
||||
parser.add_argument("--deal-with-batched", type=str, metavar="FILE",
|
||||
help="Process a JSONL file containing completed batch results")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
|
||||
|
|
@ -548,7 +552,7 @@ Here is a list of all possible labels. You need to answer with a list one of the
|
|||
|
||||
return new_tasks
|
||||
|
||||
def process_single_task(task_tuple):
|
||||
def process_single_task(task_tuple, precomputed_response=None):
|
||||
try:
|
||||
global pro_count, flash_count, pro_quota_exhausted
|
||||
file_path = task_tuple[0]
|
||||
|
|
@ -567,25 +571,32 @@ def process_single_task(task_tuple):
|
|||
total_height = group_data[-1][2]
|
||||
use_flash = n >= 4 or total_height <= 500
|
||||
|
||||
if not use_flash:
|
||||
with pro_lock:
|
||||
if pro_quota_exhausted:
|
||||
use_flash = True
|
||||
elif limit is None or pro_count < limit:
|
||||
pro_count += 1
|
||||
else:
|
||||
use_flash = True
|
||||
# Only apply limits and counts if we are making a live call
|
||||
if precomputed_response is None:
|
||||
if not use_flash:
|
||||
with pro_lock:
|
||||
if pro_quota_exhausted:
|
||||
use_flash = True
|
||||
elif limit is None or pro_count < limit:
|
||||
pro_count += 1
|
||||
else:
|
||||
use_flash = True
|
||||
|
||||
if use_flash:
|
||||
with pro_lock:
|
||||
flash_count += 1
|
||||
if use_flash:
|
||||
with pro_lock:
|
||||
flash_count += 1
|
||||
|
||||
try:
|
||||
contents, config = generate_request(file_path, label)
|
||||
model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
|
||||
tprint(f"Asking Gemini {'Flash' if use_flash else 'Pro '}: {label} {group_name}")
|
||||
|
||||
full_response_text = call_gemini_with_retries(model_to_use, contents, config)
|
||||
if precomputed_response:
|
||||
tprint(f"Using batched response for: {label} {group_name}")
|
||||
full_response_text = precomputed_response
|
||||
else:
|
||||
tprint(f"Asking Gemini {'Flash' if use_flash else 'Pro '}: {label} {group_name}")
|
||||
full_response_text = call_gemini_with_retries(model_to_use, contents, config)
|
||||
|
||||
json_data = json.loads(full_response_text)
|
||||
|
||||
# Ensure consistency of answer placements
|
||||
|
|
@ -723,10 +734,78 @@ if __name__ == "__main__":
|
|||
else:
|
||||
print(f"Warning: --refaire flag used, but {refaire_path} not found.", file=sys.stderr)
|
||||
|
||||
if args.batch:
|
||||
batch_file = Path(INPUT_DIR) / "batch_requests.jsonl"
|
||||
with open(batch_file, "w", encoding="utf-8") as f:
|
||||
for task in tasks_to_process:
|
||||
file_path, label = task[0], task[1]
|
||||
group_name = os.path.splitext(file_path)[0]
|
||||
json_path = group_name + '.json'
|
||||
|
||||
with open(json_path, 'r') as jf:
|
||||
group_data = json.load(jf)
|
||||
use_flash = len(group_data) >= 4 or group_data[-1][2] <= 500
|
||||
model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
|
||||
|
||||
image_data = Path(file_path).read_bytes()
|
||||
b64_img = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
# Format payload. NOTE: adapt the JSON format if your specific Gemini
|
||||
# Batch API endpoint expects a slightly different schema.
|
||||
req = {
|
||||
"custom_id": file_path, # Mapping ID
|
||||
"method": "POST",
|
||||
"url": f"/v1beta/models/{model_to_use}:generateContent",
|
||||
"body": {
|
||||
"contents": [{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{"inlineData": {"mimeType": "image/jpeg", "data": b64_img}},
|
||||
{"text": make_prompt(label)}
|
||||
]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"temperature": 1.0,
|
||||
"topP": 0.95,
|
||||
"maxOutputTokens": 65535,
|
||||
"responseMimeType": "application/json",
|
||||
"responseSchema": TypeAdapter(List[EvaluationEntry]).json_schema()
|
||||
}
|
||||
}
|
||||
}
|
||||
f.write(json.dumps(req) + "\n")
|
||||
|
||||
print(f"Batch generation complete. {len(tasks_to_process)} requests saved to {batch_file}")
|
||||
sys.exit(0)
|
||||
|
||||
batched_responses = {}
|
||||
if args.deal_with_batched:
|
||||
batch_results_path = Path(args.deal_with_batched)
|
||||
if batch_results_path.exists():
|
||||
print(f"Loading batch results from {batch_results_path}...")
|
||||
with open(batch_results_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip(): continue
|
||||
data = json.loads(line)
|
||||
task_id = data.get("custom_id")
|
||||
|
||||
# Extract the JSON response text. Adapt this path to match your API output schema!
|
||||
try:
|
||||
resp_text = data["response"]["body"]["candidates"][0]["content"]["parts"][0]["text"]
|
||||
batched_responses[task_id] = resp_text
|
||||
except (KeyError, IndexError):
|
||||
batched_responses[task_id] = data.get("response_text", "")
|
||||
else:
|
||||
print(f"Warning: Batch results file {batch_results_path} not found.", file=sys.stderr)
|
||||
|
||||
print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
|
||||
futures = {executor.submit(process_single_task, task): task for task in tasks_to_process}
|
||||
futures = {}
|
||||
for task in tasks_to_process:
|
||||
file_path = task[0]
|
||||
precomp = batched_responses.get(file_path)
|
||||
futures[executor.submit(process_single_task, task, precomp)] = task
|
||||
|
||||
# Process tasks as they complete, allowing dynamic task addition
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
|
|
@ -734,11 +813,11 @@ if __name__ == "__main__":
|
|||
new_generated_tasks = future.result()
|
||||
if new_generated_tasks:
|
||||
for new_task in new_generated_tasks:
|
||||
# New tasks from wrong-label/additional-answer will fallback to live API
|
||||
futures[executor.submit(process_single_task, new_task)] = new_task
|
||||
except Exception as e:
|
||||
print(f"Exception during task execution: {e}", file=sys.stderr)
|
||||
|
||||
|
||||
end_time = time.time()
|
||||
print("Time elapsed : ", end_time - start_time)
|
||||
print("Requests to pro / flash : ", pro_count, flash_count)
|
||||
|
|
|
|||
Loading…
Reference in New Issue