diff --git a/correction.py b/correction.py index 5084b08..049ff61 100644 --- a/correction.py +++ b/correction.py @@ -18,6 +18,10 @@ parser.add_argument("--overwrite", action="store_true", parser.add_argument("--limit", type=int, help="limit calls to gemini rpo integer") parser.add_argument("--refaire", action="store_true", help="Redo specific copies/labels defined in refaire.json") +parser.add_argument("--batch", action="store_true", + help="Generate a JSONL file of requests to send to the Gemini Batch API") +parser.add_argument("--deal-with-batched", type=str, metavar="FILE", + help="Process a JSONL file containing completed batch results") args, _ = parser.parse_known_args() @@ -548,7 +552,7 @@ Here is a list of all possible labels. You need to answer with a list one of the return new_tasks -def process_single_task(task_tuple): +def process_single_task(task_tuple, precomputed_response=None): try: global pro_count, flash_count, pro_quota_exhausted file_path = task_tuple[0] @@ -567,25 +571,32 @@ def process_single_task(task_tuple): total_height = group_data[-1][2] use_flash = n >= 4 or total_height <= 500 - if not use_flash: - with pro_lock: - if pro_quota_exhausted: - use_flash = True - elif limit is None or pro_count < limit: - pro_count += 1 - else: - use_flash = True + # Only apply limits and counts if we are making a live call + if precomputed_response is None: + if not use_flash: + with pro_lock: + if pro_quota_exhausted: + use_flash = True + elif limit is None or pro_count < limit: + pro_count += 1 + else: + use_flash = True - if use_flash: - with pro_lock: - flash_count += 1 + if use_flash: + with pro_lock: + flash_count += 1 try: contents, config = generate_request(file_path, label) model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro - tprint(f"Asking Gemini {'Flash' if use_flash else 'Pro '}: {label} {group_name}") - full_response_text = call_gemini_with_retries(model_to_use, contents, config) + if precomputed_response: + tprint(f"Using batched response for: {label} {group_name}") + full_response_text = precomputed_response + else: + tprint(f"Asking Gemini {'Flash' if use_flash else 'Pro '}: {label} {group_name}") + full_response_text = call_gemini_with_retries(model_to_use, contents, config) + json_data = json.loads(full_response_text) # Ensure consistency of answer placements @@ -723,10 +734,78 @@ if __name__ == "__main__": else: print(f"Warning: --refaire flag used, but {refaire_path} not found.", file=sys.stderr) + if args.batch: + batch_file = Path(INPUT_DIR) / "batch_requests.jsonl" + with open(batch_file, "w", encoding="utf-8") as f: + for task in tasks_to_process: + file_path, label = task[0], task[1] + group_name = os.path.splitext(file_path)[0] + json_path = group_name + '.json' + + with open(json_path, 'r') as jf: + group_data = json.load(jf) + use_flash = len(group_data) >= 4 or group_data[-1][2] <= 500 + model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro + + image_data = Path(file_path).read_bytes() + b64_img = base64.b64encode(image_data).decode("utf-8") + + # Format payload. NOTE: adapt the JSON format if your specific Gemini + # Batch API endpoint expects a slightly different schema. + req = { + "custom_id": file_path, # Mapping ID + "method": "POST", + "url": f"/v1beta/models/{model_to_use}:generateContent", + "body": { + "contents": [{ + "role": "user", + "parts": [ + {"inlineData": {"mimeType": "image/jpeg", "data": b64_img}}, + {"text": make_prompt(label)} + ] + }], + "generationConfig": { + "temperature": 1.0, + "topP": 0.95, + "maxOutputTokens": 65535, + "responseMimeType": "application/json", + "responseSchema": TypeAdapter(List[EvaluationEntry]).json_schema() + } + } + } + f.write(json.dumps(req) + "\n") + + print(f"Batch generation complete. {len(tasks_to_process)} requests saved to {batch_file}") + sys.exit(0) + + batched_responses = {} + if args.deal_with_batched: + batch_results_path = Path(args.deal_with_batched) + if batch_results_path.exists(): + print(f"Loading batch results from {batch_results_path}...") + with open(batch_results_path, "r", encoding="utf-8") as f: + for line in f: + if not line.strip(): continue + data = json.loads(line) + task_id = data.get("custom_id") + + # Extract the JSON response text. Adapt this path to match your API output schema! + try: + resp_text = data["response"]["body"]["candidates"][0]["content"]["parts"][0]["text"] + batched_responses[task_id] = resp_text + except (KeyError, IndexError): + batched_responses[task_id] = data.get("response_text", "") + else: + print(f"Warning: Batch results file {batch_results_path} not found.", file=sys.stderr) + print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...") with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor: - futures = {executor.submit(process_single_task, task): task for task in tasks_to_process} + futures = {} + for task in tasks_to_process: + file_path = task[0] + precomp = batched_responses.get(file_path) + futures[executor.submit(process_single_task, task, precomp)] = task # Process tasks as they complete, allowing dynamic task addition for future in concurrent.futures.as_completed(futures): @@ -734,11 +813,11 @@ if __name__ == "__main__": new_generated_tasks = future.result() if new_generated_tasks: for new_task in new_generated_tasks: + # New tasks from wrong-label/additional-answer will fallback to live API futures[executor.submit(process_single_task, new_task)] = new_task except Exception as e: print(f"Exception during task execution: {e}", file=sys.stderr) - end_time = time.time() print("Time elapsed : ", end_time - start_time) print("Requests to pro / flash : ", pro_count, flash_count)