Add Gemini retries
parent
2677e41b04
commit
a98375f1e8
|
|
@ -236,7 +236,8 @@ def render_real_latex_text(text, width_px, bg_color=(255, 255, 255, 255), max_li
|
||||||
\\usepackage[T1]{{fontenc}}
|
\\usepackage[T1]{{fontenc}}
|
||||||
\\usepackage{{lmodern}} % Enables arbitrary font scaling
|
\\usepackage{{lmodern}} % Enables arbitrary font scaling
|
||||||
\\usepackage{{amsmath, amssymb}}
|
\\usepackage{{amsmath, amssymb}}
|
||||||
%\\usepackage{{anyfontsize}} % replaces by lmodern
|
\\usepackage{{commands}}
|
||||||
|
%\\usepackage{{anyfontsize}} % replaced by lmodern
|
||||||
\\begin{{document}}
|
\\begin{{document}}
|
||||||
\\fontsize{{{fontsize}}}{{{line_spacing}}}\\selectfont
|
\\fontsize{{{fontsize}}}{{{line_spacing}}}\\selectfont
|
||||||
{text}
|
{text}
|
||||||
|
|
|
||||||
|
|
@ -217,8 +217,8 @@ completed_tasks = []
|
||||||
|
|
||||||
# --- Lock for thread-safe file writing ---
|
# --- Lock for thread-safe file writing ---
|
||||||
io_lock = threading.Lock()
|
io_lock = threading.Lock()
|
||||||
pro_lock = threading.Lock() # New lock for counter
|
pro_lock = threading.Lock()
|
||||||
pro_count = 0 # New counter
|
pro_count = 0
|
||||||
flash_count = 0
|
flash_count = 0
|
||||||
|
|
||||||
if overwrite:
|
if overwrite:
|
||||||
|
|
@ -230,63 +230,68 @@ else:
|
||||||
if progress_path.exists():
|
if progress_path.exists():
|
||||||
with open(progress_path, "r", encoding="utf-8") as f:
|
with open(progress_path, "r", encoding="utf-8") as f:
|
||||||
completed_tasks = json.load(f)
|
completed_tasks = json.load(f)
|
||||||
# Reload existing results to avoid overwriting them with partial data
|
|
||||||
if output_path.exists():
|
if output_path.exists():
|
||||||
with open(output_path, "r", encoding="utf-8") as f:
|
with open(output_path, "r", encoding="utf-8") as f:
|
||||||
results = json.load(f)
|
results = json.load(f)
|
||||||
|
|
||||||
# Create a set for O(1) lookup. Normalize paths to strings.
|
|
||||||
completed_set = set((str(f), l) for f, l in completed_tasks)
|
completed_set = set((str(f), l) for f, l in completed_tasks)
|
||||||
|
|
||||||
# Filter tasks first to avoid overhead in threads
|
|
||||||
tasks_to_process = [t for t in tasks if (str(t[0]), t[1]) not in completed_set]
|
tasks_to_process = [t for t in tasks if (str(t[0]), t[1]) not in completed_set]
|
||||||
|
|
||||||
|
def call_gemini_with_retries(model_id, contents, config):
|
||||||
|
"""Handles requests to Gemini with a 1min and 5min retry mechanism."""
|
||||||
|
delays = [60, 300]
|
||||||
|
for attempt in range(3):
|
||||||
|
try:
|
||||||
|
full_response_text = ""
|
||||||
|
for chunk in client.models.generate_content_stream(
|
||||||
|
model=model_id,
|
||||||
|
contents=contents,
|
||||||
|
config=config,
|
||||||
|
):
|
||||||
|
if chunk.text:
|
||||||
|
full_response_text += chunk.text
|
||||||
|
return full_response_text
|
||||||
|
except Exception as e:
|
||||||
|
if attempt < 2:
|
||||||
|
print(f"\tGemini API failure: {e}. Retrying in {delays[attempt]} seconds...")
|
||||||
|
time.sleep(delays[attempt])
|
||||||
|
else:
|
||||||
|
print(f"\tGemini API failure: {e}. Maximum retries reached.")
|
||||||
|
raise
|
||||||
|
|
||||||
def process_single_task(task_tuple):
|
def process_single_task(task_tuple):
|
||||||
global pro_count, flash_count
|
global pro_count, flash_count
|
||||||
file_path, label = task_tuple
|
file_path, label = task_tuple
|
||||||
group_name = os.path.splitext(file_path)[0]
|
group_name = os.path.splitext(file_path)[0]
|
||||||
json_path = group_name + '.json'
|
json_path = group_name + '.json'
|
||||||
|
|
||||||
with open(json_path, 'r') as f:
|
with open(json_path, 'r') as f:
|
||||||
# List of (groupid, start, end), in pixels
|
|
||||||
group_data = json.load(f)
|
group_data = json.load(f)
|
||||||
|
|
||||||
n = len(group_data)
|
n = len(group_data)
|
||||||
# l[3] is ratio of width to width of group
|
|
||||||
d_data = {l[0]: (l[1], l[2], l[3]) for l in group_data}
|
d_data = {l[0]: (l[1], l[2], l[3]) for l in group_data}
|
||||||
total_height = group_data[-1][2]
|
total_height = group_data[-1][2]
|
||||||
use_flash = n >= 4 or total_height <= 500
|
use_flash = n >= 4 or total_height <= 500
|
||||||
|
|
||||||
if not use_flash:
|
if not use_flash:
|
||||||
with pro_lock:
|
with pro_lock:
|
||||||
if limit is None or pro_count < limit:
|
if limit is None or pro_count < limit:
|
||||||
pro_count += 1
|
pro_count += 1
|
||||||
else:
|
else:
|
||||||
# Limit reached, force switch to Flash
|
|
||||||
use_flash = True
|
use_flash = True
|
||||||
|
|
||||||
if use_flash:
|
if use_flash:
|
||||||
with pro_lock:
|
with pro_lock:
|
||||||
flash_count += 1
|
flash_count += 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
contents, config = generate_request(file_path, label)
|
contents, config = generate_request(file_path, label)
|
||||||
if use_flash:
|
model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
|
||||||
print(f"Asking Gemini Flash: {label} {group_name}")
|
print(f"Asking Gemini {'Flash' if use_flash else 'Pro '}: {label} {group_name}")
|
||||||
else:
|
|
||||||
print(f"Asking Gemini Pro : {label} {group_name}")
|
|
||||||
|
|
||||||
full_response_text = ""
|
full_response_text = call_gemini_with_retries(model_to_use, contents, config)
|
||||||
# Assuming client is thread-safe (usually is).
|
|
||||||
# If not, create a new client instance inside this function.
|
|
||||||
for chunk in client.models.generate_content_stream(
|
|
||||||
model=MODEL_ID_flash if use_flash else MODEL_ID_pro,
|
|
||||||
contents=contents,
|
|
||||||
config=config,
|
|
||||||
):
|
|
||||||
if chunk.text:
|
|
||||||
full_response_text += chunk.text
|
|
||||||
|
|
||||||
# Parse JSON
|
|
||||||
json_data = json.loads(full_response_text)
|
json_data = json.loads(full_response_text)
|
||||||
|
|
||||||
# print("Debug : ", json_data)
|
|
||||||
# Ensure consistency of answer placements
|
# Ensure consistency of answer placements
|
||||||
for p in json_data:
|
for p in json_data:
|
||||||
pid = p["id"]
|
pid = p["id"]
|
||||||
|
|
@ -305,20 +310,17 @@ def process_single_task(task_tuple):
|
||||||
continue
|
continue
|
||||||
yming, ymaxg, width_r = d_data[pid]
|
yming, ymaxg, width_r = d_data[pid]
|
||||||
if ymin < yming-50 or ymax > ymaxg+50:
|
if ymin < yming-50 or ymax > ymaxg+50:
|
||||||
print("Error : Gemini answered box2d too low/up",
|
print("Error : Gemini answered box2d too low/up", pid, label, group_name)
|
||||||
pid, label, group_name)
|
|
||||||
if ymax < yming or ymin > ymaxg:
|
if ymax < yming or ymin > ymaxg:
|
||||||
print("Removing the box.")
|
print("Removing the box.")
|
||||||
f["box_2d"] = None
|
f["box_2d"] = None
|
||||||
continue
|
continue
|
||||||
nymin = max(ymin, yming) * 1000 // total_height
|
nymin = max(ymin, yming) * 1000 // total_height
|
||||||
nymax = min(ymax, ymaxg) * 1000 // total_height
|
nymax = min(ymax, ymaxg) * 1000 // total_height
|
||||||
|
|
||||||
f["box_2d"] = [nymin, xmin, nymax, xmax]
|
f["box_2d"] = [nymin, xmin, nymax, xmax]
|
||||||
# print("Group :", yming, ymaxg, "Answered:", ymin, ymax)
|
|
||||||
if xmax / 1000 > width_r:
|
if f["box_2d"] and xmax / 1000 > width_r:
|
||||||
print("Error : Gemini answered box2d too right",
|
print("Error : Gemini answered box2d too right", pid, label, group_name)
|
||||||
pid, label, group_name)
|
|
||||||
if xmin / 1000 > width_r:
|
if xmin / 1000 > width_r:
|
||||||
print("Removing the box.")
|
print("Removing the box.")
|
||||||
f["box_2d"] = None
|
f["box_2d"] = None
|
||||||
|
|
@ -328,10 +330,9 @@ def process_single_task(task_tuple):
|
||||||
# --- Use Lock for writing shared data ---
|
# --- Use Lock for writing shared data ---
|
||||||
with io_lock:
|
with io_lock:
|
||||||
if label not in results:
|
if label not in results:
|
||||||
results[label] = [] # Ensure key exists if not using defaultdict
|
results[label] = []
|
||||||
results[label].append(json_data)
|
results[label].append(json_data)
|
||||||
|
|
||||||
# Save Results
|
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(results, f, indent=2)
|
json.dump(results, f, indent=2)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue