Copies/correction.py

837 lines
34 KiB
Python

import sys
import os
import time
from pathlib import Path
import argparse
if len(sys.argv) < 2:
sys.exit("Usage: python script.py InterroTest/Ex 2/Group_1.jpg OR <InputDir>")
arg_path = Path(sys.argv[1])
tasks = [] # List of tuples: (filepath_str, label_str)
results = {}
# Parse Arguments
parser = argparse.ArgumentParser()
parser.add_argument("--overwrite", action="store_true",
help="Force redo requests even if output exists")
parser.add_argument("--limit", type=int, help="limit calls to gemini rpo integer")
parser.add_argument("--refaire", action="store_true",
help="Redo specific copies/labels defined in refaire.json")
parser.add_argument("--batch", action="store_true",
help="Generate a JSONL file of requests to send to the Gemini Batch API")
parser.add_argument("--deal-with-batched", type=str, metavar="FILE",
help="Process a JSONL file containing completed batch results")
args, _ = parser.parse_known_args()
if arg_path.suffix == ".jpg":
INPUT_DIR = str(arg_path.parents[1])
FULL_LABEL = arg_path.parent.name
tasks.append((str(arg_path), FULL_LABEL))
results[FULL_LABEL] = []
else:
# Directory behaviour
INPUT_DIR = str(arg_path)
if not arg_path.exists():
sys.exit(f"Directory {INPUT_DIR} not found.")
for sub in arg_path.iterdir():
if sub.is_dir() and sub.name.startswith("Ex"):
label = sub.name
results[label] = []
for img in sub.glob("*.jpg"):
tasks.append((str(img), label))
my_prompt = """I'm giving you an image of several written answers to an exam.
Each answer is separated by a black horizontal line, and underneath,
to the left, is indicated the ID of the answer, from `01` to `50`.
I want you to score each answer, from 0 to 4, you may score half
points, such as 2.5. Even if a result is wrong, if the reasoning is
correct and could lead to a right answer, you should give at least
half the points.
You also need to give feedback to the student, in french :
- which part of his answer is wrong,
- why is it wrong
- possibly, what he should have done instead.
Your feedback may contain LaTeX fragments written like `$a^2 + b^2 = c^2$`.
If your score is not 4, you should always provide some feedback
explaining what's missing.
For each piece of feedback, if it is related to a specific part of the
answer that is wrong, you may provide a `box_2d`, to locate this
specific part of the answer. This `box_2d` should be in the form
[ymin, xmin, ymax, xmax] normalized to 0-1000. If you do not provide
one, set `box_2d` to `null`.
If the answer is correct, there is no need to provide feedback. You do
not have to give positive feedback, but if you do, do not provide a
`box_2d` for it.
For example, if the student says a function is continuous when it
isn't, provide the coordinates where the word «continuous» is. If a
calculation went wrong, gives the coordinates of the step where it
goes wrong, and as feedback, what went wrong.
Avoid giving feedback about confusing letters `n` with `m`, `x` with
`n` or `h` with `k`. If it looks wrong, assume you read it wrong,
unless the distinction is very important.
You should also give me a measure of confidence, from 0 to 1 that you
were able to correctly understand the answer. A score below 0.5 means
that you think it is likely that you couldn't understand an important
part.
In some case, you may find that either
- The student didn't answer the right question. Set the score to 0.
Since it could be a labeling error, indicate is by setting `error`
to \"wrong-label\".
- You can find an answer to another question of the exercice (taking
more than a couple of lines). Score the question you are supposed
to score, but set `error` to \"additional-answer\".
- The answer to the question is empty, or the student has only
rewritten the statement of the question. In this case, set `error`
to \"empty-answer\" and do not provide any kind of feedback.
If there's no error, set `error` to `\"\"`.
You will answer using json describing a list of dictionary with a key
\"id\", and a key \"result\" that contains the \"score\", the \"confidence\", a
list \"feedback\", and possibly an \"error\". Like this example :
[{ \"id\": \"01\",
\"result\": {\"score\" : 2.5,
\"confidence\" : 0.8,
\"feedback\": [{text: \"Un retour générique. Il faut apprendre le cours.\", box_2d: null},
{text: \"Non, la fonction n'est pas forcément continue\", pos: [145, 280, 340, 500]}],
\"error\": \"\"}
},
{ \"id\": \"04\",
\"result\": {\"score\" : 4.,
\"confidence\" : 0.9,
\"feedback\" : []
\"error\": \"\" }
}
]
Here is the text of the exercice (or the relevant part of the problem)
of the exam :
```
<<text>>
```
Here is a possible correct answer :
```
<<corr>>
```
<<persp>>
You are asked to score the question or exercice labeled `<<label>>`,
do not score or give feedback to any other question."""
def make_prompt(full_label):
# l = full_label.split(" ")
# ex_label = l[0] + " " + l[1]
# text = (Path(INPUT_DIR) / "Text" / ex_label).read_text()
# corr = (Path(INPUT_DIR) / "Sol" / ex_label).read_text()
# persp = (Path(INPUT_DIR) / "Persp" / ex_label).read_text()
def read_longest_prefix_file(subdir):
dir_path = Path(INPUT_DIR) / subdir
matches = [f for f in dir_path.iterdir() if f.is_file() and full_label.startswith(f.name)]
if not matches:
return ""
return max(matches, key=lambda f: len(f.name)).read_text()
text = read_longest_prefix_file("Text")
corr = read_longest_prefix_file("Sol")
persp = read_longest_prefix_file("Persp")
if persp != "":
persp = "\n\nHere are additional scoring instructions : \n\n```\n" + persp +"\n```\n"
return my_prompt.replace("<<text>>", text).replace("<<corr>>", corr).replace("<<persp>>", persp).replace("<<label>>", full_label)
from google import genai
from google.genai import types
import base64
import shlex
import json
from pathlib import Path
import os
import threading
import concurrent.futures
NB_THREADS = 12
# PROXY_URL = "http://192.168.241.1:3128"
PROXY_URL = None
if PROXY_URL:
os.environ["http_proxy"] = PROXY_URL
os.environ["https_proxy"] = PROXY_URL
MODEL_ID_pro = "gemini-3.1-pro-preview"
MODEL_ID_flash = "gemini-3-flash-preview"
api_key = os.environ["GEMINI_API_KEY"]
import signal
import sys
# --- Thread-safe Logging ---
log_lock = threading.Lock()
thread_logs = {}
def tprint(*args, **kwargs):
"""Buffer messages per thread to group them."""
tid = threading.current_thread().name
msg = " ".join(map(str, args))
with log_lock:
if tid not in thread_logs:
thread_logs[tid] = []
thread_logs[tid].append(msg)
# Optional: Keep printing to console but prefix with thread name
print(f"[{tid}] {msg}", **kwargs)
def flush_thread_log(tid=None):
"""Append a thread's buffered messages to the log file contiguously."""
tid = tid or threading.current_thread().name
with log_lock:
if thread_logs.get(tid):
with open(Path(INPUT_DIR) / "correction_log", "a", encoding="utf-8") as f:
f.write(f"--- Task Log [{tid}] ---\n")
f.write("\n".join(thread_logs[tid]) + "\n\n")
thread_logs[tid].clear()
def handle_interrupt(sig, frame):
"""Flush all partial/unfinished logs if program is interrupted."""
print("\nInterrupt received. Flushing partial logs...", file=sys.stderr)
for tid in list(thread_logs.keys()):
flush_thread_log(tid)
sys.exit(1)
signal.signal(signal.SIGINT, handle_interrupt)
signal.signal(signal.SIGTERM, handle_interrupt)
# ---------------------------
from pydantic import BaseModel, Field, TypeAdapter
from typing import List, Optional, Tuple
class FeedbackItem(BaseModel):
text: str = Field(description="Feedback content")
box_2d: Optional[List[int]] = Field(None, description="box coordinates or null")
class ResultData(BaseModel):
score: float = Field(description="The numeric score")
confidence: float = Field(description="Confidence level")
feedback: List[FeedbackItem] = Field(description="List of feedback items")
error: str = Field(description="Indicates if an error occurred")
class EvaluationEntry(BaseModel):
id: str = Field(description="Entry identifier")
result: ResultData = Field(description="Result details")
# The root model for parsing is be: List[EvaluationEntry]
def generate_request(file, full_label):
"""Generates request for Gemini."""
prompt = make_prompt(full_label)
image_path = Path(file)
contents = [
types.Content(
role="user",
parts=[
types.Part.from_bytes(
data=image_path.read_bytes(),
mime_type="image/jpeg"
),
types.Part.from_text(text=prompt),
],
)
]
generate_content_config = types.GenerateContentConfig(
temperature=1.0,
top_p=0.95,
seed=0,
max_output_tokens=65535,
response_mime_type= "application/json",
response_json_schema= TypeAdapter(List[EvaluationEntry]).json_schema()
)
return (contents, generate_content_config)
client = genai.Client(api_key=api_key)
output_path = Path(INPUT_DIR) / "correction.json"
progress_path = Path(INPUT_DIR) / "correction_progress.json"
start_time = time.time()
overwrite = args.overwrite
limit = args.limit
completed_tasks = []
errors_summary = []
# --- Lock for thread-safe file writing ---
io_lock = threading.Lock()
pro_lock = threading.Lock()
pro_count = 0
flash_count = 0
pro_quota_exhausted = False
if overwrite:
if output_path.exists():
output_path.unlink()
if progress_path.exists():
progress_path.unlink()
else:
if progress_path.exists():
with open(progress_path, "r", encoding="utf-8") as f:
completed_tasks = json.load(f)
if output_path.exists():
with open(output_path, "r", encoding="utf-8") as f:
results = json.load(f)
completed_set = set((str(f), l) for f, l in completed_tasks)
tasks_to_process = [t for t in tasks if (str(t[0]), t[1]) not in completed_set]
def call_gemini_with_retries(model_id, contents, config,
fallback_model_id=MODEL_ID_flash):
"""Handles requests to Gemini with a 1min and 5min retry mechanism, and quota fallback."""
global pro_quota_exhausted
delays = [60, 300]
for attempt in range(3):
# Switch to fallback immediately if quota was exhausted by another thread
if model_id == MODEL_ID_pro and pro_quota_exhausted and fallback_model_id:
model_id = fallback_model_id
try:
full_response_text = ""
for chunk in client.models.generate_content_stream(
model=model_id,
contents=contents,
config=config,
):
if chunk.text:
full_response_text += chunk.text
return full_response_text
except Exception as e:
error_msg = str(e).lower()
is_quota_error = "429" in error_msg or "quota" in error_msg or "exhausted" in error_msg
# Immediately fallback to Flash without waiting if it's a Pro quota error
if is_quota_error and model_id == MODEL_ID_pro and fallback_model_id:
tprint(f"\tGemini Pro quota hit ({e}). \n\n\tFalling back to Flash permanently...")
model_id = fallback_model_id
pro_quota_exhausted = True
continue # Retry immediately with Flash
if attempt < 2:
tprint(f"\tGemini API failure: {e}. Retrying in {delays[attempt]} seconds...")
time.sleep(delays[attempt])
else:
tprint(f"\tGemini API failure: {e}. Maximum retries reached.")
raise
import io
from pdf2image import convert_from_path
from PIL import Image
def get_single_image_bytes(pdf_path):
"""Converts a multi-page PDF into a single stitched JPEG in memory."""
imgs = convert_from_path(pdf_path, dpi=200) # Same DPI as grouping.py
if not imgs:
raise ValueError(f"No pages in {pdf_path}")
if len(imgs) == 1:
combined = imgs[0]
else:
max_width = max(img.width for img in imgs)
total_height = sum(img.height for img in imgs)
combined = Image.new('RGB', (max_width, total_height), 'white')
y_offset = 0
for img in imgs:
combined.paste(img, (0, y_offset))
y_offset += img.height
img_byte_arr = io.BytesIO()
combined.save(img_byte_arr, format='JPEG', quality=85)
return img_byte_arr.getvalue()
def correct_boxes_with_gemini(pid, label, original_feedbacks,
root_dir, yming, ymaxg, width_r, total_height):
"""Requests corrected bounding boxes from Gemini Flash on the single image."""
pdf_path = Path(root_dir) / f"Copie{pid}" / f"{label}.pdf"
img_bytes = get_single_image_bytes(pdf_path)
localized_feedbacks = [f for f in original_feedbacks if f["box_2d"]]
global_feedbacks = [f for f in original_feedbacks if not f["box_2d"]]
prompt = f"""
Here is a single student's submission to a question in a written exam. The following JSON contains feedback items with bounding boxes (box_2d) that are incorrect. Each piece of feedback is supposed to be related to a piece of the answer that is wrong.
For example, if the student says a function is continuous when it
isn't, the coordinates should be where the word «continuous» is. If a
calculation went wrong, the coordinates should be where the step where
it goes wrong, and the feedback is what went wrong.
Please analyze the image and return the exact same feedback text, but with ONLY the box_2d coordinates corrected for this specific image.
Coordinates must be [ymin, xmin, ymax, xmax] scaled to 1000. If a box is invalid/not found, return null for it.
Original feedback:
{json.dumps(localized_feedbacks, indent=2)}
"""
contents = [
types.Content(
role="user",
parts=[
types.Part.from_bytes(data=img_bytes, mime_type="image/jpeg"),
types.Part.from_text(text=prompt),
],
)
]
config = types.GenerateContentConfig(
temperature=0.0, # Low temperature for accurate correction
response_mime_type="application/json",
response_json_schema=TypeAdapter(List[FeedbackItem]).json_schema()
)
response_text = call_gemini_with_retries(MODEL_ID_flash, contents, config)
corrected_feedbacks = json.loads(response_text)
# Map the coordinates back from the single image to the group canvas
for f in corrected_feedbacks:
b = f.get("box_2d")
if b:
ymin_s, xmin_s, ymax_s, xmax_s = b
# Y mapping: Add the group Y-offset (yming), then normalize to total_height
single_h = ymaxg - yming
new_ymin = int((yming + (ymin_s * single_h / 1000.0)) * 1000.0 / total_height)
new_ymax = int((yming + (ymax_s * single_h / 1000.0)) * 1000.0 / total_height)
# X mapping: Multiply by the width ratio of this sub-image vs the group image
new_xmin = int(xmin_s * width_r)
new_xmax = int(xmax_s * width_r)
f["box_2d"] = [new_ymin, new_xmin, new_ymax, new_xmax]
return global_feedbacks + corrected_feedbacks
import shutil
import grouping
def get_next_group_idx(root_dir, label):
"""Finds the next available Group index for a given label."""
target_folder = Path(root_dir) / label
target_folder.mkdir(exist_ok=True)
existing = list(target_folder.glob("Group_*.jpg"))
if not existing: return 0
return max([int(f.stem.split("_")[1]) for f in existing])
from utils import read_all_labels, enonce_total
def handle_label_errors(pid, label, res, pdf_path):
"""Handles Gemini labeling errors, moves/copies files, and returns new tasks."""
new_tasks = []
error_type = res.get("error")
all_labels = read_all_labels(INPUT_DIR)
labels_txt = (Path(INPUT_DIR) / "labels").read_text()
enonce = enonce_total(INPUT_DIR)
if error_type == "wrong-label":
tprint(f"\tHandling wrong-label for {pid} {label}")
prompt = f"""This image is a part of the answer of a student to a written exam.
It was initially labeled '{label}' but I suspect this label is wrong. Perhaps the student himself wrote the wrong label.
You need to analyse this image, and find the label of the question it answers. Do not trust the label written by the student but instead check the content of its answer and the notation he uses to identify the correct label of the question the student answered.
Return ONLY the exact label string.
Here is the full content of the exam :
{enonce}
Here is a list of all possible lables. You need to answer with one of these :
{labels_txt}
"""
contents = [types.Content(role="user", parts=[
types.Part.from_bytes(data=get_single_image_bytes(pdf_path), mime_type="image/jpeg"),
types.Part.from_text(text=prompt) ])]
config = types.GenerateContentConfig(temperature=0.0)
new_label = call_gemini_with_retries(MODEL_ID_flash, contents, config).strip().strip('"\'')
if new_label not in all_labels:
tprint(f"\t\tCopie{pid} returned an incorrect label {new_label} from an initial wrong label {label}. Ignoring")
res["error"] = "wrg-lbl:cldtfix"
return []
if new_label == label:
res["error"] = ""
return []
new_pdf_path = Path(INPUT_DIR) / f"Copie{pid}" / f"{new_label}.pdf"
if new_pdf_path.exists():
tprint(f"\t\tCopie{pid} tried to move wrong {label} to {new_label}, but it already exists.")
res["error"] = f"wrg-lbl:{new_label}?exists"
else:
res["error"] = f"wrg-lbl-moved-to:{new_label}"
tprint(f"\t\tCopie{pid} : moving wrong {label} to {new_label}.")
shutil.move(str(pdf_path), str(new_pdf_path))
# Since we moved the file, this Copie/label should not be taken
# into account in the future, I think
idx = get_next_group_idx(INPUT_DIR, new_label)
height = grouping.get_pdf_height(str(new_pdf_path))
grouping.create_jpg(new_label, idx, [(pid, str(new_pdf_path), height)],
INPUT_DIR)
tprint(f"\t\tMaking {new_label} group {idx+1}")
new_tasks.append((str(Path(INPUT_DIR) / new_label / f"Group_{idx+1}.jpg"),
new_label, False))
elif error_type == "additional-answer":
prompt = f"""This image is a part of the answer of a student to a written exam.
It was initially labeled '{label}' but I suspect this image also contains answers to another, or several other questions.
You need to analyse this image, and find the list of the labels of the questions it answers. Return ONLY the list of the exact label strings.
If the end of the image only contains the first line of an answer to another question, ignore it.
Here is the full content of the exam :
{enonce}
Here is a list of all possible labels. You need to answer with a list one of these :
{labels_txt}
"""
tprint(f"\tHandling additional-answer for {pid} {label}")
contents = [types.Content(role="user", parts=[
types.Part.from_bytes(data=get_single_image_bytes(pdf_path), mime_type="image/jpeg"),
types.Part.from_text(text=prompt)
])]
config = types.GenerateContentConfig(temperature=0.0, response_mime_type="application/json")
try:
add_labels = json.loads(call_gemini_with_retries(MODEL_ID_flash, contents, config))
except Exception:
add_labels = []
tprint(f"\tHandling additional-answer for {pid} {label}")
keep_error = False
error = "al:"
for add_label in add_labels:
if add_label == label:
continue
if add_label not in all_labels:
tprint(f"\t\t Inexistent label ({add_label}) from additional-answer processing {pid} {label}. Ignoring")
error += f"{add_label}??"
keep_error = True
continue
new_pdf_path = Path(INPUT_DIR) / f"Copie{pid}" / f"{add_label}.pdf"
if not new_pdf_path.exists():
shutil.copy(str(pdf_path), str(new_pdf_path))
tprint(f"\t\tCopying Copie{pid} : {label} -> {add_label}")
idx = get_next_group_idx(INPUT_DIR, add_label)
tprint(f"\t\tMaking {add_label} group {idx+1}")
height = grouping.get_pdf_height(str(new_pdf_path))
grouping.create_jpg(add_label, idx, [(pid, str(new_pdf_path), height)], INPUT_DIR)
new_tasks.append((str(Path(INPUT_DIR) / add_label / f"Group_{idx+1}.jpg"),
add_label, False))
error += f"(->){add_label}"
keep_error = True
else:
keep_error = True
error += f"(xx){add_label}"
tprint(f"\t\tAlready present (not copied) Copie{pid} : {label} -> {add_label}")
if not keep_error:
res["error"] = ""
else:
res["error"] = error
return new_tasks
def process_single_task(task_tuple, precomputed_response=None):
try:
global pro_count, flash_count, pro_quota_exhausted
file_path = task_tuple[0]
label = task_tuple[1]
can_spawn_tasks = task_tuple[2] if len(task_tuple) > 2 else True
group_name = os.path.splitext(file_path)[0]
json_path = group_name + '.json'
new_tasks = []
with open(json_path, 'r') as f:
group_data = json.load(f)
n = len(group_data)
d_data = {l[0]: (l[1], l[2], l[3]) for l in group_data}
total_height = group_data[-1][2]
use_flash = n >= 4 or total_height <= 500
# Only apply limits and counts if we are making a live call
if precomputed_response is None:
if not use_flash:
with pro_lock:
if pro_quota_exhausted:
use_flash = True
elif limit is None or pro_count < limit:
pro_count += 1
else:
use_flash = True
if use_flash:
with pro_lock:
flash_count += 1
try:
contents, config = generate_request(file_path, label)
model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
if precomputed_response:
tprint(f"Using batched response for: {label} {group_name}")
full_response_text = precomputed_response
else:
tprint(f"Asking Gemini {'Flash' if use_flash else 'Pro '}: {label} {group_name}")
full_response_text = call_gemini_with_retries(model_to_use, contents, config)
json_data = json.loads(full_response_text)
# Ensure consistency of answer placements
for p in json_data:
pid = p["id"]
res = p["result"]
yming, ymaxg, width_r = d_data[pid]
pdf_path = Path(INPUT_DIR) / f"Copie{pid}" / f"{label}.pdf"
if (not can_spawn_tasks) and res["error"] == "additional-answer":
tprint("\tSwallowing an additional-answer from a subsequent task.")
res["error"]= ""
if res["error"] != "":
tprint("\tError :", res["error"], "for Copie", pid, group_name)
if can_spawn_tasks and res.get("error") in ["wrong-label", "additional-answer"]:
new_tasks.extend(handle_label_errors(pid, label, res, pdf_path))
needs_correction = []
for (i,f) in enumerate(res["feedback"]):
b = f.get("box_2d")
if b:
ymin, xmin, ymax, xmax = b
ymin = ymin * total_height // 1000
ymax = ymax * total_height // 1000
if pid not in d_data:
tprint("Error : Gemini answered a copie id not present",
pid, label, group_name)
continue
if (ymin < yming - 50 or
ymax > ymaxg + 50 or
xmax / 1000 > width_r):
needs_correction.append(i)
break
if needs_correction:
tprint(f"\tBox anomalies detected for Copie {pid} {group_name}. \n\tRequesting isolated correction from Gemini Flash...")
try:
res["feedback"] = correct_boxes_with_gemini(
pid, label, res["feedback"], INPUT_DIR,
yming, ymaxg, width_r, total_height)
except Exception as e:
tprint(f"\tCorrection failed for Copie {pid}, {group_name} : {e}\n\tRemoving the boxes")
# Fallback if the second request fails entirely
for (i, f) in enumerate(res["feedback"]):
if i in needs_correction:
f["box_2d"] = None
# --- Use Lock for writing shared data ---
with io_lock:
if label not in results:
results[label] = []
results[label].append(json_data)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
# To track progress
completed_tasks.append((file_path, label))
with open(progress_path, "w", encoding="utf-8") as f:
json.dump(completed_tasks, f, indent=2)
except json.JSONDecodeError:
tprint(f"Error decoding JSON for {file_path}", file=sys.stderr)
except Exception as e:
error_msg = f"Exception processing {file_path}: {e}"
print(error_msg, file=sys.stderr)
with io_lock:
errors_summary.append((error_msg, file_path))
return new_tasks
finally:
flush_thread_log()
if __name__ == "__main__":
if args.refaire:
refaire_path = Path(INPUT_DIR) / "refaire.json"
overwritten_path = Path(INPUT_DIR) / "overwritten_correction.json"
if refaire_path.exists():
with open(refaire_path, "r", encoding="utf-8") as f:
refaire_list = json.load(f)
overwritten_data = []
if overwritten_path.exists():
with open(overwritten_path, "r", encoding="utf-8") as f:
overwritten_data = json.load(f)
dirty_results = False
for copie_name, labels in refaire_list:
pid = copie_name.replace("Copie", "")
copie_dir = Path(INPUT_DIR) / copie_name
# If list is empty, redo all labels available for this Copie
if not labels:
labels = [p.stem for p in copie_dir.glob("*.pdf")]
for label in labels:
# 1. Extract and backup old corrections
if label in results:
for batch in results[label]:
to_remove = None
for item in batch:
if item.get("id") == pid:
to_remove = item
break
if to_remove:
batch.remove(to_remove)
overwritten_data.append({
"pid": pid,
"label": label,
"data": to_remove,
"timestamp": time.time()
})
dirty_results = True
# Clean up empty batches
results[label] = [b for b in results[label] if b]
# 2. Make new group and add to tasks
pdf_path = copie_dir / f"{label}.pdf"
if pdf_path.exists():
idx = get_next_group_idx(INPUT_DIR, label)
height = grouping.get_pdf_height(str(pdf_path))
grouping.create_jpg(label, idx, [(pid, str(pdf_path), height)], INPUT_DIR)
new_group_path = str(Path(INPUT_DIR) / label / f"Group_{idx+1}.jpg")
tasks_to_process.append((new_group_path, label))
if dirty_results:
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
with open(overwritten_path, "w", encoding="utf-8") as f:
json.dump(overwritten_data, f, indent=2)
else:
print(f"Warning: --refaire flag used, but {refaire_path} not found.", file=sys.stderr)
if args.batch:
batch_file = Path(INPUT_DIR) / "batch_requests.jsonl"
with open(batch_file, "w", encoding="utf-8") as f:
for task in tasks_to_process:
file_path, label = task[0], task[1]
group_name = os.path.splitext(file_path)[0]
json_path = group_name + '.json'
with open(json_path, 'r') as jf:
group_data = json.load(jf)
use_flash = len(group_data) >= 4 or group_data[-1][2] <= 500
model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
image_data = Path(file_path).read_bytes()
b64_img = base64.b64encode(image_data).decode("utf-8")
# Format payload. NOTE: adapt the JSON format if your specific Gemini
# Batch API endpoint expects a slightly different schema.
req = {
"custom_id": file_path, # Mapping ID
"method": "POST",
"url": f"/v1beta/models/{model_to_use}:generateContent",
"body": {
"contents": [{
"role": "user",
"parts": [
{"inlineData": {"mimeType": "image/jpeg", "data": b64_img}},
{"text": make_prompt(label)}
]
}],
"generationConfig": {
"temperature": 1.0,
"topP": 0.95,
"maxOutputTokens": 65535,
"responseMimeType": "application/json",
"responseSchema": TypeAdapter(List[EvaluationEntry]).json_schema()
}
}
}
f.write(json.dumps(req) + "\n")
print(f"Batch generation complete. {len(tasks_to_process)} requests saved to {batch_file}")
sys.exit(0)
batched_responses = {}
if args.deal_with_batched:
batch_results_path = Path(args.deal_with_batched)
if batch_results_path.exists():
print(f"Loading batch results from {batch_results_path}...")
with open(batch_results_path, "r", encoding="utf-8") as f:
for line in f:
if not line.strip(): continue
data = json.loads(line)
task_id = data.get("custom_id")
# Extract the JSON response text. Adapt this path to match your API output schema!
try:
resp_text = data["response"]["body"]["candidates"][0]["content"]["parts"][0]["text"]
batched_responses[task_id] = resp_text
except (KeyError, IndexError):
batched_responses[task_id] = data.get("response_text", "")
else:
print(f"Warning: Batch results file {batch_results_path} not found.", file=sys.stderr)
print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...")
with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
futures = {}
for task in tasks_to_process:
file_path = task[0]
precomp = batched_responses.get(file_path)
futures[executor.submit(process_single_task, task, precomp)] = task
# Process tasks as they complete, allowing dynamic task addition
for future in concurrent.futures.as_completed(futures):
try:
new_generated_tasks = future.result()
if new_generated_tasks:
for new_task in new_generated_tasks:
# New tasks from wrong-label/additional-answer will fallback to live API
futures[executor.submit(process_single_task, new_task)] = new_task
except Exception as e:
print(f"Exception during task execution: {e}", file=sys.stderr)
end_time = time.time()
print("Time elapsed : ", end_time - start_time)
print("Requests to pro / flash : ", pro_count, flash_count)
if errors_summary:
print("\n--- Summary of Exceptions ---", file=sys.stderr)
for (err, file) in errors_summary:
print(err, file=sys.stderr)
escaped_path = shlex.quote(str(file))
print(f"Run : python correction.py {escaped_path}")