Some fixing of boxes

master
Sébastien Miquel 2026-02-08 19:54:54 +01:00
parent b95e55d088
commit c1451a7a99
1 changed files with 54 additions and 12 deletions

View File

@ -142,6 +142,8 @@ import os
import threading import threading
import concurrent.futures import concurrent.futures
NB_THREADS = 8
# PROXY_URL = "http://192.168.241.1:3128" # PROXY_URL = "http://192.168.241.1:3128"
PROXY_URL = None PROXY_URL = None
@ -149,8 +151,8 @@ if PROXY_URL:
os.environ["http_proxy"] = PROXY_URL os.environ["http_proxy"] = PROXY_URL
os.environ["https_proxy"] = PROXY_URL os.environ["https_proxy"] = PROXY_URL
MODEL_ID = "gemini-3-pro-preview" MODEL_ID = "gemini-3-pro-preview"
MODEL_ID_BIS = "gemini-3-flash-preview"
api_key="REMOVED_API_KEY" api_key="REMOVED_API_KEY"
from pydantic import BaseModel, Field, TypeAdapter from pydantic import BaseModel, Field, TypeAdapter
@ -242,16 +244,27 @@ tasks_to_process = [t for t in tasks if (str(t[0]), t[1]) not in completed_set]
def process_single_task(task_tuple): def process_single_task(task_tuple):
file_path, label = task_tuple file_path, label = task_tuple
group_name = os.path.splitext(file_path)[0]
json_path = group_name + '.json'
with open(json_path, 'r') as f:
# List of (groupid, start, end), in pixels
group_data = json.load(f)
n = len(group_data)
d_data = {l[0]: (l[1], l[2]) for l in group_data}
total_height = group_data[-1][2]
use_flash = n >= 5
try: try:
contents, config = generate_request(file_path, label) contents, config = generate_request(file_path, label)
if use_flash:
print(f"Asking Flash Gemini: {label} {file_path}")
else:
print(f"Asking Gemini: {label} {file_path}") print(f"Asking Gemini: {label} {file_path}")
full_response_text = "" full_response_text = ""
# Assuming client is thread-safe (usually is). # Assuming client is thread-safe (usually is).
# If not, create a new client instance inside this function. # If not, create a new client instance inside this function.
for chunk in client.models.generate_content_stream( for chunk in client.models.generate_content_stream(
model=MODEL_ID, model=MODEL_ID_BIS if use_flash else MODEL_ID,
contents=contents, contents=contents,
config=config, config=config,
): ):
@ -260,7 +273,41 @@ def process_single_task(task_tuple):
# Parse JSON # Parse JSON
json_data = json.loads(full_response_text) json_data = json.loads(full_response_text)
print(f"Gemini answered correctly for {file_path}")
if use_flash:
print(f"Gemini Flash answered for {file_path}")
else:
print(f"Gemini answered for {file_path}")
# print("Debug : ", json_data)
# Ensure consistency of answer placements
for p in json_data:
pid = p["id"]
res = p["result"]
if res["error"] != "":
print("Error :", res["error"], "for Copie", pid, label, group_name)
for f in res["feedback"]:
b = f["box_2d"]
if b:
ymin, xmin, ymax, xmax = b
ymin = ymin * total_height // 1000
ymax = ymax * total_height // 1000
if pid not in d_data:
print("Error : Gemini answered a copie id not present", pid, label, group_name)
continue
yming,ymaxg = d_data[pid]
if ymin < yming-50 or ymax > ymaxg+50:
print("Error : Gemini answered box2d not at the right position", pid, label, group_name)
if ymax < yming or ymin > ymaxg:
print("Removing the box.")
f["box_2d"] = None
continue
nymin = max(ymin, yming) * 1000 // total_height
nymax = min(ymax, ymaxg) * 1000 // total_height
f["box_2d"] = [nymin, xmin, nymax, xmax]
# print("Group :", yming, ymaxg, "Answered:", ymin, ymax)
# --- CRITICAL: Use Lock for writing shared data --- # --- CRITICAL: Use Lock for writing shared data ---
with io_lock: with io_lock:
@ -272,19 +319,14 @@ def process_single_task(task_tuple):
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2) json.dump(results, f, indent=2)
# Save Progress (Optional, based on your logic)
# completed_tasks.append((str(file_path), label))
# with open(progress_path, "w", encoding="utf-8") as f:
# json.dump(completed_tasks, f)
except json.JSONDecodeError: except json.JSONDecodeError:
print(f"Error decoding JSON for {file_path}", file=sys.stderr) print(f"Error decoding JSON for {file_path}", file=sys.stderr)
except Exception as e: except Exception as e:
print(f"Exception processing {file_path}: {e}", file=sys.stderr) print(f"Exception processing {file_path}: {e}", file=sys.stderr)
print(f"Starting processing on {len(tasks_to_process)} tasks with 6 threads...") print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...")
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
executor.map(process_single_task, tasks_to_process) executor.map(process_single_task, tasks_to_process)
end_time = time.time() end_time = time.time()