Logging suport : in correction_log
parent
173e77a64a
commit
bd35e69534
291
correction.py
291
correction.py
|
|
@ -165,6 +165,47 @@ MODEL_ID_pro = "gemini-3.1-pro-preview"
|
|||
MODEL_ID_flash = "gemini-3-flash-preview"
|
||||
api_key = os.environ["GEMINI_API_KEY"]
|
||||
|
||||
import signal
|
||||
import sys
|
||||
|
||||
# --- Thread-safe Logging ---
|
||||
log_lock = threading.Lock()
|
||||
thread_logs = {}
|
||||
|
||||
def tprint(*args, **kwargs):
|
||||
"""Buffer messages per thread to group them."""
|
||||
tid = threading.current_thread().name
|
||||
msg = " ".join(map(str, args))
|
||||
|
||||
with log_lock:
|
||||
if tid not in thread_logs:
|
||||
thread_logs[tid] = []
|
||||
thread_logs[tid].append(msg)
|
||||
|
||||
Optional: Keep printing to console but prefix with thread name
|
||||
print(f"[{tid}] {msg}", **kwargs)
|
||||
|
||||
def flush_thread_log(tid=None):
|
||||
"""Append a thread's buffered messages to the log file contiguously."""
|
||||
tid = tid or threading.current_thread().name
|
||||
with log_lock:
|
||||
if thread_logs.get(tid):
|
||||
with open("correction_log", "a", encoding="utf-8") as f:
|
||||
f.write(f"--- Task Log [{tid}] ---\n")
|
||||
f.write("\n".join(thread_logs[tid]) + "\n\n")
|
||||
thread_logs[tid].clear()
|
||||
|
||||
def handle_interrupt(sig, frame):
|
||||
"""Flush all partial/unfinished logs if program is interrupted."""
|
||||
print("\nInterrupt received. Flushing partial logs...", file=sys.stderr)
|
||||
for tid in list(thread_logs.keys()):
|
||||
flush_thread_log(tid)
|
||||
sys.exit(1)
|
||||
|
||||
signal.signal(signal.SIGINT, handle_interrupt)
|
||||
signal.signal(signal.SIGTERM, handle_interrupt)
|
||||
# ---------------------------
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
|
@ -183,7 +224,6 @@ class EvaluationEntry(BaseModel):
|
|||
result: ResultData = Field(description="Result details")
|
||||
|
||||
# The root model for parsing is be: List[EvaluationEntry]
|
||||
|
||||
def generate_request(file, full_label):
|
||||
"""Generates request for Gemini."""
|
||||
prompt = make_prompt(full_label)
|
||||
|
|
@ -271,16 +311,16 @@ def call_gemini_with_retries(model_id, contents, config,
|
|||
|
||||
# Immediately fallback to Flash without waiting if it's a Pro quota error
|
||||
if is_quota_error and model_id == MODEL_ID_pro and fallback_model_id:
|
||||
print(f"\tGemini Pro quota hit ({e}). Falling back to Flash permanently...")
|
||||
tprint(f"\tGemini Pro quota hit ({e}). Falling back to Flash permanently...")
|
||||
model_id = fallback_model_id
|
||||
pro_quota_exhausted = True
|
||||
continue # Retry immediately with Flash
|
||||
|
||||
if attempt < 2:
|
||||
print(f"\tGemini API failure: {e}. Retrying in {delays[attempt]} seconds...")
|
||||
tprint(f"\tGemini API failure: {e}. Retrying in {delays[attempt]} seconds...")
|
||||
time.sleep(delays[attempt])
|
||||
else:
|
||||
print(f"\tGemini API failure: {e}. Maximum retries reached.")
|
||||
tprint(f"\tGemini API failure: {e}. Maximum retries reached.")
|
||||
raise
|
||||
|
||||
import io
|
||||
|
|
@ -395,7 +435,7 @@ def handle_label_errors(pid, label, res, pdf_path):
|
|||
enonce = enonce_total(INPUT_DIR)
|
||||
|
||||
if error_type == "wrong-label":
|
||||
print(f"\tHandling wrong-label for {pid} {label}")
|
||||
tprint(f"\tHandling wrong-label for {pid} {label}")
|
||||
prompt = f"""This image is a part of the answer of a student to a written exam.
|
||||
|
||||
It was initially labeled '{label}' but I suspect this label is wrong. Perhaps the student himself wrote the wrong label.
|
||||
|
|
@ -419,7 +459,7 @@ Here is a list of all possible lables. You need to answer with one of these :
|
|||
config = types.GenerateContentConfig(temperature=0.0)
|
||||
new_label = call_gemini_with_retries(MODEL_ID_flash, contents, config).strip().strip('"\'')
|
||||
if new_label not in all_labels:
|
||||
print(f"\t\tCopie{pid} returned an incorrect label {new_label} from an initial wrong label {label}. Ignoring")
|
||||
tprint(f"\t\tCopie{pid} returned an incorrect label {new_label} from an initial wrong label {label}. Ignoring")
|
||||
res["error"] = "wrg-lbl:cldtfix"
|
||||
return []
|
||||
if new_label == label:
|
||||
|
|
@ -427,17 +467,17 @@ Here is a list of all possible lables. You need to answer with one of these :
|
|||
return []
|
||||
new_pdf_path = Path(INPUT_DIR) / f"Copie{pid}" / f"{new_label}.pdf"
|
||||
if new_pdf_path.exists():
|
||||
print(f"\t\tCopie{pid} tried to move wrong {label} to {new_label}, but it already exists.")
|
||||
tprint(f"\t\tCopie{pid} tried to move wrong {label} to {new_label}, but it already exists.")
|
||||
res["error"] = f"wrg-lbl:{new_label}?exists"
|
||||
else:
|
||||
print(f"\t\tCopie{pid} : moving wrong {label} to {new_label}.")
|
||||
tprint(f"\t\tCopie{pid} : moving wrong {label} to {new_label}.")
|
||||
shutil.move(str(pdf_path), str(new_pdf_path))
|
||||
# Since we moved the file, this Copie/label should not be taken
|
||||
# into account in the future, I think
|
||||
idx = get_next_group_idx(INPUT_DIR, new_label)
|
||||
height = grouping.get_pdf_height(str(new_pdf_path))
|
||||
grouping.create_jpg(new_label, idx, [(pid, str(new_pdf_path), height)], INPUT_DIR)
|
||||
print(f"\t\tMaking {new_label} group {idx+1}")
|
||||
tprint(f"\t\tMaking {new_label} group {idx+1}")
|
||||
new_tasks.append((str(Path(INPUT_DIR) / new_label / f"Group_{idx+1}.jpg"),
|
||||
new_label, False))
|
||||
|
||||
|
|
@ -458,7 +498,7 @@ Here is a list of all possible labels. You need to answer with a list one of the
|
|||
|
||||
{labels_txt}
|
||||
"""
|
||||
print(f"\tHandling additional-answer for {pid} {label}")
|
||||
tprint(f"\tHandling additional-answer for {pid} {label}")
|
||||
contents = [types.Content(role="user", parts=[
|
||||
types.Part.from_bytes(data=get_single_image_bytes(pdf_path), mime_type="image/jpeg"),
|
||||
types.Part.from_text(text=prompt)
|
||||
|
|
@ -469,28 +509,28 @@ Here is a list of all possible labels. You need to answer with a list one of the
|
|||
except Exception:
|
||||
add_labels = []
|
||||
|
||||
print(f"\tHandling additional-answer for {pid} {label}")
|
||||
tprint(f"\tHandling additional-answer for {pid} {label}")
|
||||
keep_error = False
|
||||
for add_label in add_labels:
|
||||
if add_label == label:
|
||||
continue
|
||||
if add_label not in all_labels:
|
||||
print(f"\t\t Inexistent label from additional-answer processing {pid} {label}. Ignoring")
|
||||
tprint(f"\t\t Inexistent label from additional-answer processing {pid} {label}. Ignoring")
|
||||
keep_error = True
|
||||
continue
|
||||
new_pdf_path = Path(INPUT_DIR) / f"Copie{pid}" / f"{add_label}.pdf"
|
||||
if not new_pdf_path.exists():
|
||||
shutil.copy(str(pdf_path), str(new_pdf_path))
|
||||
print(f"\t\tCopying Copie{pid} : {label} -> {add_label}")
|
||||
tprint(f"\t\tCopying Copie{pid} : {label} -> {add_label}")
|
||||
idx = get_next_group_idx(INPUT_DIR, add_label)
|
||||
print(f"\t\tMaking {add_label} group {idx+1}")
|
||||
tprint(f"\t\tMaking {add_label} group {idx+1}")
|
||||
height = grouping.get_pdf_height(str(new_pdf_path))
|
||||
grouping.create_jpg(add_label, idx, [(pid, str(new_pdf_path), height)], INPUT_DIR)
|
||||
new_tasks.append((str(Path(INPUT_DIR) / add_label / f"Group_{idx+1}.jpg"),
|
||||
add_label, False))
|
||||
else:
|
||||
keep_error = True
|
||||
print(f"\t\tAlready present (not copied) Copie{pid} : {label} -> {add_label}")
|
||||
tprint(f"\t\tAlready present (not copied) Copie{pid} : {label} -> {add_label}")
|
||||
|
||||
|
||||
if not keep_error:
|
||||
|
|
@ -499,130 +539,133 @@ Here is a list of all possible labels. You need to answer with a list one of the
|
|||
return new_tasks
|
||||
|
||||
def process_single_task(task_tuple):
|
||||
global pro_count, flash_count, pro_quota_exhausted
|
||||
file_path = task_tuple[0]
|
||||
label = task_tuple[1]
|
||||
can_spawn_tasks = task_tuple[2] if len(task_tuple) > 2 else True
|
||||
|
||||
group_name = os.path.splitext(file_path)[0]
|
||||
json_path = group_name + '.json'
|
||||
new_tasks = []
|
||||
|
||||
with open(json_path, 'r') as f:
|
||||
group_data = json.load(f)
|
||||
|
||||
n = len(group_data)
|
||||
d_data = {l[0]: (l[1], l[2], l[3]) for l in group_data}
|
||||
total_height = group_data[-1][2]
|
||||
use_flash = n >= 4 or total_height <= 500
|
||||
|
||||
if not use_flash:
|
||||
with pro_lock:
|
||||
if pro_quota_exhausted:
|
||||
use_flash = True
|
||||
elif limit is None or pro_count < limit:
|
||||
pro_count += 1
|
||||
else:
|
||||
use_flash = True
|
||||
|
||||
if use_flash:
|
||||
with pro_lock:
|
||||
flash_count += 1
|
||||
|
||||
try:
|
||||
contents, config = generate_request(file_path, label)
|
||||
model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
|
||||
print(f"Asking Gemini {'Flash' if use_flash else 'Pro '}: {label} {group_name}")
|
||||
global pro_count, flash_count, pro_quota_exhausted
|
||||
file_path = task_tuple[0]
|
||||
label = task_tuple[1]
|
||||
can_spawn_tasks = task_tuple[2] if len(task_tuple) > 2 else True
|
||||
|
||||
full_response_text = call_gemini_with_retries(model_to_use, contents, config)
|
||||
json_data = json.loads(full_response_text)
|
||||
group_name = os.path.splitext(file_path)[0]
|
||||
json_path = group_name + '.json'
|
||||
new_tasks = []
|
||||
|
||||
# Ensure consistency of answer placements
|
||||
for p in json_data:
|
||||
pid = p["id"]
|
||||
res = p["result"]
|
||||
yming, ymaxg, width_r = d_data[pid]
|
||||
with open(json_path, 'r') as f:
|
||||
group_data = json.load(f)
|
||||
|
||||
pdf_path = Path(INPUT_DIR) / f"Copie{pid}" / f"{label}.pdf"
|
||||
if res["error"] != "":
|
||||
print("\tError :", res["error"], "for Copie", pid, group_name)
|
||||
n = len(group_data)
|
||||
d_data = {l[0]: (l[1], l[2], l[3]) for l in group_data}
|
||||
total_height = group_data[-1][2]
|
||||
use_flash = n >= 4 or total_height <= 500
|
||||
|
||||
if can_spawn_tasks and res.get("error") in ["wrong-label", "additional-answer"]:
|
||||
new_tasks.extend(handle_label_errors(pid, label, res, pdf_path))
|
||||
if not use_flash:
|
||||
with pro_lock:
|
||||
if pro_quota_exhausted:
|
||||
use_flash = True
|
||||
elif limit is None or pro_count < limit:
|
||||
pro_count += 1
|
||||
else:
|
||||
use_flash = True
|
||||
|
||||
needs_correction = []
|
||||
for (i,f) in enumerate(res["feedback"]):
|
||||
b = f["box_2d"]
|
||||
if b:
|
||||
ymin, xmin, ymax, xmax = b
|
||||
ymin = ymin * total_height // 1000
|
||||
ymax = ymax * total_height // 1000
|
||||
if use_flash:
|
||||
with pro_lock:
|
||||
flash_count += 1
|
||||
|
||||
if pid not in d_data:
|
||||
print("Error : Gemini answered a copie id not present",
|
||||
pid, label, group_name)
|
||||
continue
|
||||
try:
|
||||
contents, config = generate_request(file_path, label)
|
||||
model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
|
||||
tprint(f"Asking Gemini {'Flash' if use_flash else 'Pro '}: {label} {group_name}")
|
||||
|
||||
if (ymin < yming - 50 or
|
||||
ymax > ymaxg + 50 or
|
||||
xmax / 1000 > width_r):
|
||||
needs_correction.append(i)
|
||||
break
|
||||
full_response_text = call_gemini_with_retries(model_to_use, contents, config)
|
||||
json_data = json.loads(full_response_text)
|
||||
|
||||
#
|
||||
# if ymin < yming-50 or ymax > ymaxg+50:
|
||||
# print("Error : Gemini answered box2d too low/up", pid, label, group_name)
|
||||
# if ymax < yming or ymin > ymaxg:
|
||||
# print("Removing the box.")
|
||||
# f["box_2d"] = None
|
||||
# continue
|
||||
# nymin = max(ymin, yming) * 1000 // total_height
|
||||
# nymax = min(ymax, ymaxg) * 1000 // total_height
|
||||
# f["box_2d"] = [nymin, xmin, nymax, xmax]
|
||||
# Ensure consistency of answer placements
|
||||
for p in json_data:
|
||||
pid = p["id"]
|
||||
res = p["result"]
|
||||
yming, ymaxg, width_r = d_data[pid]
|
||||
|
||||
# if f["box_2d"] and xmax / 1000 > width_r:
|
||||
# print("Error : Gemini answered box2d too right", pid, label, group_name)
|
||||
# if xmin / 1000 > width_r:
|
||||
# print("Removing the box.")
|
||||
# f["box_2d"] = None
|
||||
# continue
|
||||
# f["box_2d"][3] = int(width_r * 1000)
|
||||
pdf_path = Path(INPUT_DIR) / f"Copie{pid}" / f"{label}.pdf"
|
||||
if res["error"] != "":
|
||||
tprint("\tError :", res["error"], "for Copie", pid, group_name)
|
||||
|
||||
if needs_correction:
|
||||
print(f"\tBox anomalies detected for Copie {pid} {group_name}. \n\tRequesting isolated correction from Gemini Flash...")
|
||||
try:
|
||||
res["feedback"] = correct_boxes_with_gemini(
|
||||
pid, label, res["feedback"], INPUT_DIR,
|
||||
yming, ymaxg, width_r, total_height)
|
||||
except Exception as e:
|
||||
print(f"\tCorrection failed for Copie {pid}, {group_name} : {e}\n\tRemoving the boxes")
|
||||
# Fallback if the second request fails entirely
|
||||
for (i, f) in enumerate(res["feedback"]):
|
||||
if i in needs_correction:
|
||||
f["box_2d"] = None
|
||||
if can_spawn_tasks and res.get("error") in ["wrong-label", "additional-answer"]:
|
||||
new_tasks.extend(handle_label_errors(pid, label, res, pdf_path))
|
||||
|
||||
# --- Use Lock for writing shared data ---
|
||||
with io_lock:
|
||||
if label not in results:
|
||||
results[label] = []
|
||||
results[label].append(json_data)
|
||||
needs_correction = []
|
||||
for (i,f) in enumerate(res["feedback"]):
|
||||
b = f["box_2d"]
|
||||
if b:
|
||||
ymin, xmin, ymax, xmax = b
|
||||
ymin = ymin * total_height // 1000
|
||||
ymax = ymax * total_height // 1000
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
if pid not in d_data:
|
||||
tprint("Error : Gemini answered a copie id not present",
|
||||
pid, label, group_name)
|
||||
continue
|
||||
|
||||
# To track progress
|
||||
completed_tasks.append((file_path, label))
|
||||
with open(progress_path, "w", encoding="utf-8") as f:
|
||||
json.dump(completed_tasks, f, indent=2)
|
||||
if (ymin < yming - 50 or
|
||||
ymax > ymaxg + 50 or
|
||||
xmax / 1000 > width_r):
|
||||
needs_correction.append(i)
|
||||
break
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error decoding JSON for {file_path}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
error_msg = f"Exception processing {file_path}: {e}"
|
||||
print(error_msg, file=sys.stderr)
|
||||
with io_lock:
|
||||
errors_summary.append((error_msg, file_path))
|
||||
return new_tasks
|
||||
#
|
||||
# if ymin < yming-50 or ymax > ymaxg+50:
|
||||
# print("Error : Gemini answered box2d too low/up", pid, label, group_name)
|
||||
# if ymax < yming or ymin > ymaxg:
|
||||
# print("Removing the box.")
|
||||
# f["box_2d"] = None
|
||||
# continue
|
||||
# nymin = max(ymin, yming) * 1000 // total_height
|
||||
# nymax = min(ymax, ymaxg) * 1000 // total_height
|
||||
# f["box_2d"] = [nymin, xmin, nymax, xmax]
|
||||
|
||||
# if f["box_2d"] and xmax / 1000 > width_r:
|
||||
# print("Error : Gemini answered box2d too right", pid, label, group_name)
|
||||
# if xmin / 1000 > width_r:
|
||||
# print("Removing the box.")
|
||||
# f["box_2d"] = None
|
||||
# continue
|
||||
# f["box_2d"][3] = int(width_r * 1000)
|
||||
|
||||
if needs_correction:
|
||||
tprint(f"\tBox anomalies detected for Copie {pid} {group_name}. \n\tRequesting isolated correction from Gemini Flash...")
|
||||
try:
|
||||
res["feedback"] = correct_boxes_with_gemini(
|
||||
pid, label, res["feedback"], INPUT_DIR,
|
||||
yming, ymaxg, width_r, total_height)
|
||||
except Exception as e:
|
||||
tprint(f"\tCorrection failed for Copie {pid}, {group_name} : {e}\n\tRemoving the boxes")
|
||||
# Fallback if the second request fails entirely
|
||||
for (i, f) in enumerate(res["feedback"]):
|
||||
if i in needs_correction:
|
||||
f["box_2d"] = None
|
||||
|
||||
# --- Use Lock for writing shared data ---
|
||||
with io_lock:
|
||||
if label not in results:
|
||||
results[label] = []
|
||||
results[label].append(json_data)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
# To track progress
|
||||
completed_tasks.append((file_path, label))
|
||||
with open(progress_path, "w", encoding="utf-8") as f:
|
||||
json.dump(completed_tasks, f, indent=2)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
tprint(f"Error decoding JSON for {file_path}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
error_msg = f"Exception processing {file_path}: {e}"
|
||||
print(error_msg, file=sys.stderr)
|
||||
with io_lock:
|
||||
errors_summary.append((error_msg, file_path))
|
||||
return new_tasks
|
||||
finally:
|
||||
flush_thread_log()
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...")
|
||||
|
|
|
|||
Loading…
Reference in New Issue