From bc47f815562e44ab4416e2739956df4ba58d41b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Miquel?= Date: Tue, 21 Apr 2026 14:32:20 +0200 Subject: [PATCH] Fix batching, hopefully --- correction.py | 55 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/correction.py b/correction.py index a39928b..f17286f 100644 --- a/correction.py +++ b/correction.py @@ -741,9 +741,17 @@ if __name__ == "__main__": else: print(f"Warning: --refaire flag used, but {refaire_path} not found.", file=sys.stderr) + if args.batch: - batch_file = Path(INPUT_DIR) / "batch_requests.jsonl" - with open(batch_file, "w", encoding="utf-8") as f: + batch_flash_file = Path(INPUT_DIR) / "batch_requests_flash.jsonl" + batch_pro_file = Path(INPUT_DIR) / "batch_requests_pro.jsonl" + + count_flash = 0 + count_pro = 0 + + with open(batch_flash_file, "w", encoding="utf-8") as f_flash, \ + open(batch_pro_file, "w", encoding="utf-8") as f_pro: + for task in tasks_to_process: file_path, label = task[0], task[1] group_name = os.path.splitext(file_path)[0] @@ -752,18 +760,14 @@ if __name__ == "__main__": with open(json_path, 'r') as jf: group_data = json.load(jf) use_flash = len(group_data) >= 4 or group_data[-1][2] <= 500 - model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro image_data = Path(file_path).read_bytes() b64_img = base64.b64encode(image_data).decode("utf-8") - # Format payload. NOTE: adapt the JSON format if your specific Gemini - # Batch API endpoint expects a slightly different schema. + # Format payload matching Gemini Batch API file requirements req = { - "custom_id": file_path, # Mapping ID - "method": "POST", - "url": f"/v1beta/models/{model_to_use}:generateContent", - "body": { + "key": file_path, # The ID returned in the output file + "request": { "contents": [{ "role": "user", "parts": [ @@ -771,7 +775,7 @@ if __name__ == "__main__": {"text": make_prompt(label)} ] }], - "generationConfig": { + "generation_config": { "temperature": 1.0, "topP": 0.95, "maxOutputTokens": 65535, @@ -780,9 +784,18 @@ if __name__ == "__main__": } } } - f.write(json.dumps(req) + "\n") - print(f"Batch generation complete. {len(tasks_to_process)} requests saved to {batch_file}") + if use_flash: + f_flash.write(json.dumps(req) + "\n") + count_flash += 1 + else: + f_pro.write(json.dumps(req) + "\n") + count_pro += 1 + + print(f"Batch generation complete.") + print(f" - {count_flash} requests saved to {batch_flash_file} (for {MODEL_ID_flash})") + print(f" - {count_pro} requests saved to {batch_pro_file} (for {MODEL_ID_pro})") + print("Upload these files via the File API and create two separate batch jobs.") sys.exit(0) batched_responses = {} @@ -794,19 +807,21 @@ if __name__ == "__main__": for line in f: if not line.strip(): continue data = json.loads(line) - task_id = data.get("custom_id") + task_id = data.get("key") # Corresponds to the key sent in the request - # Extract the JSON response text. Adapt this path to match your API output schema! - try: - resp_text = data["response"]["body"]["candidates"][0]["content"]["parts"][0]["text"] - batched_responses[task_id] = resp_text - except (KeyError, IndexError): - batched_responses[task_id] = data.get("response_text", "") + if "response" in data: + try: + # Extract the JSON response text per standard Batch API schema + resp_text = data["response"]["candidates"][0]["content"]["parts"][0]["text"] + batched_responses[task_id] = resp_text + except (KeyError, IndexError) as e: + print(f"Warning: Could not parse response for {task_id}: {e}", file=sys.stderr) + elif "error" in data: + print(f"Batch API Error for {task_id}: {data['error']}", file=sys.stderr) else: print(f"Warning: Batch results file {batch_results_path} not found.", file=sys.stderr) print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...") - with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor: futures = {} for task in tasks_to_process: