Copies/correction.py

786 lines
34 KiB
Python

import sys
import os
import time
from pathlib import Path
import argparse
import prompting
import signal
from google import genai
import base64
import shlex
import json
import threading
import concurrent.futures
if len(sys.argv) < 2:
sys.exit("Usage: python script.py 'InterroTest/Ex 2/Group_1.jpg' OR <InputDir> OR 'file1' 'file2'")
# Parse Arguments
parser = argparse.ArgumentParser()
parser.add_argument("paths", nargs="+", help="List of images or directories")
parser.add_argument("--overwrite", action="store_true",
help="Force redo requests even if output exists")
parser.add_argument("--limit", type=int, help="limit calls to gemini rpo integer")
parser.add_argument("--refaire", action="store_true",
help="Redo specific copies/labels defined in refaire.json")
parser.add_argument("--batch", action="store_true",
help="Generate a JSONL file of requests to send to the Gemini Batch API")
parser.add_argument("--batch-from", type=str, metavar="LABEL",
help="Do live requests before LABEL, and batch requests from LABEL onwards")
parser.add_argument("--deal-with-batched", action="store_true",
help="Process a JSONL file containing completed batch results")
args, _ = parser.parse_known_args()
tasks = [] # List of tuples: (filepath_str, label_str)
results = {}
for path_str in args.paths:
arg_path = Path(path_str)
if not arg_path.exists():
print(f"Warning: {path_str} not found. Skipping.")
continue
if arg_path.is_file() and arg_path.suffix.lower() == ".jpg":
# Handle individual file
# Note: assumes structure InterroTest/Ex 2/Group_1.jpg
label = arg_path.parent.name
INPUT_DIR = arg_path.parent.parent.parent
COPIES_DIR = INPUT_DIR / "Copies"
GROUPS_DIR = INPUT_DIR / "Par label"
tasks.append((str(arg_path), label))
if label not in results:
results[label] = []
elif arg_path.is_dir():
INPUT_DIR = arg_path
COPIES_DIR = INPUT_DIR / "Copies"
GROUPS_DIR = INPUT_DIR / "Par label"
# Handle directory (original behavior)
for sub in GROUPS_DIR.iterdir():
if sub.is_dir():
label = sub.name
if label not in results:
results[label] = []
for img in sub.glob("*.jpg"):
tasks.append((str(img), label))
NB_THREADS = 12
# PROXY_URL = "http://192.168.241.1:3128"
PROXY_URL = None
if PROXY_URL:
os.environ["http_proxy"] = PROXY_URL
os.environ["https_proxy"] = PROXY_URL
MODEL_ID_pro = "gemini-3.1-pro-preview"
MODEL_ID_flash = "gemini-3-flash-preview"
api_key = os.environ["GEMINI_API_KEY"]
# --- Thread-safe Logging ---
log_lock = threading.Lock()
thread_logs = {}
def tprint(*args, **kwargs):
"""Buffer messages per thread to group them."""
tid = threading.current_thread().name
msg = " ".join(map(str, args))
with log_lock:
if tid not in thread_logs:
thread_logs[tid] = []
thread_logs[tid].append(msg)
# Optional: Keep printing to console but prefix with thread name
print(f"[{tid}] {msg}", **kwargs)
def flush_thread_log(tid=None):
"""Append a thread's buffered messages to the log file contiguously."""
tid = tid or threading.current_thread().name
with log_lock:
if thread_logs.get(tid):
with open(INPUT_DIR / "correction_log", "a", encoding="utf-8") as f:
f.write(f"--- Task Log [{tid}] ---\n")
f.write("\n".join(thread_logs[tid]) + "\n\n")
thread_logs[tid].clear()
def handle_interrupt(sig, frame):
"""Flush all partial/unfinished logs if program is interrupted."""
print("\nInterrupt received. Flushing partial logs...", file=sys.stderr)
for tid in list(thread_logs.keys()):
flush_thread_log(tid)
sys.exit(1)
signal.signal(signal.SIGINT, handle_interrupt)
signal.signal(signal.SIGTERM, handle_interrupt)
# ---------------------------
client = genai.Client(api_key=api_key)
output_path = INPUT_DIR / "correction.json"
progress_path = INPUT_DIR / "correction_progress.json"
start_time = time.time()
overwrite = args.overwrite
limit = args.limit
completed_tasks = []
errors_summary = []
# --- Lock for thread-safe file writing ---
io_lock = threading.Lock()
pro_lock = threading.Lock()
pro_count = 0
flash_count = 0
pro_quota_exhausted = False
if overwrite:
if output_path.exists():
output_path.unlink()
if progress_path.exists():
progress_path.unlink()
else:
if progress_path.exists():
with open(progress_path, "r", encoding="utf-8") as f:
completed_tasks = json.load(f)
if output_path.exists():
with open(output_path, "r", encoding="utf-8") as f:
results = json.load(f)
completed_set = set((str(f), l) for f, l in completed_tasks)
tasks_to_process = [t for t in tasks if (str(t[0]), t[1]) not in completed_set]
def call_gemini_with_retries(model_id, contents, config,
fallback_model_id=MODEL_ID_flash):
"""Handles requests to Gemini with a 1min and 5min retry mechanism, and quota fallback."""
global pro_quota_exhausted
delays = [60, 300]
for attempt in range(3):
# Switch to fallback immediately if quota was exhausted by another thread
if model_id == MODEL_ID_pro and pro_quota_exhausted and fallback_model_id:
model_id = fallback_model_id
try:
full_response_text = ""
for chunk in client.models.generate_content_stream(
model=model_id,
contents=contents,
config=config,
):
if chunk.text:
full_response_text += chunk.text
return full_response_text
except Exception as e:
error_msg = str(e).lower()
is_quota_error = "429" in error_msg or "quota" in error_msg or "exhausted" in error_msg
is_minute_limit = "minute" in error_msg or "rpm" in error_msg or "tpm" in error_msg
if is_minute_limit:
import re
# Extract wait time if present, else use default delay
retry_match = re.search(r"retry in ([\d.]+)s", error_msg)
wait_time = float(retry_match.group(1)) + 1.0 if retry_match else delays[attempt]
tprint(f"\tGemini Pro minute limit hit. Waiting {wait_time:.1f}s...")
time.sleep(wait_time)
continue # Retry same model
# Immediately fallback to Flash without waiting if it's a Pro quota error
if is_quota_error and model_id == MODEL_ID_pro and fallback_model_id:
tprint(f"\tGemini Pro quota hit ({e}). \n\n\tFalling back to Flash permanently...")
model_id = fallback_model_id
pro_quota_exhausted = True
continue # Retry immediately with Flash
if attempt < 2:
tprint(f"\tGemini API failure: {e}. Retrying in {delays[attempt]} seconds...")
time.sleep(delays[attempt])
else:
tprint(f"\tGemini API failure: {e}. Maximum retries reached.")
raise
def correct_boxes_with_gemini(pid, label, pdf_path, original_feedbacks,
yming, ymaxg, width_r, total_height):
"""Requests corrected bounding boxes from Gemini Flash on the single image."""
# pdf_path = COPIES_DIR / f"Copie{pid}" / f"{label}.pdf"
contents, config = prompting.request_for_box_correction(pdf_path, original_feedbacks)
response_text = call_gemini_with_retries(MODEL_ID_flash, contents, config)
corrected_feedbacks = json.loads(response_text)
global_feedbacks = [f for f in original_feedbacks if not f["box_2d"]]
# Map the coordinates back from the single image to the group canvas
for f in corrected_feedbacks:
b = f.get("box_2d")
if b:
ymin_s, xmin_s, ymax_s, xmax_s = b
# Y mapping: Add the group Y-offset (yming), then normalize to total_height
single_h = ymaxg - yming
new_ymin = int((yming + (ymin_s * single_h / 1000.0)) * 1000.0 / total_height)
new_ymax = int((yming + (ymax_s * single_h / 1000.0)) * 1000.0 / total_height)
# X mapping: Multiply by the width ratio of this sub-image vs the group image
new_xmin = int(xmin_s * width_r)
new_xmax = int(xmax_s * width_r)
f["box_2d"] = [new_ymin, new_xmin, new_ymax, new_xmax]
return global_feedbacks + corrected_feedbacks
import shutil
import grouping
def get_next_group_idx(label):
"""Finds the next available Group index for a given label."""
target_folder = GROUPS_DIR / label
target_folder.mkdir(exist_ok=True)
existing = list(target_folder.glob("Group_*.jpg"))
if not existing: return 0
return max([int(f.stem.split("_")[1]) for f in existing])
from utils import read_all_labels, enonce_total
def handle_label_errors(pid, label, res, pdf_path):
"""Handles Gemini labeling errors, moves/copies files, and returns new tasks."""
new_tasks = []
error_type = res.get("error")
all_labels = read_all_labels(INPUT_DIR)
labels_txt = (INPUT_DIR / "labels").read_text(encoding="utf-8", errors="replace")
enonce = enonce_total(INPUT_DIR)
if error_type == "wrong-label":
tprint(f"\tHandling wrong-label for {pid} {label}")
contents, config = prompting.request_for_wrong_label(pdf_path, label, enonce, labels_txt)
new_label = call_gemini_with_retries(MODEL_ID_flash, contents, config).strip().strip('"\'')
if new_label not in all_labels:
tprint(f"\t\tCopie{pid} returned an incorrect label {new_label} from an initial wrong label {label}. Ignoring")
res["error"] = "wrg-lbl:cldtfix"
return []
if new_label == label:
res["error"] = ""
return []
base_new_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{new_label}.pdf"
new_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{new_label}_new.pdf"
if base_new_pdf_path.exists() or new_pdf_path.exists():
tprint(f"\t\tCopie{pid} tried to move wrong {label} to {new_label}, but it already exists. Delaying.")
# res["error"] = f"wrg-lbl:{new_label}?exists"
res["error"] = f"wrg-lbl:{new_label}?delayed"
else:
res["error"] = f"wrg-lbl-moved-to:{new_label}"
tprint(f"\t\tCopie{pid} : moving wrong {label} to {new_label}.")
# Copie vers _new, puis renommage de l'original vers _old
shutil.copy(str(pdf_path), str(new_pdf_path))
old_pdf_path = pdf_path.with_name(f"{label}_old.pdf")
if pdf_path != old_pdf_path:
shutil.move(str(pdf_path), str(old_pdf_path))
idx = get_next_group_idx(new_label)
height = grouping.get_pdf_height(str(new_pdf_path))
grouping.create_jpg(new_label, idx, [(pid, str(new_pdf_path), height)], GROUPS_DIR)
tprint(f"\t\tMaking {new_label} group {idx+1}")
new_tasks.append((str(GROUPS_DIR / new_label / f"Group_{idx+1}.jpg"),
new_label, False))
elif error_type == "additional-answer":
contents, config = prompting.request_for_additional_answer(pdf_path, label, enonce, labels_txt)
tprint(f"\tHandling additional-answer for {pid} {label}")
try:
add_labels = json.loads(call_gemini_with_retries(MODEL_ID_flash, contents, config))
except Exception:
add_labels = []
keep_error = False
error = "al:"
for add_label in add_labels:
if add_label == label:
continue
if add_label not in all_labels:
tprint(f"\t\t Inexistent label ({add_label}) from additional-answer processing {pid} {label}. Ignoring")
error += f"{add_label}??"
keep_error = True
continue
base_add_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{add_label}.pdf"
add_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{add_label}_new.pdf"
if not base_add_pdf_path.exists() and not add_pdf_path.exists():
shutil.copy(str(pdf_path), str(add_pdf_path))
tprint(f"\t\tCopying Copie{pid} : {label} -> {add_label}")
idx = get_next_group_idx(add_label)
tprint(f"\t\tMaking {add_label} group {idx+1}")
height = grouping.get_pdf_height(str(add_pdf_path))
grouping.create_jpg(add_label, idx, [(pid, str(add_pdf_path), height)], GROUPS_DIR)
new_tasks.append((str(GROUPS_DIR / add_label / f"Group_{idx+1}.jpg"),
add_label, False))
error += f"(->){add_label}"
keep_error = True
else:
keep_error = True
# error += f"(xx){add_label}"
error += f"(delayed){add_label}"
tprint(f"\t\tAlready present (not copied) Copie{pid} : {label} -> {add_label}. Delaying.")
if not keep_error:
res["error"] = ""
else:
res["error"] = error
return new_tasks
def process_single_task(task_tuple, precomputed_response=None):
try:
global pro_count, flash_count, pro_quota_exhausted
file_path = task_tuple[0]
label = task_tuple[1]
can_spawn_tasks = task_tuple[2] if len(task_tuple) > 2 else True
group_name = os.path.splitext(file_path)[0]
json_path = group_name + '.json'
new_tasks = []
with open(json_path, 'r') as f:
group_data = json.load(f)
n = len(group_data)
d_data = {l[0]: (l[1], l[2], l[3]) for l in group_data}
total_height = group_data[-1][2]
use_flash = n >= 4 or total_height <= 500
# Only apply limits and counts if we are making a live call
if precomputed_response is None:
if not use_flash:
with pro_lock:
if pro_quota_exhausted:
use_flash = True
elif limit is None or pro_count < limit:
pro_count += 1
else:
use_flash = True
if use_flash:
with pro_lock:
flash_count += 1
try:
contents, config = prompting.generate_request(INPUT_DIR, file_path, label)
model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
if precomputed_response:
tprint(f"Using batched response for: {label} {group_name}")
full_response_text = precomputed_response
else:
tprint(f"Asking Gemini {'Flash' if use_flash else 'Pro '}: {label} {group_name}")
full_response_text = call_gemini_with_retries(model_to_use, contents, config)
json_data = json.loads(full_response_text)
# Ensure consistency of answer placements
for p in json_data:
pid = p["id"]
res = p["result"]
yming, ymaxg, width_r = d_data[pid]
pdf_path = COPIES_DIR / f"Copie{pid}" / f"{label}.pdf"
current_suffix = ""
# Détection du vrai fichier s'il a un suffixe
if not pdf_path.exists():
if pdf_path.with_name(f"{label}_new.pdf").exists():
pdf_path = pdf_path.with_name(f"{label}_new.pdf")
current_suffix = "_new"
# Quand est-ce que ce chemin est utilisé ? Jamais ?
elif pdf_path.with_name(f"{label}_old.pdf").exists():
pdf_path = pdf_path.with_name(f"{label}_old.pdf")
current_suffix = "_old"
# 1. Gestion de empty-answer
if res.get("error") == "empty-answer":
old_path = pdf_path.with_name(f"{label}_old.pdf")
if pdf_path.exists() and pdf_path != old_path:
shutil.move(str(pdf_path), str(old_path))
pdf_path = old_path
current_suffix = "_old"
if (not can_spawn_tasks) and res["error"] == "additional-answer":
tprint("\tSwallowing an additional-answer from a subsequent task.")
res["error"]= ""
if res["error"] != "":
tprint("\tError :", res["error"], "for Copie", pid, group_name)
if can_spawn_tasks and res.get("error") in ["wrong-label", "additional-answer"]:
new_tasks.extend(handle_label_errors(pid, label, res, pdf_path))
# Si "wrong-label" a déplacé le fichier courant vers _old
if res.get("error", "").startswith("wrg-lbl-moved-to:"):
current_suffix = "_old"
# 5. Enregistrer l'information dans correction.json
if current_suffix:
res["suffix"] = current_suffix
needs_correction = []
for (i,f) in enumerate(res["feedback"]):
b = f.get("box_2d")
if b:
ymin, xmin, ymax, xmax = b
ymin = ymin * total_height // 1000
ymax = ymax * total_height // 1000
if pid not in d_data:
tprint("Error : Gemini answered a copie id not present",
pid, label, group_name)
continue
if (ymin < yming - 50 or ymax > ymaxg + 50 or xmax / 1000 > width_r):
needs_correction.append(i)
break
if ymin < yming - 5:
ymin = yming - 5
b[0] = ymin * 1000 // total_height
if ymax > ymaxg + 5:
ymax = ymaxg + 5
b[2] = ymax * 1000 // total_height
if needs_correction:
tprint(f"\tBox anomalies detected for Copie {pid} {group_name}. \n\tRequesting isolated correction from Gemini Flash...")
try:
# Pensez à passer pdf_path à la fonction modifiée !
res["feedback"] = correct_boxes_with_gemini(
pid, label, pdf_path, res["feedback"],
yming, ymaxg, width_r, total_height)
except Exception as e:
tprint(f"\tCorrection failed for Copie {pid}, {group_name} : {e}\n\tRemoving the boxes")
# Fallback if the second request fails entirely
for (i, f) in enumerate(res["feedback"]):
if i in needs_correction:
f["box_2d"] = None
# --- Use Lock for writing shared data ---
with io_lock:
if label not in results:
results[label] = []
results[label].append(json_data)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
# To track progress
completed_tasks.append((file_path, label))
with open(progress_path, "w", encoding="utf-8") as f:
json.dump(completed_tasks, f, indent=2)
except json.JSONDecodeError:
tprint(f"Error decoding JSON for {file_path}", file=sys.stderr)
with io_lock:
errors_summary.append(("Error decoding JSON response", file_path))
except Exception as e:
error_msg = f"Exception processing {file_path}: {e}"
print(error_msg, file=sys.stderr)
with io_lock:
errors_summary.append((error_msg, file_path))
return new_tasks
finally:
flush_thread_log()
def resolve_delayed_moves():
"""Scans the current results to find delayed moves and executes them if space was freed."""
new_tasks = []
with io_lock:
for label, batches in results.items():
for batch in batches:
for p in batch:
err = p.get("result", {}).get("error", "")
if not err or ("?delayed" not in err and "(delayed)" not in err):
continue
pid = p["id"]
pdf_path = COPIES_DIR / f"Copie{pid}" / f"{label}.pdf"
if not pdf_path.exists():
if pdf_path.with_name(f"{label}_new.pdf").exists():
pdf_path = pdf_path.with_name(f"{label}_new.pdf")
elif pdf_path.with_name(f"{label}_old.pdf").exists():
pdf_path = pdf_path.with_name(f"{label}_old.pdf")
# 1. Résolution de wrong-label
if err.startswith("wrg-lbl:") and "?delayed" in err:
new_label = err.split(":")[1].split("?")[0]
base_new_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{new_label}.pdf"
new_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{new_label}_new.pdf"
# Si la place s'est libérée (l'ancien a été bougé vers _old)
if not base_new_pdf_path.exists() and not new_pdf_path.exists():
tprint(f"Resolving delayed move: Copie{pid} {label} -> {new_label}")
p["result"]["error"] = f"wrg-lbl-moved-to:{new_label}"
p["result"]["suffixe"] = "_old" # Très important pour l'ignorer ensuite
shutil.copy(str(pdf_path), str(new_pdf_path))
old_pdf_path = pdf_path.with_name(f"{label}_old.pdf")
if pdf_path != old_pdf_path:
shutil.move(str(pdf_path), str(old_pdf_path))
idx = get_next_group_idx(new_label)
height = grouping.get_pdf_height(str(new_pdf_path))
grouping.create_jpg(new_label, idx, [(pid, str(new_pdf_path), height)], GROUPS_DIR)
new_tasks.append((str(GROUPS_DIR / new_label / f"Group_{idx+1}.jpg"), new_label, False))
# 2. Résolution de additional-answer
elif err.startswith("al:") and "(delayed)" in err:
import re
delayed_matches = re.findall(r'\(delayed\)([^?()]+)', err)
new_err = err
resolved_any = False
for add_label in delayed_matches:
base_add_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{add_label}.pdf"
add_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{add_label}_new.pdf"
if not base_add_pdf_path.exists() and not add_pdf_path.exists():
tprint(f"Resolving delayed additional-answer: Copie{pid} {label} -> {add_label}")
new_err = new_err.replace(f"(delayed){add_label}", f"(->){add_label}")
resolved_any = True
shutil.copy(str(pdf_path), str(add_pdf_path))
idx = get_next_group_idx(add_label)
height = grouping.get_pdf_height(str(add_pdf_path))
grouping.create_jpg(add_label, idx, [(pid, str(add_pdf_path), height)], GROUPS_DIR)
new_tasks.append((str(GROUPS_DIR / add_label / f"Group_{idx+1}.jpg"), add_label, False))
if resolved_any:
p["result"]["error"] = new_err
if new_tasks:
# Sauvegarder les modifications d'erreurs (les tags delayed enlevés)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
return new_tasks
if __name__ == "__main__":
if args.refaire:
refaire_path = INPUT_DIR / "refaire.json"
overwritten_path = INPUT_DIR / "overwritten_correction.json"
if refaire_path.exists():
with open(refaire_path, "r", encoding="utf-8") as f:
refaire_list = json.load(f)
overwritten_data = []
if overwritten_path.exists():
with open(overwritten_path, "r", encoding="utf-8") as f:
overwritten_data = json.load(f)
dirty_results = False
for copie_name, labels in refaire_list:
pid = copie_name.replace("Copie", "")
copie_dir = COPIES_DIR / copie_name
# If list is empty, redo all labels available for this Copie
if not labels:
labels = [p.stem for p in copie_dir.glob("*.pdf")]
for label in labels:
# 1. Extract and backup old corrections
if label in results:
for batch in results[label]:
to_remove = None
for item in batch:
if item.get("id") == pid:
to_remove = item
break
if to_remove:
batch.remove(to_remove)
overwritten_data.append({
"pid": pid,
"label": label,
"data": to_remove,
"timestamp": time.time()
})
dirty_results = True
# Clean up empty batches
results[label] = [b for b in results[label] if b]
# 2. Make new group and add to tasks
pdf_path = copie_dir / f"{label}.pdf"
if not pdf_path.exists():
if (copie_dir / f"{label}_new.pdf").exists():
pdf_path = copie_dir / f"{label}_new.pdf"
# elif (copie_dir / f"{label}_old.pdf").exists():
# pdf_path = copie_dir / f"{label}_old.pdf"
if pdf_path.exists():
idx = get_next_group_idx(label)
height = grouping.get_pdf_height(str(pdf_path))
grouping.create_jpg(label, idx, [(pid, str(pdf_path), height)], GROUPS_DIR)
new_group_path = str(GROUPS_DIR / label / f"Group_{idx+1}.jpg")
tasks_to_process.append((new_group_path, label))
if dirty_results:
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
with open(overwritten_path, "w", encoding="utf-8") as f:
json.dump(overwritten_data, f, indent=2)
else:
print(f"Warning: --refaire flag used, but {refaire_path} not found.", file=sys.stderr)
if args.batch or args.batch_from:
all_labels = read_all_labels(INPUT_DIR)
batch_tasks = []
if args.batch_from:
for label in all_labels:
if label.startswith(args.batch_from):
args.batch_from = label
input(f"About to batch from: {args.batch_from}. Press Enter to confirm...")
break
if args.batch_from not in all_labels:
sys.exit(f"Error: Label '{args.batch_from}' not found. Available labels: {all_labels}")
target_idx = all_labels.index(args.batch_from)
live_tasks = []
for task in tasks_to_process:
lbl = task[1]
# Any label found sequentially equal or after `args.batch_from` gets batched
if lbl in all_labels and all_labels.index(lbl) >= target_idx:
batch_tasks.append(task)
else:
live_tasks.append(task)
tasks_to_process = live_tasks # Keep live tasks to be run right after
else:
batch_tasks = tasks_to_process
tasks_to_process = [] # Run nothing live if just `--batch`
if batch_tasks:
batch_flash_file = INPUT_DIR / "batch_requests_flash.jsonl"
batch_pro_file = INPUT_DIR / "batch_requests_pro.jsonl"
count_flash = 0
count_pro = 0
with open(batch_flash_file, "w", encoding="utf-8") as f_flash, \
open(batch_pro_file, "w", encoding="utf-8") as f_pro:
for task in batch_tasks:
file_path, label = task[0], task[1]
group_name = os.path.splitext(file_path)[0]
json_path = group_name + '.json'
with open(json_path, 'r') as jf:
group_data = json.load(jf)
use_flash = len(group_data) >= 4 or group_data[-1][2] <= 500
image_data = Path(file_path).read_bytes()
b64_img = base64.b64encode(image_data).decode("utf-8")
# Format payload matching Gemini Batch API file requirements
req = {
"key": file_path, # The ID returned in the output file
"request": {
"contents": [{
"role": "user",
"parts": [
{"inlineData": {"mimeType": "image/jpeg", "data": b64_img}},
{"text": prompting.make_prompt(INPUT_DIR,label)}
]
}],
"generation_config": {
"temperature": 1.0,
"topP": 0.95,
"maxOutputTokens": 65535,
"responseMimeType": "application/json",
"responseSchema": prompting.UNROLLED_SCHEMA
}
}
}
if use_flash:
f_flash.write(json.dumps(req) + "\n")
count_flash += 1
else:
f_pro.write(json.dumps(req) + "\n")
count_pro += 1
print(f"Batch generation complete.")
print(f" - {count_flash} requests saved to {batch_flash_file} (for {MODEL_ID_flash})")
print(f" - {count_pro} requests saved to {batch_pro_file} (for {MODEL_ID_pro})")
print("Upload these files via the File API and create two separate batch jobs.")
# If there's no live tasks to do, and we aren't doing a batched ingestion, exit right away
if not tasks_to_process and not args.deal_with_batched:
sys.exit(0)
batched_responses = {}
if args.deal_with_batched:
batch_results_path = INPUT_DIR / "batched_correction_result.jsonl"
if batch_results_path.exists():
print(f"Loading batch results from {batch_results_path}...")
with open(batch_results_path, "r", encoding="utf-8") as f:
for line in f:
if not line.strip(): continue
data = json.loads(line)
task_id = data.get("key") # Corresponds to the key sent in the request
if "response" in data:
try:
# Extract the JSON response text per standard Batch API schema
resp_text = data["response"]["candidates"][0]["content"]["parts"][0]["text"]
batched_responses[task_id] = resp_text
except (KeyError, IndexError) as e:
print(f"Warning: Could not parse response for {task_id}: {e}", file=sys.stderr)
elif "error" in data:
print(f"Batch API Error for {task_id}: {data['error']}", file=sys.stderr)
else:
print(f"Warning: Batch results file {batch_results_path} not found.", file=sys.stderr)
made_progress = True
while tasks_to_process or made_progress:
if tasks_to_process:
print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...")
with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
futures = {}
for task in tasks_to_process:
file_path = task[0]
precomp = batched_responses.get(file_path)
futures[executor.submit(process_single_task, task, precomp)] = task
for future in concurrent.futures.as_completed(futures):
try:
new_generated_tasks = future.result()
if new_generated_tasks:
for new_task in new_generated_tasks:
futures[executor.submit(process_single_task, new_task)] = new_task
except Exception as e:
print(f"Exception during task execution: {e}", file=sys.stderr)
tasks_to_process = [] # Vider la liste une fois traitée
# Après avoir traité toutes les tâches actuelles (live ou batched),
# on tente de débloquer les mouvements qui étaient en attente
delayed_tasks = resolve_delayed_moves()
if delayed_tasks:
print(f"Resolved {len(delayed_tasks)} delayed moves! Running executor for new tasks...")
tasks_to_process.extend(delayed_tasks)
made_progress = True
else:
made_progress = False
end_time = time.time()
print("Time elapsed : ", end_time - start_time)
print("Requests to pro / flash : ", pro_count, flash_count)
if errors_summary:
print("\n--- Summary of Exceptions (You can use several images on one instance) ---", file=sys.stderr)
for (err, file) in errors_summary:
print(err, file=sys.stderr)
escaped_path = shlex.quote(str(file))
print(f"Run : python correction.py {escaped_path}")