From c1451a7a9916135681860019a0aee023c547682c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Miquel?= Date: Sun, 8 Feb 2026 19:54:54 +0100 Subject: [PATCH] Some fixing of boxes --- correction.py | 66 +++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 12 deletions(-) diff --git a/correction.py b/correction.py index 97e11fe..3d6a1cb 100644 --- a/correction.py +++ b/correction.py @@ -142,6 +142,8 @@ import os import threading import concurrent.futures +NB_THREADS = 8 + # PROXY_URL = "http://192.168.241.1:3128" PROXY_URL = None @@ -149,8 +151,8 @@ if PROXY_URL: os.environ["http_proxy"] = PROXY_URL os.environ["https_proxy"] = PROXY_URL - MODEL_ID = "gemini-3-pro-preview" +MODEL_ID_BIS = "gemini-3-flash-preview" api_key="REMOVED_API_KEY" from pydantic import BaseModel, Field, TypeAdapter @@ -242,16 +244,27 @@ tasks_to_process = [t for t in tasks if (str(t[0]), t[1]) not in completed_set] def process_single_task(task_tuple): file_path, label = task_tuple - + group_name = os.path.splitext(file_path)[0] + json_path = group_name + '.json' + with open(json_path, 'r') as f: + # List of (groupid, start, end), in pixels + group_data = json.load(f) + n = len(group_data) + d_data = {l[0]: (l[1], l[2]) for l in group_data} + total_height = group_data[-1][2] + use_flash = n >= 5 try: contents, config = generate_request(file_path, label) - print(f"Asking Gemini: {label} {file_path}") + if use_flash: + print(f"Asking Flash Gemini: {label} {file_path}") + else: + print(f"Asking Gemini: {label} {file_path}") full_response_text = "" # Assuming client is thread-safe (usually is). # If not, create a new client instance inside this function. for chunk in client.models.generate_content_stream( - model=MODEL_ID, + model=MODEL_ID_BIS if use_flash else MODEL_ID, contents=contents, config=config, ): @@ -260,7 +273,41 @@ def process_single_task(task_tuple): # Parse JSON json_data = json.loads(full_response_text) - print(f"Gemini answered correctly for {file_path}") + + if use_flash: + print(f"Gemini Flash answered for {file_path}") + else: + print(f"Gemini answered for {file_path}") + + # print("Debug : ", json_data) + # Ensure consistency of answer placements + for p in json_data: + pid = p["id"] + res = p["result"] + if res["error"] != "": + print("Error :", res["error"], "for Copie", pid, label, group_name) + for f in res["feedback"]: + b = f["box_2d"] + if b: + ymin, xmin, ymax, xmax = b + ymin = ymin * total_height // 1000 + ymax = ymax * total_height // 1000 + + if pid not in d_data: + print("Error : Gemini answered a copie id not present", pid, label, group_name) + continue + yming,ymaxg = d_data[pid] + if ymin < yming-50 or ymax > ymaxg+50: + print("Error : Gemini answered box2d not at the right position", pid, label, group_name) + if ymax < yming or ymin > ymaxg: + print("Removing the box.") + f["box_2d"] = None + continue + nymin = max(ymin, yming) * 1000 // total_height + nymax = min(ymax, ymaxg) * 1000 // total_height + + f["box_2d"] = [nymin, xmin, nymax, xmax] + # print("Group :", yming, ymaxg, "Answered:", ymin, ymax) # --- CRITICAL: Use Lock for writing shared data --- with io_lock: @@ -272,19 +319,14 @@ def process_single_task(task_tuple): with open(output_path, "w", encoding="utf-8") as f: json.dump(results, f, indent=2) - # Save Progress (Optional, based on your logic) - # completed_tasks.append((str(file_path), label)) - # with open(progress_path, "w", encoding="utf-8") as f: - # json.dump(completed_tasks, f) - except json.JSONDecodeError: print(f"Error decoding JSON for {file_path}", file=sys.stderr) except Exception as e: print(f"Exception processing {file_path}: {e}", file=sys.stderr) -print(f"Starting processing on {len(tasks_to_process)} tasks with 6 threads...") +print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...") -with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor: +with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor: executor.map(process_single_task, tasks_to_process) end_time = time.time()