786 lines
34 KiB
Python
786 lines
34 KiB
Python
import sys
|
|
import os
|
|
import time
|
|
from pathlib import Path
|
|
import argparse
|
|
import prompting
|
|
import signal
|
|
from google import genai
|
|
import base64
|
|
import shlex
|
|
import json
|
|
import threading
|
|
import concurrent.futures
|
|
|
|
if len(sys.argv) < 2:
|
|
sys.exit("Usage: python script.py 'InterroTest/Ex 2/Group_1.jpg' OR <InputDir> OR 'file1' 'file2'")
|
|
|
|
# Parse Arguments
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("paths", nargs="+", help="List of images or directories")
|
|
parser.add_argument("--overwrite", action="store_true",
|
|
help="Force redo requests even if output exists")
|
|
parser.add_argument("--limit", type=int, help="limit calls to gemini rpo integer")
|
|
parser.add_argument("--refaire", action="store_true",
|
|
help="Redo specific copies/labels defined in refaire.json")
|
|
parser.add_argument("--batch", action="store_true",
|
|
help="Generate a JSONL file of requests to send to the Gemini Batch API")
|
|
parser.add_argument("--batch-from", type=str, metavar="LABEL",
|
|
help="Do live requests before LABEL, and batch requests from LABEL onwards")
|
|
parser.add_argument("--deal-with-batched", action="store_true",
|
|
help="Process a JSONL file containing completed batch results")
|
|
args, _ = parser.parse_known_args()
|
|
|
|
tasks = [] # List of tuples: (filepath_str, label_str)
|
|
results = {}
|
|
|
|
|
|
for path_str in args.paths:
|
|
arg_path = Path(path_str)
|
|
|
|
if not arg_path.exists():
|
|
print(f"Warning: {path_str} not found. Skipping.")
|
|
continue
|
|
|
|
if arg_path.is_file() and arg_path.suffix.lower() == ".jpg":
|
|
# Handle individual file
|
|
# Note: assumes structure InterroTest/Ex 2/Group_1.jpg
|
|
label = arg_path.parent.name
|
|
INPUT_DIR = arg_path.parent.parent.parent
|
|
COPIES_DIR = INPUT_DIR / "Copies"
|
|
GROUPS_DIR = INPUT_DIR / "Par label"
|
|
tasks.append((str(arg_path), label))
|
|
if label not in results:
|
|
results[label] = []
|
|
|
|
elif arg_path.is_dir():
|
|
INPUT_DIR = arg_path
|
|
COPIES_DIR = INPUT_DIR / "Copies"
|
|
GROUPS_DIR = INPUT_DIR / "Par label"
|
|
# Handle directory (original behavior)
|
|
for sub in GROUPS_DIR.iterdir():
|
|
if sub.is_dir():
|
|
label = sub.name
|
|
if label not in results:
|
|
results[label] = []
|
|
for img in sub.glob("*.jpg"):
|
|
tasks.append((str(img), label))
|
|
|
|
|
|
NB_THREADS = 12
|
|
|
|
# PROXY_URL = "http://192.168.241.1:3128"
|
|
PROXY_URL = None
|
|
|
|
if PROXY_URL:
|
|
os.environ["http_proxy"] = PROXY_URL
|
|
os.environ["https_proxy"] = PROXY_URL
|
|
|
|
MODEL_ID_pro = "gemini-3.1-pro-preview"
|
|
MODEL_ID_flash = "gemini-3-flash-preview"
|
|
api_key = os.environ["GEMINI_API_KEY"]
|
|
|
|
# --- Thread-safe Logging ---
|
|
log_lock = threading.Lock()
|
|
thread_logs = {}
|
|
|
|
def tprint(*args, **kwargs):
|
|
"""Buffer messages per thread to group them."""
|
|
tid = threading.current_thread().name
|
|
msg = " ".join(map(str, args))
|
|
|
|
with log_lock:
|
|
if tid not in thread_logs:
|
|
thread_logs[tid] = []
|
|
thread_logs[tid].append(msg)
|
|
|
|
# Optional: Keep printing to console but prefix with thread name
|
|
print(f"[{tid}] {msg}", **kwargs)
|
|
|
|
def flush_thread_log(tid=None):
|
|
"""Append a thread's buffered messages to the log file contiguously."""
|
|
tid = tid or threading.current_thread().name
|
|
with log_lock:
|
|
if thread_logs.get(tid):
|
|
with open(INPUT_DIR / "correction_log", "a", encoding="utf-8") as f:
|
|
f.write(f"--- Task Log [{tid}] ---\n")
|
|
f.write("\n".join(thread_logs[tid]) + "\n\n")
|
|
thread_logs[tid].clear()
|
|
|
|
def handle_interrupt(sig, frame):
|
|
"""Flush all partial/unfinished logs if program is interrupted."""
|
|
print("\nInterrupt received. Flushing partial logs...", file=sys.stderr)
|
|
for tid in list(thread_logs.keys()):
|
|
flush_thread_log(tid)
|
|
sys.exit(1)
|
|
|
|
|
|
signal.signal(signal.SIGINT, handle_interrupt)
|
|
signal.signal(signal.SIGTERM, handle_interrupt)
|
|
# ---------------------------
|
|
|
|
client = genai.Client(api_key=api_key)
|
|
output_path = INPUT_DIR / "correction.json"
|
|
progress_path = INPUT_DIR / "correction_progress.json"
|
|
start_time = time.time()
|
|
overwrite = args.overwrite
|
|
limit = args.limit
|
|
completed_tasks = []
|
|
errors_summary = []
|
|
|
|
# --- Lock for thread-safe file writing ---
|
|
io_lock = threading.Lock()
|
|
pro_lock = threading.Lock()
|
|
pro_count = 0
|
|
flash_count = 0
|
|
pro_quota_exhausted = False
|
|
|
|
if overwrite:
|
|
if output_path.exists():
|
|
output_path.unlink()
|
|
if progress_path.exists():
|
|
progress_path.unlink()
|
|
else:
|
|
if progress_path.exists():
|
|
with open(progress_path, "r", encoding="utf-8") as f:
|
|
completed_tasks = json.load(f)
|
|
if output_path.exists():
|
|
with open(output_path, "r", encoding="utf-8") as f:
|
|
results = json.load(f)
|
|
|
|
completed_set = set((str(f), l) for f, l in completed_tasks)
|
|
tasks_to_process = [t for t in tasks if (str(t[0]), t[1]) not in completed_set]
|
|
|
|
def call_gemini_with_retries(model_id, contents, config,
|
|
fallback_model_id=MODEL_ID_flash):
|
|
"""Handles requests to Gemini with a 1min and 5min retry mechanism, and quota fallback."""
|
|
global pro_quota_exhausted
|
|
delays = [60, 300]
|
|
|
|
for attempt in range(3):
|
|
# Switch to fallback immediately if quota was exhausted by another thread
|
|
if model_id == MODEL_ID_pro and pro_quota_exhausted and fallback_model_id:
|
|
model_id = fallback_model_id
|
|
|
|
try:
|
|
full_response_text = ""
|
|
for chunk in client.models.generate_content_stream(
|
|
model=model_id,
|
|
contents=contents,
|
|
config=config,
|
|
):
|
|
if chunk.text:
|
|
full_response_text += chunk.text
|
|
return full_response_text
|
|
except Exception as e:
|
|
error_msg = str(e).lower()
|
|
is_quota_error = "429" in error_msg or "quota" in error_msg or "exhausted" in error_msg
|
|
is_minute_limit = "minute" in error_msg or "rpm" in error_msg or "tpm" in error_msg
|
|
|
|
if is_minute_limit:
|
|
import re
|
|
# Extract wait time if present, else use default delay
|
|
retry_match = re.search(r"retry in ([\d.]+)s", error_msg)
|
|
wait_time = float(retry_match.group(1)) + 1.0 if retry_match else delays[attempt]
|
|
|
|
tprint(f"\tGemini Pro minute limit hit. Waiting {wait_time:.1f}s...")
|
|
time.sleep(wait_time)
|
|
continue # Retry same model
|
|
|
|
# Immediately fallback to Flash without waiting if it's a Pro quota error
|
|
if is_quota_error and model_id == MODEL_ID_pro and fallback_model_id:
|
|
tprint(f"\tGemini Pro quota hit ({e}). \n\n\tFalling back to Flash permanently...")
|
|
model_id = fallback_model_id
|
|
pro_quota_exhausted = True
|
|
continue # Retry immediately with Flash
|
|
|
|
if attempt < 2:
|
|
tprint(f"\tGemini API failure: {e}. Retrying in {delays[attempt]} seconds...")
|
|
time.sleep(delays[attempt])
|
|
else:
|
|
tprint(f"\tGemini API failure: {e}. Maximum retries reached.")
|
|
raise
|
|
|
|
def correct_boxes_with_gemini(pid, label, pdf_path, original_feedbacks,
|
|
yming, ymaxg, width_r, total_height):
|
|
"""Requests corrected bounding boxes from Gemini Flash on the single image."""
|
|
# pdf_path = COPIES_DIR / f"Copie{pid}" / f"{label}.pdf"
|
|
|
|
contents, config = prompting.request_for_box_correction(pdf_path, original_feedbacks)
|
|
response_text = call_gemini_with_retries(MODEL_ID_flash, contents, config)
|
|
corrected_feedbacks = json.loads(response_text)
|
|
|
|
global_feedbacks = [f for f in original_feedbacks if not f["box_2d"]]
|
|
|
|
# Map the coordinates back from the single image to the group canvas
|
|
for f in corrected_feedbacks:
|
|
b = f.get("box_2d")
|
|
if b:
|
|
ymin_s, xmin_s, ymax_s, xmax_s = b
|
|
|
|
# Y mapping: Add the group Y-offset (yming), then normalize to total_height
|
|
single_h = ymaxg - yming
|
|
new_ymin = int((yming + (ymin_s * single_h / 1000.0)) * 1000.0 / total_height)
|
|
new_ymax = int((yming + (ymax_s * single_h / 1000.0)) * 1000.0 / total_height)
|
|
|
|
# X mapping: Multiply by the width ratio of this sub-image vs the group image
|
|
new_xmin = int(xmin_s * width_r)
|
|
new_xmax = int(xmax_s * width_r)
|
|
|
|
f["box_2d"] = [new_ymin, new_xmin, new_ymax, new_xmax]
|
|
|
|
return global_feedbacks + corrected_feedbacks
|
|
|
|
import shutil
|
|
import grouping
|
|
|
|
def get_next_group_idx(label):
|
|
"""Finds the next available Group index for a given label."""
|
|
target_folder = GROUPS_DIR / label
|
|
target_folder.mkdir(exist_ok=True)
|
|
existing = list(target_folder.glob("Group_*.jpg"))
|
|
if not existing: return 0
|
|
return max([int(f.stem.split("_")[1]) for f in existing])
|
|
|
|
from utils import read_all_labels, enonce_total
|
|
|
|
def handle_label_errors(pid, label, res, pdf_path):
|
|
"""Handles Gemini labeling errors, moves/copies files, and returns new tasks."""
|
|
new_tasks = []
|
|
error_type = res.get("error")
|
|
|
|
all_labels = read_all_labels(INPUT_DIR)
|
|
labels_txt = (INPUT_DIR / "labels").read_text(encoding="utf-8", errors="replace")
|
|
enonce = enonce_total(INPUT_DIR)
|
|
|
|
if error_type == "wrong-label":
|
|
tprint(f"\tHandling wrong-label for {pid} {label}")
|
|
contents, config = prompting.request_for_wrong_label(pdf_path, label, enonce, labels_txt)
|
|
new_label = call_gemini_with_retries(MODEL_ID_flash, contents, config).strip().strip('"\'')
|
|
if new_label not in all_labels:
|
|
tprint(f"\t\tCopie{pid} returned an incorrect label {new_label} from an initial wrong label {label}. Ignoring")
|
|
res["error"] = "wrg-lbl:cldtfix"
|
|
return []
|
|
if new_label == label:
|
|
res["error"] = ""
|
|
return []
|
|
|
|
base_new_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{new_label}.pdf"
|
|
new_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{new_label}_new.pdf"
|
|
|
|
if base_new_pdf_path.exists() or new_pdf_path.exists():
|
|
tprint(f"\t\tCopie{pid} tried to move wrong {label} to {new_label}, but it already exists. Delaying.")
|
|
# res["error"] = f"wrg-lbl:{new_label}?exists"
|
|
res["error"] = f"wrg-lbl:{new_label}?delayed"
|
|
else:
|
|
res["error"] = f"wrg-lbl-moved-to:{new_label}"
|
|
tprint(f"\t\tCopie{pid} : moving wrong {label} to {new_label}.")
|
|
|
|
# Copie vers _new, puis renommage de l'original vers _old
|
|
shutil.copy(str(pdf_path), str(new_pdf_path))
|
|
old_pdf_path = pdf_path.with_name(f"{label}_old.pdf")
|
|
if pdf_path != old_pdf_path:
|
|
shutil.move(str(pdf_path), str(old_pdf_path))
|
|
|
|
idx = get_next_group_idx(new_label)
|
|
height = grouping.get_pdf_height(str(new_pdf_path))
|
|
grouping.create_jpg(new_label, idx, [(pid, str(new_pdf_path), height)], GROUPS_DIR)
|
|
tprint(f"\t\tMaking {new_label} group {idx+1}")
|
|
new_tasks.append((str(GROUPS_DIR / new_label / f"Group_{idx+1}.jpg"),
|
|
new_label, False))
|
|
|
|
elif error_type == "additional-answer":
|
|
contents, config = prompting.request_for_additional_answer(pdf_path, label, enonce, labels_txt)
|
|
tprint(f"\tHandling additional-answer for {pid} {label}")
|
|
try:
|
|
add_labels = json.loads(call_gemini_with_retries(MODEL_ID_flash, contents, config))
|
|
except Exception:
|
|
add_labels = []
|
|
|
|
keep_error = False
|
|
error = "al:"
|
|
for add_label in add_labels:
|
|
if add_label == label:
|
|
continue
|
|
if add_label not in all_labels:
|
|
tprint(f"\t\t Inexistent label ({add_label}) from additional-answer processing {pid} {label}. Ignoring")
|
|
error += f"{add_label}??"
|
|
keep_error = True
|
|
continue
|
|
|
|
base_add_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{add_label}.pdf"
|
|
add_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{add_label}_new.pdf"
|
|
|
|
if not base_add_pdf_path.exists() and not add_pdf_path.exists():
|
|
shutil.copy(str(pdf_path), str(add_pdf_path))
|
|
tprint(f"\t\tCopying Copie{pid} : {label} -> {add_label}")
|
|
idx = get_next_group_idx(add_label)
|
|
tprint(f"\t\tMaking {add_label} group {idx+1}")
|
|
height = grouping.get_pdf_height(str(add_pdf_path))
|
|
grouping.create_jpg(add_label, idx, [(pid, str(add_pdf_path), height)], GROUPS_DIR)
|
|
new_tasks.append((str(GROUPS_DIR / add_label / f"Group_{idx+1}.jpg"),
|
|
add_label, False))
|
|
error += f"(->){add_label}"
|
|
keep_error = True
|
|
else:
|
|
keep_error = True
|
|
# error += f"(xx){add_label}"
|
|
error += f"(delayed){add_label}"
|
|
tprint(f"\t\tAlready present (not copied) Copie{pid} : {label} -> {add_label}. Delaying.")
|
|
if not keep_error:
|
|
res["error"] = ""
|
|
else:
|
|
res["error"] = error
|
|
|
|
return new_tasks
|
|
|
|
def process_single_task(task_tuple, precomputed_response=None):
|
|
try:
|
|
global pro_count, flash_count, pro_quota_exhausted
|
|
file_path = task_tuple[0]
|
|
label = task_tuple[1]
|
|
can_spawn_tasks = task_tuple[2] if len(task_tuple) > 2 else True
|
|
|
|
group_name = os.path.splitext(file_path)[0]
|
|
json_path = group_name + '.json'
|
|
new_tasks = []
|
|
|
|
with open(json_path, 'r') as f:
|
|
group_data = json.load(f)
|
|
|
|
n = len(group_data)
|
|
d_data = {l[0]: (l[1], l[2], l[3]) for l in group_data}
|
|
total_height = group_data[-1][2]
|
|
use_flash = n >= 4 or total_height <= 500
|
|
|
|
# Only apply limits and counts if we are making a live call
|
|
if precomputed_response is None:
|
|
if not use_flash:
|
|
with pro_lock:
|
|
if pro_quota_exhausted:
|
|
use_flash = True
|
|
elif limit is None or pro_count < limit:
|
|
pro_count += 1
|
|
else:
|
|
use_flash = True
|
|
|
|
if use_flash:
|
|
with pro_lock:
|
|
flash_count += 1
|
|
|
|
try:
|
|
contents, config = prompting.generate_request(INPUT_DIR, file_path, label)
|
|
model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
|
|
|
|
if precomputed_response:
|
|
tprint(f"Using batched response for: {label} {group_name}")
|
|
full_response_text = precomputed_response
|
|
else:
|
|
tprint(f"Asking Gemini {'Flash' if use_flash else 'Pro '}: {label} {group_name}")
|
|
full_response_text = call_gemini_with_retries(model_to_use, contents, config)
|
|
|
|
json_data = json.loads(full_response_text)
|
|
|
|
# Ensure consistency of answer placements
|
|
for p in json_data:
|
|
pid = p["id"]
|
|
res = p["result"]
|
|
yming, ymaxg, width_r = d_data[pid]
|
|
|
|
pdf_path = COPIES_DIR / f"Copie{pid}" / f"{label}.pdf"
|
|
current_suffix = ""
|
|
|
|
# Détection du vrai fichier s'il a un suffixe
|
|
if not pdf_path.exists():
|
|
if pdf_path.with_name(f"{label}_new.pdf").exists():
|
|
pdf_path = pdf_path.with_name(f"{label}_new.pdf")
|
|
current_suffix = "_new"
|
|
# Quand est-ce que ce chemin est utilisé ? Jamais ?
|
|
elif pdf_path.with_name(f"{label}_old.pdf").exists():
|
|
pdf_path = pdf_path.with_name(f"{label}_old.pdf")
|
|
current_suffix = "_old"
|
|
|
|
# 1. Gestion de empty-answer
|
|
if res.get("error") == "empty-answer":
|
|
old_path = pdf_path.with_name(f"{label}_old.pdf")
|
|
if pdf_path.exists() and pdf_path != old_path:
|
|
shutil.move(str(pdf_path), str(old_path))
|
|
pdf_path = old_path
|
|
current_suffix = "_old"
|
|
|
|
if (not can_spawn_tasks) and res["error"] == "additional-answer":
|
|
tprint("\tSwallowing an additional-answer from a subsequent task.")
|
|
res["error"]= ""
|
|
if res["error"] != "":
|
|
tprint("\tError :", res["error"], "for Copie", pid, group_name)
|
|
|
|
if can_spawn_tasks and res.get("error") in ["wrong-label", "additional-answer"]:
|
|
new_tasks.extend(handle_label_errors(pid, label, res, pdf_path))
|
|
# Si "wrong-label" a déplacé le fichier courant vers _old
|
|
if res.get("error", "").startswith("wrg-lbl-moved-to:"):
|
|
current_suffix = "_old"
|
|
|
|
# 5. Enregistrer l'information dans correction.json
|
|
if current_suffix:
|
|
res["suffix"] = current_suffix
|
|
|
|
needs_correction = []
|
|
for (i,f) in enumerate(res["feedback"]):
|
|
b = f.get("box_2d")
|
|
if b:
|
|
ymin, xmin, ymax, xmax = b
|
|
ymin = ymin * total_height // 1000
|
|
ymax = ymax * total_height // 1000
|
|
|
|
if pid not in d_data:
|
|
tprint("Error : Gemini answered a copie id not present",
|
|
pid, label, group_name)
|
|
continue
|
|
|
|
if (ymin < yming - 50 or ymax > ymaxg + 50 or xmax / 1000 > width_r):
|
|
needs_correction.append(i)
|
|
break
|
|
if ymin < yming - 5:
|
|
ymin = yming - 5
|
|
b[0] = ymin * 1000 // total_height
|
|
if ymax > ymaxg + 5:
|
|
ymax = ymaxg + 5
|
|
b[2] = ymax * 1000 // total_height
|
|
|
|
|
|
if needs_correction:
|
|
tprint(f"\tBox anomalies detected for Copie {pid} {group_name}. \n\tRequesting isolated correction from Gemini Flash...")
|
|
try:
|
|
# Pensez à passer pdf_path à la fonction modifiée !
|
|
res["feedback"] = correct_boxes_with_gemini(
|
|
pid, label, pdf_path, res["feedback"],
|
|
yming, ymaxg, width_r, total_height)
|
|
except Exception as e:
|
|
tprint(f"\tCorrection failed for Copie {pid}, {group_name} : {e}\n\tRemoving the boxes")
|
|
# Fallback if the second request fails entirely
|
|
for (i, f) in enumerate(res["feedback"]):
|
|
if i in needs_correction:
|
|
f["box_2d"] = None
|
|
|
|
# --- Use Lock for writing shared data ---
|
|
with io_lock:
|
|
if label not in results:
|
|
results[label] = []
|
|
results[label].append(json_data)
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
# To track progress
|
|
completed_tasks.append((file_path, label))
|
|
with open(progress_path, "w", encoding="utf-8") as f:
|
|
json.dump(completed_tasks, f, indent=2)
|
|
|
|
except json.JSONDecodeError:
|
|
tprint(f"Error decoding JSON for {file_path}", file=sys.stderr)
|
|
with io_lock:
|
|
errors_summary.append(("Error decoding JSON response", file_path))
|
|
except Exception as e:
|
|
error_msg = f"Exception processing {file_path}: {e}"
|
|
print(error_msg, file=sys.stderr)
|
|
with io_lock:
|
|
errors_summary.append((error_msg, file_path))
|
|
return new_tasks
|
|
finally:
|
|
flush_thread_log()
|
|
|
|
def resolve_delayed_moves():
|
|
"""Scans the current results to find delayed moves and executes them if space was freed."""
|
|
new_tasks = []
|
|
with io_lock:
|
|
for label, batches in results.items():
|
|
for batch in batches:
|
|
for p in batch:
|
|
err = p.get("result", {}).get("error", "")
|
|
if not err or ("?delayed" not in err and "(delayed)" not in err):
|
|
continue
|
|
|
|
pid = p["id"]
|
|
pdf_path = COPIES_DIR / f"Copie{pid}" / f"{label}.pdf"
|
|
|
|
if not pdf_path.exists():
|
|
if pdf_path.with_name(f"{label}_new.pdf").exists():
|
|
pdf_path = pdf_path.with_name(f"{label}_new.pdf")
|
|
elif pdf_path.with_name(f"{label}_old.pdf").exists():
|
|
pdf_path = pdf_path.with_name(f"{label}_old.pdf")
|
|
|
|
# 1. Résolution de wrong-label
|
|
if err.startswith("wrg-lbl:") and "?delayed" in err:
|
|
new_label = err.split(":")[1].split("?")[0]
|
|
base_new_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{new_label}.pdf"
|
|
new_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{new_label}_new.pdf"
|
|
|
|
# Si la place s'est libérée (l'ancien a été bougé vers _old)
|
|
if not base_new_pdf_path.exists() and not new_pdf_path.exists():
|
|
tprint(f"Resolving delayed move: Copie{pid} {label} -> {new_label}")
|
|
p["result"]["error"] = f"wrg-lbl-moved-to:{new_label}"
|
|
p["result"]["suffixe"] = "_old" # Très important pour l'ignorer ensuite
|
|
|
|
shutil.copy(str(pdf_path), str(new_pdf_path))
|
|
old_pdf_path = pdf_path.with_name(f"{label}_old.pdf")
|
|
if pdf_path != old_pdf_path:
|
|
shutil.move(str(pdf_path), str(old_pdf_path))
|
|
|
|
idx = get_next_group_idx(new_label)
|
|
height = grouping.get_pdf_height(str(new_pdf_path))
|
|
grouping.create_jpg(new_label, idx, [(pid, str(new_pdf_path), height)], GROUPS_DIR)
|
|
new_tasks.append((str(GROUPS_DIR / new_label / f"Group_{idx+1}.jpg"), new_label, False))
|
|
|
|
# 2. Résolution de additional-answer
|
|
elif err.startswith("al:") and "(delayed)" in err:
|
|
import re
|
|
delayed_matches = re.findall(r'\(delayed\)([^?()]+)', err)
|
|
new_err = err
|
|
resolved_any = False
|
|
|
|
for add_label in delayed_matches:
|
|
base_add_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{add_label}.pdf"
|
|
add_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{add_label}_new.pdf"
|
|
|
|
if not base_add_pdf_path.exists() and not add_pdf_path.exists():
|
|
tprint(f"Resolving delayed additional-answer: Copie{pid} {label} -> {add_label}")
|
|
new_err = new_err.replace(f"(delayed){add_label}", f"(->){add_label}")
|
|
resolved_any = True
|
|
|
|
shutil.copy(str(pdf_path), str(add_pdf_path))
|
|
idx = get_next_group_idx(add_label)
|
|
height = grouping.get_pdf_height(str(add_pdf_path))
|
|
grouping.create_jpg(add_label, idx, [(pid, str(add_pdf_path), height)], GROUPS_DIR)
|
|
new_tasks.append((str(GROUPS_DIR / add_label / f"Group_{idx+1}.jpg"), add_label, False))
|
|
|
|
if resolved_any:
|
|
p["result"]["error"] = new_err
|
|
|
|
if new_tasks:
|
|
# Sauvegarder les modifications d'erreurs (les tags delayed enlevés)
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
return new_tasks
|
|
|
|
if __name__ == "__main__":
|
|
if args.refaire:
|
|
refaire_path = INPUT_DIR / "refaire.json"
|
|
overwritten_path = INPUT_DIR / "overwritten_correction.json"
|
|
|
|
if refaire_path.exists():
|
|
with open(refaire_path, "r", encoding="utf-8") as f:
|
|
refaire_list = json.load(f)
|
|
|
|
overwritten_data = []
|
|
if overwritten_path.exists():
|
|
with open(overwritten_path, "r", encoding="utf-8") as f:
|
|
overwritten_data = json.load(f)
|
|
|
|
dirty_results = False
|
|
|
|
for copie_name, labels in refaire_list:
|
|
pid = copie_name.replace("Copie", "")
|
|
copie_dir = COPIES_DIR / copie_name
|
|
|
|
# If list is empty, redo all labels available for this Copie
|
|
if not labels:
|
|
labels = [p.stem for p in copie_dir.glob("*.pdf")]
|
|
|
|
for label in labels:
|
|
# 1. Extract and backup old corrections
|
|
if label in results:
|
|
for batch in results[label]:
|
|
to_remove = None
|
|
for item in batch:
|
|
if item.get("id") == pid:
|
|
to_remove = item
|
|
break
|
|
if to_remove:
|
|
batch.remove(to_remove)
|
|
overwritten_data.append({
|
|
"pid": pid,
|
|
"label": label,
|
|
"data": to_remove,
|
|
"timestamp": time.time()
|
|
})
|
|
dirty_results = True
|
|
# Clean up empty batches
|
|
results[label] = [b for b in results[label] if b]
|
|
|
|
# 2. Make new group and add to tasks
|
|
pdf_path = copie_dir / f"{label}.pdf"
|
|
if not pdf_path.exists():
|
|
if (copie_dir / f"{label}_new.pdf").exists():
|
|
pdf_path = copie_dir / f"{label}_new.pdf"
|
|
# elif (copie_dir / f"{label}_old.pdf").exists():
|
|
# pdf_path = copie_dir / f"{label}_old.pdf"
|
|
|
|
if pdf_path.exists():
|
|
idx = get_next_group_idx(label)
|
|
height = grouping.get_pdf_height(str(pdf_path))
|
|
grouping.create_jpg(label, idx, [(pid, str(pdf_path), height)], GROUPS_DIR)
|
|
new_group_path = str(GROUPS_DIR / label / f"Group_{idx+1}.jpg")
|
|
tasks_to_process.append((new_group_path, label))
|
|
|
|
if dirty_results:
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
json.dump(results, f, indent=2)
|
|
with open(overwritten_path, "w", encoding="utf-8") as f:
|
|
json.dump(overwritten_data, f, indent=2)
|
|
else:
|
|
print(f"Warning: --refaire flag used, but {refaire_path} not found.", file=sys.stderr)
|
|
|
|
|
|
if args.batch or args.batch_from:
|
|
all_labels = read_all_labels(INPUT_DIR)
|
|
batch_tasks = []
|
|
if args.batch_from:
|
|
for label in all_labels:
|
|
if label.startswith(args.batch_from):
|
|
args.batch_from = label
|
|
input(f"About to batch from: {args.batch_from}. Press Enter to confirm...")
|
|
break
|
|
if args.batch_from not in all_labels:
|
|
sys.exit(f"Error: Label '{args.batch_from}' not found. Available labels: {all_labels}")
|
|
|
|
target_idx = all_labels.index(args.batch_from)
|
|
live_tasks = []
|
|
|
|
for task in tasks_to_process:
|
|
lbl = task[1]
|
|
# Any label found sequentially equal or after `args.batch_from` gets batched
|
|
if lbl in all_labels and all_labels.index(lbl) >= target_idx:
|
|
batch_tasks.append(task)
|
|
else:
|
|
live_tasks.append(task)
|
|
|
|
tasks_to_process = live_tasks # Keep live tasks to be run right after
|
|
else:
|
|
batch_tasks = tasks_to_process
|
|
tasks_to_process = [] # Run nothing live if just `--batch`
|
|
|
|
if batch_tasks:
|
|
batch_flash_file = INPUT_DIR / "batch_requests_flash.jsonl"
|
|
batch_pro_file = INPUT_DIR / "batch_requests_pro.jsonl"
|
|
|
|
count_flash = 0
|
|
count_pro = 0
|
|
|
|
with open(batch_flash_file, "w", encoding="utf-8") as f_flash, \
|
|
open(batch_pro_file, "w", encoding="utf-8") as f_pro:
|
|
|
|
for task in batch_tasks:
|
|
file_path, label = task[0], task[1]
|
|
group_name = os.path.splitext(file_path)[0]
|
|
json_path = group_name + '.json'
|
|
|
|
with open(json_path, 'r') as jf:
|
|
group_data = json.load(jf)
|
|
use_flash = len(group_data) >= 4 or group_data[-1][2] <= 500
|
|
|
|
image_data = Path(file_path).read_bytes()
|
|
b64_img = base64.b64encode(image_data).decode("utf-8")
|
|
|
|
# Format payload matching Gemini Batch API file requirements
|
|
req = {
|
|
"key": file_path, # The ID returned in the output file
|
|
"request": {
|
|
"contents": [{
|
|
"role": "user",
|
|
"parts": [
|
|
{"inlineData": {"mimeType": "image/jpeg", "data": b64_img}},
|
|
{"text": prompting.make_prompt(INPUT_DIR,label)}
|
|
]
|
|
}],
|
|
"generation_config": {
|
|
"temperature": 1.0,
|
|
"topP": 0.95,
|
|
"maxOutputTokens": 65535,
|
|
"responseMimeType": "application/json",
|
|
"responseSchema": prompting.UNROLLED_SCHEMA
|
|
}
|
|
}
|
|
}
|
|
|
|
if use_flash:
|
|
f_flash.write(json.dumps(req) + "\n")
|
|
count_flash += 1
|
|
else:
|
|
f_pro.write(json.dumps(req) + "\n")
|
|
count_pro += 1
|
|
|
|
print(f"Batch generation complete.")
|
|
print(f" - {count_flash} requests saved to {batch_flash_file} (for {MODEL_ID_flash})")
|
|
print(f" - {count_pro} requests saved to {batch_pro_file} (for {MODEL_ID_pro})")
|
|
print("Upload these files via the File API and create two separate batch jobs.")
|
|
|
|
# If there's no live tasks to do, and we aren't doing a batched ingestion, exit right away
|
|
if not tasks_to_process and not args.deal_with_batched:
|
|
sys.exit(0)
|
|
|
|
batched_responses = {}
|
|
if args.deal_with_batched:
|
|
batch_results_path = INPUT_DIR / "batched_correction_result.jsonl"
|
|
if batch_results_path.exists():
|
|
print(f"Loading batch results from {batch_results_path}...")
|
|
with open(batch_results_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
if not line.strip(): continue
|
|
data = json.loads(line)
|
|
task_id = data.get("key") # Corresponds to the key sent in the request
|
|
|
|
if "response" in data:
|
|
try:
|
|
# Extract the JSON response text per standard Batch API schema
|
|
resp_text = data["response"]["candidates"][0]["content"]["parts"][0]["text"]
|
|
batched_responses[task_id] = resp_text
|
|
except (KeyError, IndexError) as e:
|
|
print(f"Warning: Could not parse response for {task_id}: {e}", file=sys.stderr)
|
|
elif "error" in data:
|
|
print(f"Batch API Error for {task_id}: {data['error']}", file=sys.stderr)
|
|
else:
|
|
print(f"Warning: Batch results file {batch_results_path} not found.", file=sys.stderr)
|
|
|
|
made_progress = True
|
|
while tasks_to_process or made_progress:
|
|
if tasks_to_process:
|
|
print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...")
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
|
|
futures = {}
|
|
for task in tasks_to_process:
|
|
file_path = task[0]
|
|
precomp = batched_responses.get(file_path)
|
|
futures[executor.submit(process_single_task, task, precomp)] = task
|
|
|
|
for future in concurrent.futures.as_completed(futures):
|
|
try:
|
|
new_generated_tasks = future.result()
|
|
if new_generated_tasks:
|
|
for new_task in new_generated_tasks:
|
|
futures[executor.submit(process_single_task, new_task)] = new_task
|
|
except Exception as e:
|
|
print(f"Exception during task execution: {e}", file=sys.stderr)
|
|
|
|
tasks_to_process = [] # Vider la liste une fois traitée
|
|
|
|
# Après avoir traité toutes les tâches actuelles (live ou batched),
|
|
# on tente de débloquer les mouvements qui étaient en attente
|
|
delayed_tasks = resolve_delayed_moves()
|
|
if delayed_tasks:
|
|
print(f"Resolved {len(delayed_tasks)} delayed moves! Running executor for new tasks...")
|
|
tasks_to_process.extend(delayed_tasks)
|
|
made_progress = True
|
|
else:
|
|
made_progress = False
|
|
|
|
end_time = time.time()
|
|
print("Time elapsed : ", end_time - start_time)
|
|
print("Requests to pro / flash : ", pro_count, flash_count)
|
|
if errors_summary:
|
|
print("\n--- Summary of Exceptions (You can use several images on one instance) ---", file=sys.stderr)
|
|
for (err, file) in errors_summary:
|
|
print(err, file=sys.stderr)
|
|
escaped_path = shlex.quote(str(file))
|
|
print(f"Run : python correction.py {escaped_path}")
|