Add Gemini retries
parent
2677e41b04
commit
a98375f1e8
|
|
@ -236,7 +236,8 @@ def render_real_latex_text(text, width_px, bg_color=(255, 255, 255, 255), max_li
|
|||
\\usepackage[T1]{{fontenc}}
|
||||
\\usepackage{{lmodern}} % Enables arbitrary font scaling
|
||||
\\usepackage{{amsmath, amssymb}}
|
||||
%\\usepackage{{anyfontsize}} % replaces by lmodern
|
||||
\\usepackage{{commands}}
|
||||
%\\usepackage{{anyfontsize}} % replaced by lmodern
|
||||
\\begin{{document}}
|
||||
\\fontsize{{{fontsize}}}{{{line_spacing}}}\\selectfont
|
||||
{text}
|
||||
|
|
|
|||
|
|
@ -217,8 +217,8 @@ completed_tasks = []
|
|||
|
||||
# --- Lock for thread-safe file writing ---
|
||||
io_lock = threading.Lock()
|
||||
pro_lock = threading.Lock() # New lock for counter
|
||||
pro_count = 0 # New counter
|
||||
pro_lock = threading.Lock()
|
||||
pro_count = 0
|
||||
flash_count = 0
|
||||
|
||||
if overwrite:
|
||||
|
|
@ -230,63 +230,68 @@ else:
|
|||
if progress_path.exists():
|
||||
with open(progress_path, "r", encoding="utf-8") as f:
|
||||
completed_tasks = json.load(f)
|
||||
# Reload existing results to avoid overwriting them with partial data
|
||||
if output_path.exists():
|
||||
with open(output_path, "r", encoding="utf-8") as f:
|
||||
results = json.load(f)
|
||||
|
||||
# Create a set for O(1) lookup. Normalize paths to strings.
|
||||
completed_set = set((str(f), l) for f, l in completed_tasks)
|
||||
|
||||
# Filter tasks first to avoid overhead in threads
|
||||
tasks_to_process = [t for t in tasks if (str(t[0]), t[1]) not in completed_set]
|
||||
|
||||
def call_gemini_with_retries(model_id, contents, config):
|
||||
"""Handles requests to Gemini with a 1min and 5min retry mechanism."""
|
||||
delays = [60, 300]
|
||||
for attempt in range(3):
|
||||
try:
|
||||
full_response_text = ""
|
||||
for chunk in client.models.generate_content_stream(
|
||||
model=model_id,
|
||||
contents=contents,
|
||||
config=config,
|
||||
):
|
||||
if chunk.text:
|
||||
full_response_text += chunk.text
|
||||
return full_response_text
|
||||
except Exception as e:
|
||||
if attempt < 2:
|
||||
print(f"\tGemini API failure: {e}. Retrying in {delays[attempt]} seconds...")
|
||||
time.sleep(delays[attempt])
|
||||
else:
|
||||
print(f"\tGemini API failure: {e}. Maximum retries reached.")
|
||||
raise
|
||||
|
||||
def process_single_task(task_tuple):
|
||||
global pro_count, flash_count
|
||||
file_path, label = task_tuple
|
||||
group_name = os.path.splitext(file_path)[0]
|
||||
json_path = group_name + '.json'
|
||||
|
||||
with open(json_path, 'r') as f:
|
||||
# List of (groupid, start, end), in pixels
|
||||
group_data = json.load(f)
|
||||
|
||||
n = len(group_data)
|
||||
# l[3] is ratio of width to width of group
|
||||
d_data = {l[0]: (l[1], l[2], l[3]) for l in group_data}
|
||||
total_height = group_data[-1][2]
|
||||
use_flash = n >= 4 or total_height <= 500
|
||||
|
||||
if not use_flash:
|
||||
with pro_lock:
|
||||
if limit is None or pro_count < limit:
|
||||
pro_count += 1
|
||||
else:
|
||||
# Limit reached, force switch to Flash
|
||||
use_flash = True
|
||||
|
||||
if use_flash:
|
||||
with pro_lock:
|
||||
flash_count += 1
|
||||
|
||||
try:
|
||||
contents, config = generate_request(file_path, label)
|
||||
if use_flash:
|
||||
print(f"Asking Gemini Flash: {label} {group_name}")
|
||||
else:
|
||||
print(f"Asking Gemini Pro : {label} {group_name}")
|
||||
model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
|
||||
print(f"Asking Gemini {'Flash' if use_flash else 'Pro '}: {label} {group_name}")
|
||||
|
||||
full_response_text = ""
|
||||
# Assuming client is thread-safe (usually is).
|
||||
# If not, create a new client instance inside this function.
|
||||
for chunk in client.models.generate_content_stream(
|
||||
model=MODEL_ID_flash if use_flash else MODEL_ID_pro,
|
||||
contents=contents,
|
||||
config=config,
|
||||
):
|
||||
if chunk.text:
|
||||
full_response_text += chunk.text
|
||||
|
||||
# Parse JSON
|
||||
full_response_text = call_gemini_with_retries(model_to_use, contents, config)
|
||||
json_data = json.loads(full_response_text)
|
||||
|
||||
# print("Debug : ", json_data)
|
||||
# Ensure consistency of answer placements
|
||||
for p in json_data:
|
||||
pid = p["id"]
|
||||
|
|
@ -305,20 +310,17 @@ def process_single_task(task_tuple):
|
|||
continue
|
||||
yming, ymaxg, width_r = d_data[pid]
|
||||
if ymin < yming-50 or ymax > ymaxg+50:
|
||||
print("Error : Gemini answered box2d too low/up",
|
||||
pid, label, group_name)
|
||||
print("Error : Gemini answered box2d too low/up", pid, label, group_name)
|
||||
if ymax < yming or ymin > ymaxg:
|
||||
print("Removing the box.")
|
||||
f["box_2d"] = None
|
||||
continue
|
||||
nymin = max(ymin, yming) * 1000 // total_height
|
||||
nymax = min(ymax, ymaxg) * 1000 // total_height
|
||||
|
||||
f["box_2d"] = [nymin, xmin, nymax, xmax]
|
||||
# print("Group :", yming, ymaxg, "Answered:", ymin, ymax)
|
||||
if xmax / 1000 > width_r:
|
||||
print("Error : Gemini answered box2d too right",
|
||||
pid, label, group_name)
|
||||
|
||||
if f["box_2d"] and xmax / 1000 > width_r:
|
||||
print("Error : Gemini answered box2d too right", pid, label, group_name)
|
||||
if xmin / 1000 > width_r:
|
||||
print("Removing the box.")
|
||||
f["box_2d"] = None
|
||||
|
|
@ -328,10 +330,9 @@ def process_single_task(task_tuple):
|
|||
# --- Use Lock for writing shared data ---
|
||||
with io_lock:
|
||||
if label not in results:
|
||||
results[label] = [] # Ensure key exists if not using defaultdict
|
||||
results[label] = []
|
||||
results[label].append(json_data)
|
||||
|
||||
# Save Results
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue