Some fixing of boxes
parent
b95e55d088
commit
c1451a7a99
|
|
@ -142,6 +142,8 @@ import os
|
|||
import threading
|
||||
import concurrent.futures
|
||||
|
||||
NB_THREADS = 8
|
||||
|
||||
# PROXY_URL = "http://192.168.241.1:3128"
|
||||
PROXY_URL = None
|
||||
|
||||
|
|
@ -149,8 +151,8 @@ if PROXY_URL:
|
|||
os.environ["http_proxy"] = PROXY_URL
|
||||
os.environ["https_proxy"] = PROXY_URL
|
||||
|
||||
|
||||
MODEL_ID = "gemini-3-pro-preview"
|
||||
MODEL_ID_BIS = "gemini-3-flash-preview"
|
||||
api_key="REMOVED_API_KEY"
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
|
@ -242,16 +244,27 @@ tasks_to_process = [t for t in tasks if (str(t[0]), t[1]) not in completed_set]
|
|||
|
||||
def process_single_task(task_tuple):
|
||||
file_path, label = task_tuple
|
||||
|
||||
group_name = os.path.splitext(file_path)[0]
|
||||
json_path = group_name + '.json'
|
||||
with open(json_path, 'r') as f:
|
||||
# List of (groupid, start, end), in pixels
|
||||
group_data = json.load(f)
|
||||
n = len(group_data)
|
||||
d_data = {l[0]: (l[1], l[2]) for l in group_data}
|
||||
total_height = group_data[-1][2]
|
||||
use_flash = n >= 5
|
||||
try:
|
||||
contents, config = generate_request(file_path, label)
|
||||
print(f"Asking Gemini: {label} {file_path}")
|
||||
if use_flash:
|
||||
print(f"Asking Flash Gemini: {label} {file_path}")
|
||||
else:
|
||||
print(f"Asking Gemini: {label} {file_path}")
|
||||
|
||||
full_response_text = ""
|
||||
# Assuming client is thread-safe (usually is).
|
||||
# If not, create a new client instance inside this function.
|
||||
for chunk in client.models.generate_content_stream(
|
||||
model=MODEL_ID,
|
||||
model=MODEL_ID_BIS if use_flash else MODEL_ID,
|
||||
contents=contents,
|
||||
config=config,
|
||||
):
|
||||
|
|
@ -260,7 +273,41 @@ def process_single_task(task_tuple):
|
|||
|
||||
# Parse JSON
|
||||
json_data = json.loads(full_response_text)
|
||||
print(f"Gemini answered correctly for {file_path}")
|
||||
|
||||
if use_flash:
|
||||
print(f"Gemini Flash answered for {file_path}")
|
||||
else:
|
||||
print(f"Gemini answered for {file_path}")
|
||||
|
||||
# print("Debug : ", json_data)
|
||||
# Ensure consistency of answer placements
|
||||
for p in json_data:
|
||||
pid = p["id"]
|
||||
res = p["result"]
|
||||
if res["error"] != "":
|
||||
print("Error :", res["error"], "for Copie", pid, label, group_name)
|
||||
for f in res["feedback"]:
|
||||
b = f["box_2d"]
|
||||
if b:
|
||||
ymin, xmin, ymax, xmax = b
|
||||
ymin = ymin * total_height // 1000
|
||||
ymax = ymax * total_height // 1000
|
||||
|
||||
if pid not in d_data:
|
||||
print("Error : Gemini answered a copie id not present", pid, label, group_name)
|
||||
continue
|
||||
yming,ymaxg = d_data[pid]
|
||||
if ymin < yming-50 or ymax > ymaxg+50:
|
||||
print("Error : Gemini answered box2d not at the right position", pid, label, group_name)
|
||||
if ymax < yming or ymin > ymaxg:
|
||||
print("Removing the box.")
|
||||
f["box_2d"] = None
|
||||
continue
|
||||
nymin = max(ymin, yming) * 1000 // total_height
|
||||
nymax = min(ymax, ymaxg) * 1000 // total_height
|
||||
|
||||
f["box_2d"] = [nymin, xmin, nymax, xmax]
|
||||
# print("Group :", yming, ymaxg, "Answered:", ymin, ymax)
|
||||
|
||||
# --- CRITICAL: Use Lock for writing shared data ---
|
||||
with io_lock:
|
||||
|
|
@ -272,19 +319,14 @@ def process_single_task(task_tuple):
|
|||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
# Save Progress (Optional, based on your logic)
|
||||
# completed_tasks.append((str(file_path), label))
|
||||
# with open(progress_path, "w", encoding="utf-8") as f:
|
||||
# json.dump(completed_tasks, f)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error decoding JSON for {file_path}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"Exception processing {file_path}: {e}", file=sys.stderr)
|
||||
|
||||
print(f"Starting processing on {len(tasks_to_process)} tasks with 6 threads...")
|
||||
print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
|
||||
executor.map(process_single_task, tasks_to_process)
|
||||
|
||||
end_time = time.time()
|
||||
|
|
|
|||
Loading…
Reference in New Issue