Add Gemini retries

master
Sébastien Miquel 2026-03-06 19:03:48 +01:00
parent 2677e41b04
commit a98375f1e8
2 changed files with 40 additions and 38 deletions

View File

@ -236,7 +236,8 @@ def render_real_latex_text(text, width_px, bg_color=(255, 255, 255, 255), max_li
\\usepackage[T1]{{fontenc}} \\usepackage[T1]{{fontenc}}
\\usepackage{{lmodern}} % Enables arbitrary font scaling \\usepackage{{lmodern}} % Enables arbitrary font scaling
\\usepackage{{amsmath, amssymb}} \\usepackage{{amsmath, amssymb}}
%\\usepackage{{anyfontsize}} % replaces by lmodern \\usepackage{{commands}}
%\\usepackage{{anyfontsize}} % replaced by lmodern
\\begin{{document}} \\begin{{document}}
\\fontsize{{{fontsize}}}{{{line_spacing}}}\\selectfont \\fontsize{{{fontsize}}}{{{line_spacing}}}\\selectfont
{text} {text}

View File

@ -217,8 +217,8 @@ completed_tasks = []
# --- Lock for thread-safe file writing --- # --- Lock for thread-safe file writing ---
io_lock = threading.Lock() io_lock = threading.Lock()
pro_lock = threading.Lock() # New lock for counter pro_lock = threading.Lock()
pro_count = 0 # New counter pro_count = 0
flash_count = 0 flash_count = 0
if overwrite: if overwrite:
@ -230,63 +230,68 @@ else:
if progress_path.exists(): if progress_path.exists():
with open(progress_path, "r", encoding="utf-8") as f: with open(progress_path, "r", encoding="utf-8") as f:
completed_tasks = json.load(f) completed_tasks = json.load(f)
# Reload existing results to avoid overwriting them with partial data
if output_path.exists(): if output_path.exists():
with open(output_path, "r", encoding="utf-8") as f: with open(output_path, "r", encoding="utf-8") as f:
results = json.load(f) results = json.load(f)
# Create a set for O(1) lookup. Normalize paths to strings.
completed_set = set((str(f), l) for f, l in completed_tasks) completed_set = set((str(f), l) for f, l in completed_tasks)
# Filter tasks first to avoid overhead in threads
tasks_to_process = [t for t in tasks if (str(t[0]), t[1]) not in completed_set] tasks_to_process = [t for t in tasks if (str(t[0]), t[1]) not in completed_set]
def call_gemini_with_retries(model_id, contents, config):
"""Handles requests to Gemini with a 1min and 5min retry mechanism."""
delays = [60, 300]
for attempt in range(3):
try:
full_response_text = ""
for chunk in client.models.generate_content_stream(
model=model_id,
contents=contents,
config=config,
):
if chunk.text:
full_response_text += chunk.text
return full_response_text
except Exception as e:
if attempt < 2:
print(f"\tGemini API failure: {e}. Retrying in {delays[attempt]} seconds...")
time.sleep(delays[attempt])
else:
print(f"\tGemini API failure: {e}. Maximum retries reached.")
raise
def process_single_task(task_tuple): def process_single_task(task_tuple):
global pro_count, flash_count global pro_count, flash_count
file_path, label = task_tuple file_path, label = task_tuple
group_name = os.path.splitext(file_path)[0] group_name = os.path.splitext(file_path)[0]
json_path = group_name + '.json' json_path = group_name + '.json'
with open(json_path, 'r') as f: with open(json_path, 'r') as f:
# List of (groupid, start, end), in pixels
group_data = json.load(f) group_data = json.load(f)
n = len(group_data) n = len(group_data)
# l[3] is ratio of width to width of group
d_data = {l[0]: (l[1], l[2], l[3]) for l in group_data} d_data = {l[0]: (l[1], l[2], l[3]) for l in group_data}
total_height = group_data[-1][2] total_height = group_data[-1][2]
use_flash = n >= 4 or total_height <= 500 use_flash = n >= 4 or total_height <= 500
if not use_flash: if not use_flash:
with pro_lock: with pro_lock:
if limit is None or pro_count < limit: if limit is None or pro_count < limit:
pro_count += 1 pro_count += 1
else: else:
# Limit reached, force switch to Flash
use_flash = True use_flash = True
if use_flash: if use_flash:
with pro_lock: with pro_lock:
flash_count += 1 flash_count += 1
try: try:
contents, config = generate_request(file_path, label) contents, config = generate_request(file_path, label)
if use_flash: model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
print(f"Asking Gemini Flash: {label} {group_name}") print(f"Asking Gemini {'Flash' if use_flash else 'Pro '}: {label} {group_name}")
else:
print(f"Asking Gemini Pro : {label} {group_name}")
full_response_text = "" full_response_text = call_gemini_with_retries(model_to_use, contents, config)
# Assuming client is thread-safe (usually is).
# If not, create a new client instance inside this function.
for chunk in client.models.generate_content_stream(
model=MODEL_ID_flash if use_flash else MODEL_ID_pro,
contents=contents,
config=config,
):
if chunk.text:
full_response_text += chunk.text
# Parse JSON
json_data = json.loads(full_response_text) json_data = json.loads(full_response_text)
# print("Debug : ", json_data)
# Ensure consistency of answer placements # Ensure consistency of answer placements
for p in json_data: for p in json_data:
pid = p["id"] pid = p["id"]
@ -305,20 +310,17 @@ def process_single_task(task_tuple):
continue continue
yming, ymaxg, width_r = d_data[pid] yming, ymaxg, width_r = d_data[pid]
if ymin < yming-50 or ymax > ymaxg+50: if ymin < yming-50 or ymax > ymaxg+50:
print("Error : Gemini answered box2d too low/up", print("Error : Gemini answered box2d too low/up", pid, label, group_name)
pid, label, group_name)
if ymax < yming or ymin > ymaxg: if ymax < yming or ymin > ymaxg:
print("Removing the box.") print("Removing the box.")
f["box_2d"] = None f["box_2d"] = None
continue continue
nymin = max(ymin, yming) * 1000 // total_height nymin = max(ymin, yming) * 1000 // total_height
nymax = min(ymax, ymaxg) * 1000 // total_height nymax = min(ymax, ymaxg) * 1000 // total_height
f["box_2d"] = [nymin, xmin, nymax, xmax] f["box_2d"] = [nymin, xmin, nymax, xmax]
# print("Group :", yming, ymaxg, "Answered:", ymin, ymax)
if xmax / 1000 > width_r: if f["box_2d"] and xmax / 1000 > width_r:
print("Error : Gemini answered box2d too right", print("Error : Gemini answered box2d too right", pid, label, group_name)
pid, label, group_name)
if xmin / 1000 > width_r: if xmin / 1000 > width_r:
print("Removing the box.") print("Removing the box.")
f["box_2d"] = None f["box_2d"] = None
@ -328,10 +330,9 @@ def process_single_task(task_tuple):
# --- Use Lock for writing shared data --- # --- Use Lock for writing shared data ---
with io_lock: with io_lock:
if label not in results: if label not in results:
results[label] = [] # Ensure key exists if not using defaultdict results[label] = []
results[label].append(json_data) results[label].append(json_data)
# Save Results
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2) json.dump(results, f, indent=2)