import sys import os import time from pathlib import Path import argparse import prompting import signal from google import genai import base64 import shlex import json import threading import concurrent.futures if len(sys.argv) < 2: sys.exit("Usage: python script.py 'InterroTest/Ex 2/Group_1.jpg' OR OR 'file1' 'file2'") # Parse Arguments parser = argparse.ArgumentParser() parser.add_argument("paths", nargs="+", help="List of images or directories") parser.add_argument("--overwrite", action="store_true", help="Force redo requests even if output exists") parser.add_argument("--limit", type=int, help="limit calls to gemini rpo integer") parser.add_argument("--refaire", action="store_true", help="Redo specific copies/labels defined in refaire.json") parser.add_argument("--batch", action="store_true", help="Generate a JSONL file of requests to send to the Gemini Batch API") parser.add_argument("--batch-from", type=str, metavar="LABEL", help="Do live requests before LABEL, and batch requests from LABEL onwards") parser.add_argument("--deal-with-batched", action="store_true", help="Process a JSONL file containing completed batch results") args, _ = parser.parse_known_args() tasks = [] # List of tuples: (filepath_str, label_str) results = {} for path_str in args.paths: arg_path = Path(path_str) if not arg_path.exists(): print(f"Warning: {path_str} not found. Skipping.") continue if arg_path.is_file() and arg_path.suffix.lower() == ".jpg": # Handle individual file # Note: assumes structure InterroTest/Ex 2/Group_1.jpg label = arg_path.parent.name INPUT_DIR = arg_path.parent.parent.parent COPIES_DIR = INPUT_DIR / "Copies" GROUPS_DIR = INPUT_DIR / "Par label" tasks.append((str(arg_path), label)) if label not in results: results[label] = [] elif arg_path.is_dir(): INPUT_DIR = arg_path COPIES_DIR = INPUT_DIR / "Copies" GROUPS_DIR = INPUT_DIR / "Par label" # Handle directory (original behavior) for sub in GROUPS_DIR.iterdir(): if sub.is_dir(): label = sub.name if label not in results: results[label] = [] for img in sub.glob("*.jpg"): tasks.append((str(img), label)) NB_THREADS = 12 # PROXY_URL = "http://192.168.241.1:3128" PROXY_URL = None if PROXY_URL: os.environ["http_proxy"] = PROXY_URL os.environ["https_proxy"] = PROXY_URL MODEL_ID_pro = "gemini-3.1-pro-preview" MODEL_ID_flash = "gemini-3-flash-preview" api_key = os.environ["GEMINI_API_KEY"] # --- Thread-safe Logging --- log_lock = threading.Lock() thread_logs = {} def tprint(*args, **kwargs): """Buffer messages per thread to group them.""" tid = threading.current_thread().name msg = " ".join(map(str, args)) with log_lock: if tid not in thread_logs: thread_logs[tid] = [] thread_logs[tid].append(msg) # Optional: Keep printing to console but prefix with thread name print(f"[{tid}] {msg}", **kwargs) def flush_thread_log(tid=None): """Append a thread's buffered messages to the log file contiguously.""" tid = tid or threading.current_thread().name with log_lock: if thread_logs.get(tid): with open(INPUT_DIR / "correction_log", "a", encoding="utf-8") as f: f.write(f"--- Task Log [{tid}] ---\n") f.write("\n".join(thread_logs[tid]) + "\n\n") thread_logs[tid].clear() def handle_interrupt(sig, frame): """Flush all partial/unfinished logs if program is interrupted.""" print("\nInterrupt received. Flushing partial logs...", file=sys.stderr) for tid in list(thread_logs.keys()): flush_thread_log(tid) sys.exit(1) signal.signal(signal.SIGINT, handle_interrupt) signal.signal(signal.SIGTERM, handle_interrupt) # --------------------------- client = genai.Client(api_key=api_key) output_path = INPUT_DIR / "correction.json" progress_path = INPUT_DIR / "correction_progress.json" start_time = time.time() overwrite = args.overwrite limit = args.limit completed_tasks = [] errors_summary = [] # --- Lock for thread-safe file writing --- io_lock = threading.Lock() pro_lock = threading.Lock() pro_count = 0 flash_count = 0 pro_quota_exhausted = False if overwrite: if output_path.exists(): output_path.unlink() if progress_path.exists(): progress_path.unlink() else: if progress_path.exists(): with open(progress_path, "r", encoding="utf-8") as f: completed_tasks = json.load(f) if output_path.exists(): with open(output_path, "r", encoding="utf-8") as f: results = json.load(f) completed_set = set((str(f), l) for f, l in completed_tasks) tasks_to_process = [t for t in tasks if (str(t[0]), t[1]) not in completed_set] def call_gemini_with_retries(model_id, contents, config, fallback_model_id=MODEL_ID_flash): """Handles requests to Gemini with a 1min and 5min retry mechanism, and quota fallback.""" global pro_quota_exhausted delays = [60, 300] for attempt in range(3): # Switch to fallback immediately if quota was exhausted by another thread if model_id == MODEL_ID_pro and pro_quota_exhausted and fallback_model_id: model_id = fallback_model_id try: full_response_text = "" for chunk in client.models.generate_content_stream( model=model_id, contents=contents, config=config, ): if chunk.text: full_response_text += chunk.text return full_response_text except Exception as e: error_msg = str(e).lower() is_quota_error = "429" in error_msg or "quota" in error_msg or "exhausted" in error_msg # Immediately fallback to Flash without waiting if it's a Pro quota error if is_quota_error and model_id == MODEL_ID_pro and fallback_model_id: tprint(f"\tGemini Pro quota hit ({e}). \n\n\tFalling back to Flash permanently...") model_id = fallback_model_id pro_quota_exhausted = True continue # Retry immediately with Flash if attempt < 2: tprint(f"\tGemini API failure: {e}. Retrying in {delays[attempt]} seconds...") time.sleep(delays[attempt]) else: tprint(f"\tGemini API failure: {e}. Maximum retries reached.") raise def correct_boxes_with_gemini(pid, label, pdf_path, original_feedbacks, yming, ymaxg, width_r, total_height): """Requests corrected bounding boxes from Gemini Flash on the single image.""" # pdf_path = COPIES_DIR / f"Copie{pid}" / f"{label}.pdf" contents, config = prompting.request_for_box_correction(pdf_path, original_feedbacks) response_text = call_gemini_with_retries(MODEL_ID_flash, contents, config) corrected_feedbacks = json.loads(response_text) global_feedbacks = [f for f in original_feedbacks if not f["box_2d"]] # Map the coordinates back from the single image to the group canvas for f in corrected_feedbacks: b = f.get("box_2d") if b: ymin_s, xmin_s, ymax_s, xmax_s = b # Y mapping: Add the group Y-offset (yming), then normalize to total_height single_h = ymaxg - yming new_ymin = int((yming + (ymin_s * single_h / 1000.0)) * 1000.0 / total_height) new_ymax = int((yming + (ymax_s * single_h / 1000.0)) * 1000.0 / total_height) # X mapping: Multiply by the width ratio of this sub-image vs the group image new_xmin = int(xmin_s * width_r) new_xmax = int(xmax_s * width_r) f["box_2d"] = [new_ymin, new_xmin, new_ymax, new_xmax] return global_feedbacks + corrected_feedbacks import shutil import grouping def get_next_group_idx(label): """Finds the next available Group index for a given label.""" target_folder = GROUPS_DIR / label target_folder.mkdir(exist_ok=True) existing = list(target_folder.glob("Group_*.jpg")) if not existing: return 0 return max([int(f.stem.split("_")[1]) for f in existing]) from utils import read_all_labels, enonce_total def handle_label_errors(pid, label, res, pdf_path): """Handles Gemini labeling errors, moves/copies files, and returns new tasks.""" new_tasks = [] error_type = res.get("error") all_labels = read_all_labels(INPUT_DIR) labels_txt = (INPUT_DIR / "labels").read_text(encoding="utf-8", errors="replace") enonce = enonce_total(INPUT_DIR) if error_type == "wrong-label": tprint(f"\tHandling wrong-label for {pid} {label}") contents, config = prompting.request_for_wrong_label(pdf_path, label, enonce, labels_txt) new_label = call_gemini_with_retries(MODEL_ID_flash, contents, config).strip().strip('"\'') if new_label not in all_labels: tprint(f"\t\tCopie{pid} returned an incorrect label {new_label} from an initial wrong label {label}. Ignoring") res["error"] = "wrg-lbl:cldtfix" return [] if new_label == label: res["error"] = "" return [] base_new_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{new_label}.pdf" new_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{new_label}_new.pdf" if base_new_pdf_path.exists() or new_pdf_path.exists(): tprint(f"\t\tCopie{pid} tried to move wrong {label} to {new_label}, but it already exists.") res["error"] = f"wrg-lbl:{new_label}?exists" else: res["error"] = f"wrg-lbl-moved-to:{new_label}" tprint(f"\t\tCopie{pid} : moving wrong {label} to {new_label}.") # Copie vers _new, puis renommage de l'original vers _old shutil.copy(str(pdf_path), str(new_pdf_path)) old_pdf_path = pdf_path.with_name(f"{label}_old.pdf") if pdf_path != old_pdf_path: shutil.move(str(pdf_path), str(old_pdf_path)) idx = get_next_group_idx(new_label) height = grouping.get_pdf_height(str(new_pdf_path)) grouping.create_jpg(new_label, idx, [(pid, str(new_pdf_path), height)], GROUPS_DIR) tprint(f"\t\tMaking {new_label} group {idx+1}") new_tasks.append((str(GROUPS_DIR / new_label / f"Group_{idx+1}.jpg"), new_label, False)) elif error_type == "additional-answer": contents, config = prompting.request_for_additional_answer(pdf_path, label, enonce, labels_txt) tprint(f"\tHandling additional-answer for {pid} {label}") try: add_labels = json.loads(call_gemini_with_retries(MODEL_ID_flash, contents, config)) except Exception: add_labels = [] keep_error = False error = "al:" for add_label in add_labels: if add_label == label: continue if add_label not in all_labels: tprint(f"\t\t Inexistent label ({add_label}) from additional-answer processing {pid} {label}. Ignoring") error += f"{add_label}??" keep_error = True continue base_add_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{add_label}.pdf" add_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{add_label}_new.pdf" if not base_add_pdf_path.exists() and not add_pdf_path.exists(): shutil.copy(str(pdf_path), str(add_pdf_path)) tprint(f"\t\tCopying Copie{pid} : {label} -> {add_label}") idx = get_next_group_idx(add_label) tprint(f"\t\tMaking {add_label} group {idx+1}") height = grouping.get_pdf_height(str(add_pdf_path)) grouping.create_jpg(add_label, idx, [(pid, str(add_pdf_path), height)], GROUPS_DIR) new_tasks.append((str(GROUPS_DIR / add_label / f"Group_{idx+1}.jpg"), add_label, False)) error += f"(->){add_label}" keep_error = True else: keep_error = True error += f"(xx){add_label}" tprint(f"\t\tAlready present (not copied) Copie{pid} : {label} -> {add_label}") if not keep_error: res["error"] = "" else: res["error"] = error return new_tasks def process_single_task(task_tuple, precomputed_response=None): try: global pro_count, flash_count, pro_quota_exhausted file_path = task_tuple[0] label = task_tuple[1] can_spawn_tasks = task_tuple[2] if len(task_tuple) > 2 else True group_name = os.path.splitext(file_path)[0] json_path = group_name + '.json' new_tasks = [] with open(json_path, 'r') as f: group_data = json.load(f) n = len(group_data) d_data = {l[0]: (l[1], l[2], l[3]) for l in group_data} total_height = group_data[-1][2] use_flash = n >= 4 or total_height <= 500 # Only apply limits and counts if we are making a live call if precomputed_response is None: if not use_flash: with pro_lock: if pro_quota_exhausted: use_flash = True elif limit is None or pro_count < limit: pro_count += 1 else: use_flash = True if use_flash: with pro_lock: flash_count += 1 try: contents, config = prompting.generate_request(INPUT_DIR, file_path, label) model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro if precomputed_response: tprint(f"Using batched response for: {label} {group_name}") full_response_text = precomputed_response else: tprint(f"Asking Gemini {'Flash' if use_flash else 'Pro '}: {label} {group_name}") full_response_text = call_gemini_with_retries(model_to_use, contents, config) json_data = json.loads(full_response_text) # Ensure consistency of answer placements for p in json_data: pid = p["id"] res = p["result"] yming, ymaxg, width_r = d_data[pid] pdf_path = COPIES_DIR / f"Copie{pid}" / f"{label}.pdf" current_suffix = "" # Détection du vrai fichier s'il a un suffixe if not pdf_path.exists(): if pdf_path.with_name(f"{label}_new.pdf").exists(): pdf_path = pdf_path.with_name(f"{label}_new.pdf") current_suffix = "_new" # Quand est-ce que ce chemin est utilisé ? Jamais ? elif pdf_path.with_name(f"{label}_old.pdf").exists(): pdf_path = pdf_path.with_name(f"{label}_old.pdf") current_suffix = "_old" # 1. Gestion de empty-answer if res.get("error") == "empty-answer": old_path = pdf_path.with_name(f"{label}_old.pdf") if pdf_path.exists() and pdf_path != old_path: shutil.move(str(pdf_path), str(old_path)) pdf_path = old_path current_suffix = "_old" if (not can_spawn_tasks) and res["error"] == "additional-answer": tprint("\tSwallowing an additional-answer from a subsequent task.") res["error"]= "" if res["error"] != "": tprint("\tError :", res["error"], "for Copie", pid, group_name) if can_spawn_tasks and res.get("error") in ["wrong-label", "additional-answer"]: new_tasks.extend(handle_label_errors(pid, label, res, pdf_path)) # Si "wrong-label" a déplacé le fichier courant vers _old if res.get("error", "").startswith("wrg-lbl-moved-to:"): current_suffix = "_old" # 5. Enregistrer l'information dans correction.json if current_suffix: res["suffix"] = current_suffix needs_correction = [] for (i,f) in enumerate(res["feedback"]): b = f.get("box_2d") if b: ymin, xmin, ymax, xmax = b ymin = ymin * total_height // 1000 ymax = ymax * total_height // 1000 if pid not in d_data: tprint("Error : Gemini answered a copie id not present", pid, label, group_name) continue if (ymin < yming - 50 or ymax > ymaxg + 50 or xmax / 1000 > width_r): needs_correction.append(i) break if ymin < yming - 5: ymin = yming - 5 b[0] = ymin * 1000 // total_height if ymax > ymaxg + 5: ymax = ymaxg + 5 b[2] = ymax * 1000 // total_height if needs_correction: tprint(f"\tBox anomalies detected for Copie {pid} {group_name}. \n\tRequesting isolated correction from Gemini Flash...") try: # Pensez à passer pdf_path à la fonction modifiée ! res["feedback"] = correct_boxes_with_gemini( pid, label, pdf_path, res["feedback"], yming, ymaxg, width_r, total_height) except Exception as e: tprint(f"\tCorrection failed for Copie {pid}, {group_name} : {e}\n\tRemoving the boxes") # Fallback if the second request fails entirely for (i, f) in enumerate(res["feedback"]): if i in needs_correction: f["box_2d"] = None # --- Use Lock for writing shared data --- with io_lock: if label not in results: results[label] = [] results[label].append(json_data) with open(output_path, "w", encoding="utf-8") as f: json.dump(results, f, indent=2) # To track progress completed_tasks.append((file_path, label)) with open(progress_path, "w", encoding="utf-8") as f: json.dump(completed_tasks, f, indent=2) except json.JSONDecodeError: tprint(f"Error decoding JSON for {file_path}", file=sys.stderr) with io_lock: errors_summary.append(("Error decoding JSON response", file_path)) except Exception as e: error_msg = f"Exception processing {file_path}: {e}" print(error_msg, file=sys.stderr) with io_lock: errors_summary.append((error_msg, file_path)) return new_tasks finally: flush_thread_log() if __name__ == "__main__": if args.refaire: refaire_path = INPUT_DIR / "refaire.json" overwritten_path = INPUT_DIR / "overwritten_correction.json" if refaire_path.exists(): with open(refaire_path, "r", encoding="utf-8") as f: refaire_list = json.load(f) overwritten_data = [] if overwritten_path.exists(): with open(overwritten_path, "r", encoding="utf-8") as f: overwritten_data = json.load(f) dirty_results = False for copie_name, labels in refaire_list: pid = copie_name.replace("Copie", "") copie_dir = COPIES_DIR / copie_name # If list is empty, redo all labels available for this Copie if not labels: labels = [p.stem for p in copie_dir.glob("*.pdf")] for label in labels: # 1. Extract and backup old corrections if label in results: for batch in results[label]: to_remove = None for item in batch: if item.get("id") == pid: to_remove = item break if to_remove: batch.remove(to_remove) overwritten_data.append({ "pid": pid, "label": label, "data": to_remove, "timestamp": time.time() }) dirty_results = True # Clean up empty batches results[label] = [b for b in results[label] if b] # 2. Make new group and add to tasks pdf_path = copie_dir / f"{label}.pdf" if not pdf_path.exists(): if (copie_dir / f"{label}_new.pdf").exists(): pdf_path = copie_dir / f"{label}_new.pdf" # elif (copie_dir / f"{label}_old.pdf").exists(): # pdf_path = copie_dir / f"{label}_old.pdf" if pdf_path.exists(): idx = get_next_group_idx(label) height = grouping.get_pdf_height(str(pdf_path)) grouping.create_jpg(label, idx, [(pid, str(pdf_path), height)], GROUPS_DIR) new_group_path = str(GROUPS_DIR / label / f"Group_{idx+1}.jpg") tasks_to_process.append((new_group_path, label)) if dirty_results: with open(output_path, "w", encoding="utf-8") as f: json.dump(results, f, indent=2) with open(overwritten_path, "w", encoding="utf-8") as f: json.dump(overwritten_data, f, indent=2) else: print(f"Warning: --refaire flag used, but {refaire_path} not found.", file=sys.stderr) if args.batch or args.batch_from: all_labels = read_all_labels(INPUT_DIR) batch_tasks = [] if args.batch_from: for label in all_labels: if label.startswith(args.batch_from): args.batch_from = label print("Batching from : ", args.batch_from) break if args.batch_from not in all_labels: sys.exit(f"Error: Label '{args.batch_from}' not found. Available labels: {all_labels}") target_idx = all_labels.index(args.batch_from) live_tasks = [] for task in tasks_to_process: lbl = task[1] # Any label found sequentially equal or after `args.batch_from` gets batched if lbl in all_labels and all_labels.index(lbl) >= target_idx: batch_tasks.append(task) else: live_tasks.append(task) tasks_to_process = live_tasks # Keep live tasks to be run right after else: batch_tasks = tasks_to_process tasks_to_process = [] # Run nothing live if just `--batch` if batch_tasks: batch_flash_file = INPUT_DIR / "batch_requests_flash.jsonl" batch_pro_file = INPUT_DIR / "batch_requests_pro.jsonl" count_flash = 0 count_pro = 0 with open(batch_flash_file, "w", encoding="utf-8") as f_flash, \ open(batch_pro_file, "w", encoding="utf-8") as f_pro: for task in batch_tasks: file_path, label = task[0], task[1] group_name = os.path.splitext(file_path)[0] json_path = group_name + '.json' with open(json_path, 'r') as jf: group_data = json.load(jf) use_flash = len(group_data) >= 4 or group_data[-1][2] <= 500 image_data = Path(file_path).read_bytes() b64_img = base64.b64encode(image_data).decode("utf-8") # Format payload matching Gemini Batch API file requirements req = { "key": file_path, # The ID returned in the output file "request": { "contents": [{ "role": "user", "parts": [ {"inlineData": {"mimeType": "image/jpeg", "data": b64_img}}, {"text": prompting.make_prompt(INPUT_DIR,label)} ] }], "generation_config": { "temperature": 1.0, "topP": 0.95, "maxOutputTokens": 65535, "responseMimeType": "application/json", "responseSchema": prompting.UNROLLED_SCHEMA } } } if use_flash: f_flash.write(json.dumps(req) + "\n") count_flash += 1 else: f_pro.write(json.dumps(req) + "\n") count_pro += 1 print(f"Batch generation complete.") print(f" - {count_flash} requests saved to {batch_flash_file} (for {MODEL_ID_flash})") print(f" - {count_pro} requests saved to {batch_pro_file} (for {MODEL_ID_pro})") print("Upload these files via the File API and create two separate batch jobs.") # If there's no live tasks to do, and we aren't doing a batched ingestion, exit right away if not tasks_to_process and not args.deal_with_batched: sys.exit(0) batched_responses = {} if args.deal_with_batched: batch_results_path = INPUT_DIR / "batched_correction_result.jsonl" if batch_results_path.exists(): print(f"Loading batch results from {batch_results_path}...") with open(batch_results_path, "r", encoding="utf-8") as f: for line in f: if not line.strip(): continue data = json.loads(line) task_id = data.get("key") # Corresponds to the key sent in the request if "response" in data: try: # Extract the JSON response text per standard Batch API schema resp_text = data["response"]["candidates"][0]["content"]["parts"][0]["text"] batched_responses[task_id] = resp_text except (KeyError, IndexError) as e: print(f"Warning: Could not parse response for {task_id}: {e}", file=sys.stderr) elif "error" in data: print(f"Batch API Error for {task_id}: {data['error']}", file=sys.stderr) else: print(f"Warning: Batch results file {batch_results_path} not found.", file=sys.stderr) print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...") with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor: futures = {} for task in tasks_to_process: file_path = task[0] precomp = batched_responses.get(file_path) futures[executor.submit(process_single_task, task, precomp)] = task # Process tasks as they complete, allowing dynamic task addition for future in concurrent.futures.as_completed(futures): try: new_generated_tasks = future.result() if new_generated_tasks: for new_task in new_generated_tasks: # New tasks from wrong-label/additional-answer will fallback to live API futures[executor.submit(process_single_task, new_task)] = new_task except Exception as e: print(f"Exception during task execution: {e}", file=sys.stderr) end_time = time.time() print("Time elapsed : ", end_time - start_time) print("Requests to pro / flash : ", pro_count, flash_count) if errors_summary: print("\n--- Summary of Exceptions (You can use several images on one instance) ---", file=sys.stderr) for (err, file) in errors_summary: print(err, file=sys.stderr) escaped_path = shlex.quote(str(file)) print(f"Run : python correction.py {escaped_path}")