935 lines
38 KiB
Python
935 lines
38 KiB
Python
import sys
|
|
import os
|
|
import time
|
|
from pathlib import Path
|
|
import argparse
|
|
|
|
if len(sys.argv) < 2:
|
|
sys.exit("Usage: python script.py 'InterroTest/Ex 2/Group_1.jpg' OR <InputDir> OR 'file1' 'file2'")
|
|
|
|
# Parse Arguments
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("paths", nargs="+", help="List of images or directories")
|
|
parser.add_argument("--overwrite", action="store_true",
|
|
help="Force redo requests even if output exists")
|
|
parser.add_argument("--limit", type=int, help="limit calls to gemini rpo integer")
|
|
parser.add_argument("--refaire", action="store_true",
|
|
help="Redo specific copies/labels defined in refaire.json")
|
|
parser.add_argument("--batch", action="store_true",
|
|
help="Generate a JSONL file of requests to send to the Gemini Batch API")
|
|
parser.add_argument("--batch-from", type=str, metavar="LABEL",
|
|
help="Do live requests before LABEL, and batch requests from LABEL onwards")
|
|
parser.add_argument("--deal-with-batched", action="store_true",
|
|
help="Process a JSONL file containing completed batch results")
|
|
args, _ = parser.parse_known_args()
|
|
|
|
tasks = [] # List of tuples: (filepath_str, label_str)
|
|
results = {}
|
|
|
|
|
|
for path_str in args.paths:
|
|
arg_path = Path(path_str)
|
|
|
|
if not arg_path.exists():
|
|
print(f"Warning: {path_str} not found. Skipping.")
|
|
continue
|
|
|
|
if arg_path.is_file() and arg_path.suffix.lower() == ".jpg":
|
|
# Handle individual file
|
|
# Note: assumes structure InterroTest/Ex 2/Group_1.jpg to get parents[1]
|
|
label = arg_path.parent.name
|
|
INPUT_DIR = arg_path.parent.parent.parent
|
|
COPIES_DIR = INPUT_DIR / "Copies"
|
|
GROUPS_DIR = INPUT_DIR / "Par label"
|
|
tasks.append((str(arg_path), label))
|
|
if label not in results:
|
|
results[label] = []
|
|
|
|
elif arg_path.is_dir():
|
|
INPUT_DIR = arg_path
|
|
COPIES_DIR = INPUT_DIR / "Copies"
|
|
GROUPS_DIR = INPUT_DIR / "Par label"
|
|
# Handle directory (original behavior)
|
|
for sub in GROUPS_DIR.iterdir():
|
|
if sub.is_dir():
|
|
label = sub.name
|
|
if label not in results:
|
|
results[label] = []
|
|
for img in sub.glob("*.jpg"):
|
|
tasks.append((str(img), label))
|
|
|
|
my_prompt = """I'm giving you an image of several written answers to an exam.
|
|
|
|
Each answer is separated by a black horizontal line, and underneath,
|
|
to the left, is indicated the ID of the answer, from `01` to `50`.
|
|
|
|
I want you to score each answer, from 0 to 4, you may score half
|
|
points, such as 2.5. Even if a result is wrong, if the reasoning is
|
|
correct and could lead to a right answer, you should give at least
|
|
half the points.
|
|
|
|
You also need to give feedback to the student, in french :
|
|
- which part of his answer is wrong,
|
|
- why is it wrong
|
|
- possibly, what he should have done instead.
|
|
Your feedback may contain LaTeX fragments written like `$a^2 + b^2 = c^2$`.
|
|
|
|
If your score is not 4, you should always provide some feedback
|
|
explaining what's missing.
|
|
|
|
For each piece of feedback, if it is related to a specific part of the
|
|
answer that is wrong, you may provide a `box_2d`, to locate this
|
|
specific part of the answer. This `box_2d` should be in the form
|
|
[ymin, xmin, ymax, xmax] normalized to 0-1000. If you do not provide
|
|
one, set `box_2d` to `null`.
|
|
|
|
If the answer is correct, there is no need to provide feedback. You do
|
|
not have to give positive feedback, but if you do, do not provide a
|
|
`box_2d` for it.
|
|
|
|
For example, if the student says a function is continuous when it
|
|
isn't, provide the coordinates where the word «continuous» is. If a
|
|
calculation went wrong, gives the coordinates of the step where it
|
|
goes wrong, and as feedback, what went wrong.
|
|
|
|
Avoid giving feedback about confusing letters `n` with `m`, `x` with
|
|
`n` or `h` with `k`. If it looks wrong, assume you read it wrong,
|
|
unless the distinction is very important.
|
|
|
|
You should also give me a measure of confidence, from 0 to 1 that you
|
|
were able to correctly understand the answer. A score below 0.5 means
|
|
that you think it is likely that you couldn't understand an important
|
|
part.
|
|
|
|
In some case, you may find that either
|
|
- The student didn't answer the right question. Set the score to 0.
|
|
Since it could be a labeling error, indicate is by setting `error`
|
|
to \"wrong-label\".
|
|
- You can find an answer to another question of the exercice (taking
|
|
more than a couple of lines). Score the question you are supposed
|
|
to score, but set `error` to \"additional-answer\".
|
|
- The answer to the question is empty, or the student has only
|
|
rewritten the statement of the question. In this case, set `error`
|
|
to \"empty-answer\" and do not provide any kind of feedback.
|
|
If there's no error, set `error` to `\"\"`.
|
|
|
|
You will answer using json describing a list of dictionary with a key
|
|
\"id\", and a key \"result\" that contains the \"score\", the \"confidence\", a
|
|
list \"feedback\", and possibly an \"error\". Like this example :
|
|
|
|
[{ \"id\": \"01\",
|
|
\"result\": {\"score\" : 2.5,
|
|
\"confidence\" : 0.8,
|
|
\"feedback\": [{text: \"Un retour générique. Il faut apprendre le cours.\", box_2d: null},
|
|
{text: \"Non, la fonction n'est pas forcément continue\", pos: [145, 280, 340, 500]}],
|
|
\"error\": \"\"}
|
|
},
|
|
{ \"id\": \"04\",
|
|
\"result\": {\"score\" : 4.,
|
|
\"confidence\" : 0.9,
|
|
\"feedback\" : []
|
|
\"error\": \"\" }
|
|
}
|
|
]
|
|
|
|
Here is the text of the exercice (or the relevant part of the problem)
|
|
of the exam :
|
|
|
|
```
|
|
<<text>>
|
|
```
|
|
|
|
Here is a possible correct answer :
|
|
|
|
```
|
|
<<corr>>
|
|
```
|
|
<<persp>>
|
|
|
|
You are asked to score the question or exercice labeled `<<label>>`,
|
|
do not score or give feedback to any other question."""
|
|
|
|
def make_prompt(full_label):
|
|
def read_longest_prefix_file(subdir):
|
|
dir_path = INPUT_DIR / subdir
|
|
matches = [f for f in dir_path.iterdir()
|
|
if f.is_file()
|
|
and full_label.startswith(f.name)
|
|
and f.suffix not in [".pdf", ".tex"]]
|
|
if not matches:
|
|
return ""
|
|
return max(matches, key=lambda f: len(f.name)).read_text(encoding="utf-8", errors="replace")
|
|
|
|
text = read_longest_prefix_file("Text")
|
|
corr = read_longest_prefix_file("Sol")
|
|
persp = read_longest_prefix_file("Persp")
|
|
|
|
if persp != "":
|
|
persp = "\n\nHere are additional scoring instructions : \n\n```\n" + persp +"\n```\n"
|
|
return my_prompt.replace("<<text>>", text).replace("<<corr>>", corr).replace("<<persp>>", persp).replace("<<label>>", full_label)
|
|
|
|
from google import genai
|
|
from google.genai import types
|
|
import base64
|
|
import shlex
|
|
import json
|
|
import os
|
|
import threading
|
|
import concurrent.futures
|
|
|
|
NB_THREADS = 12
|
|
|
|
# PROXY_URL = "http://192.168.241.1:3128"
|
|
PROXY_URL = None
|
|
|
|
if PROXY_URL:
|
|
os.environ["http_proxy"] = PROXY_URL
|
|
os.environ["https_proxy"] = PROXY_URL
|
|
|
|
MODEL_ID_pro = "gemini-3.1-pro-preview"
|
|
MODEL_ID_flash = "gemini-3-flash-preview"
|
|
api_key = os.environ["GEMINI_API_KEY"]
|
|
|
|
import signal
|
|
import sys
|
|
|
|
# --- Thread-safe Logging ---
|
|
log_lock = threading.Lock()
|
|
thread_logs = {}
|
|
|
|
def tprint(*args, **kwargs):
|
|
"""Buffer messages per thread to group them."""
|
|
tid = threading.current_thread().name
|
|
msg = " ".join(map(str, args))
|
|
|
|
with log_lock:
|
|
if tid not in thread_logs:
|
|
thread_logs[tid] = []
|
|
thread_logs[tid].append(msg)
|
|
|
|
# Optional: Keep printing to console but prefix with thread name
|
|
print(f"[{tid}] {msg}", **kwargs)
|
|
|
|
def flush_thread_log(tid=None):
|
|
"""Append a thread's buffered messages to the log file contiguously."""
|
|
tid = tid or threading.current_thread().name
|
|
with log_lock:
|
|
if thread_logs.get(tid):
|
|
with open(INPUT_DIR / "correction_log", "a", encoding="utf-8") as f:
|
|
f.write(f"--- Task Log [{tid}] ---\n")
|
|
f.write("\n".join(thread_logs[tid]) + "\n\n")
|
|
thread_logs[tid].clear()
|
|
|
|
def handle_interrupt(sig, frame):
|
|
"""Flush all partial/unfinished logs if program is interrupted."""
|
|
print("\nInterrupt received. Flushing partial logs...", file=sys.stderr)
|
|
for tid in list(thread_logs.keys()):
|
|
flush_thread_log(tid)
|
|
sys.exit(1)
|
|
|
|
signal.signal(signal.SIGINT, handle_interrupt)
|
|
signal.signal(signal.SIGTERM, handle_interrupt)
|
|
# ---------------------------
|
|
|
|
from pydantic import BaseModel, Field, TypeAdapter
|
|
from typing import List, Optional, Tuple
|
|
|
|
class FeedbackItem(BaseModel):
|
|
text: str = Field(description="Feedback content")
|
|
box_2d: Optional[List[int]] = Field(None, description="box coordinates or null")
|
|
|
|
class ResultData(BaseModel):
|
|
score: float = Field(description="The numeric score")
|
|
confidence: float = Field(description="Confidence level")
|
|
feedback: List[FeedbackItem] = Field(description="List of feedback items")
|
|
error: str = Field(description="Indicates if an error occurred")
|
|
|
|
class EvaluationEntry(BaseModel):
|
|
id: str = Field(description="Entry identifier")
|
|
result: ResultData = Field(description="Result details")
|
|
|
|
# These nested definitions do not work with the batch api, unroll them
|
|
UNROLLED_SCHEMA = {
|
|
"type": "ARRAY",
|
|
"items": {
|
|
"type": "OBJECT",
|
|
"properties": {
|
|
"id": {"type": "STRING", "description": "Entry identifier"},
|
|
"result": {
|
|
"type": "OBJECT",
|
|
"properties": {
|
|
"score": {"type": "NUMBER", "description": "The numeric score"},
|
|
"confidence": {"type": "NUMBER", "description": "Confidence level"},
|
|
"error": {"type": "STRING", "description": "Indicates if an error occurred"},
|
|
"feedback": {
|
|
"type": "ARRAY",
|
|
"description": "List of feedback items",
|
|
"items": {
|
|
"type": "OBJECT",
|
|
"properties": {
|
|
"text": {"type": "STRING", "description": "Feedback content"},
|
|
"box_2d": {
|
|
"type": "ARRAY",
|
|
"items": {"type": "INTEGER"},
|
|
"nullable": True,
|
|
"description": "box coordinates or null"
|
|
}
|
|
},
|
|
"required": ["text"]
|
|
}
|
|
}
|
|
},
|
|
"required": ["score", "confidence", "feedback", "error"]
|
|
}
|
|
},
|
|
"required": ["id", "result"]
|
|
}
|
|
}
|
|
|
|
# The root model for parsing is be: List[EvaluationEntry]
|
|
def generate_request(file, full_label):
|
|
"""Generates request for Gemini."""
|
|
prompt = make_prompt(full_label)
|
|
image_path = Path(file)
|
|
|
|
contents = [
|
|
types.Content(
|
|
role="user",
|
|
parts=[
|
|
types.Part.from_bytes(
|
|
data=image_path.read_bytes(),
|
|
mime_type="image/jpeg"
|
|
),
|
|
types.Part.from_text(text=prompt),
|
|
],
|
|
)
|
|
]
|
|
|
|
generate_content_config = types.GenerateContentConfig(
|
|
temperature=1.0,
|
|
top_p=0.95,
|
|
seed=0,
|
|
max_output_tokens=65535,
|
|
response_mime_type= "application/json",
|
|
response_json_schema= TypeAdapter(List[EvaluationEntry]).json_schema()
|
|
)
|
|
return (contents, generate_content_config)
|
|
|
|
client = genai.Client(api_key=api_key)
|
|
output_path = INPUT_DIR / "correction.json"
|
|
progress_path = INPUT_DIR / "correction_progress.json"
|
|
start_time = time.time()
|
|
overwrite = args.overwrite
|
|
limit = args.limit
|
|
completed_tasks = []
|
|
errors_summary = []
|
|
|
|
# --- Lock for thread-safe file writing ---
|
|
io_lock = threading.Lock()
|
|
pro_lock = threading.Lock()
|
|
pro_count = 0
|
|
flash_count = 0
|
|
pro_quota_exhausted = False
|
|
|
|
if overwrite:
|
|
if output_path.exists():
|
|
output_path.unlink()
|
|
if progress_path.exists():
|
|
progress_path.unlink()
|
|
else:
|
|
if progress_path.exists():
|
|
with open(progress_path, "r", encoding="utf-8") as f:
|
|
completed_tasks = json.load(f)
|
|
if output_path.exists():
|
|
with open(output_path, "r", encoding="utf-8") as f:
|
|
results = json.load(f)
|
|
|
|
completed_set = set((str(f), l) for f, l in completed_tasks)
|
|
tasks_to_process = [t for t in tasks if (str(t[0]), t[1]) not in completed_set]
|
|
|
|
def call_gemini_with_retries(model_id, contents, config,
|
|
fallback_model_id=MODEL_ID_flash):
|
|
"""Handles requests to Gemini with a 1min and 5min retry mechanism, and quota fallback."""
|
|
global pro_quota_exhausted
|
|
delays = [60, 300]
|
|
|
|
for attempt in range(3):
|
|
# Switch to fallback immediately if quota was exhausted by another thread
|
|
if model_id == MODEL_ID_pro and pro_quota_exhausted and fallback_model_id:
|
|
model_id = fallback_model_id
|
|
|
|
try:
|
|
full_response_text = ""
|
|
for chunk in client.models.generate_content_stream(
|
|
model=model_id,
|
|
contents=contents,
|
|
config=config,
|
|
):
|
|
if chunk.text:
|
|
full_response_text += chunk.text
|
|
return full_response_text
|
|
except Exception as e:
|
|
error_msg = str(e).lower()
|
|
is_quota_error = "429" in error_msg or "quota" in error_msg or "exhausted" in error_msg
|
|
|
|
# Immediately fallback to Flash without waiting if it's a Pro quota error
|
|
if is_quota_error and model_id == MODEL_ID_pro and fallback_model_id:
|
|
tprint(f"\tGemini Pro quota hit ({e}). \n\n\tFalling back to Flash permanently...")
|
|
model_id = fallback_model_id
|
|
pro_quota_exhausted = True
|
|
continue # Retry immediately with Flash
|
|
|
|
if attempt < 2:
|
|
tprint(f"\tGemini API failure: {e}. Retrying in {delays[attempt]} seconds...")
|
|
time.sleep(delays[attempt])
|
|
else:
|
|
tprint(f"\tGemini API failure: {e}. Maximum retries reached.")
|
|
raise
|
|
|
|
import io
|
|
from pdf2image import convert_from_path
|
|
from PIL import Image
|
|
|
|
def get_single_image_bytes(pdf_path):
|
|
"""Converts a multi-page PDF into a single stitched JPEG in memory."""
|
|
imgs = convert_from_path(pdf_path, dpi=200) # Same DPI as grouping.py
|
|
if not imgs:
|
|
raise ValueError(f"No pages in {pdf_path}")
|
|
|
|
if len(imgs) == 1:
|
|
combined = imgs[0]
|
|
else:
|
|
max_width = max(img.width for img in imgs)
|
|
total_height = sum(img.height for img in imgs)
|
|
combined = Image.new('RGB', (max_width, total_height), 'white')
|
|
y_offset = 0
|
|
for img in imgs:
|
|
combined.paste(img, (0, y_offset))
|
|
y_offset += img.height
|
|
|
|
img_byte_arr = io.BytesIO()
|
|
combined.save(img_byte_arr, format='JPEG', quality=85)
|
|
return img_byte_arr.getvalue()
|
|
|
|
def correct_boxes_with_gemini(pid, label, original_feedbacks,
|
|
yming, ymaxg, width_r, total_height):
|
|
"""Requests corrected bounding boxes from Gemini Flash on the single image."""
|
|
pdf_path = COPIES_DIR / f"Copie{pid}" / f"{label}.pdf"
|
|
img_bytes = get_single_image_bytes(pdf_path)
|
|
|
|
localized_feedbacks = [f for f in original_feedbacks if f["box_2d"]]
|
|
global_feedbacks = [f for f in original_feedbacks if not f["box_2d"]]
|
|
|
|
prompt = f"""
|
|
Here is a single student's submission to a question in a written exam. The following JSON contains feedback items with bounding boxes (box_2d) that are incorrect. Each piece of feedback is supposed to be related to a piece of the answer that is wrong.
|
|
|
|
For example, if the student says a function is continuous when it
|
|
isn't, the coordinates should be where the word «continuous» is. If a
|
|
calculation went wrong, the coordinates should be where the step where
|
|
it goes wrong, and the feedback is what went wrong.
|
|
|
|
Please analyze the image and return the exact same feedback text, but with ONLY the box_2d coordinates corrected for this specific image.
|
|
Coordinates must be [ymin, xmin, ymax, xmax] scaled to 1000. If a box is invalid/not found, return null for it.
|
|
Original feedback:
|
|
|
|
{json.dumps(localized_feedbacks, indent=2)}
|
|
"""
|
|
|
|
|
|
|
|
contents = [
|
|
types.Content(
|
|
role="user",
|
|
parts=[
|
|
types.Part.from_bytes(data=img_bytes, mime_type="image/jpeg"),
|
|
types.Part.from_text(text=prompt),
|
|
],
|
|
)
|
|
]
|
|
|
|
config = types.GenerateContentConfig(
|
|
temperature=0.0, # Low temperature for accurate correction
|
|
response_mime_type="application/json",
|
|
response_json_schema=TypeAdapter(List[FeedbackItem]).json_schema()
|
|
)
|
|
|
|
response_text = call_gemini_with_retries(MODEL_ID_flash, contents, config)
|
|
corrected_feedbacks = json.loads(response_text)
|
|
|
|
# Map the coordinates back from the single image to the group canvas
|
|
for f in corrected_feedbacks:
|
|
b = f.get("box_2d")
|
|
if b:
|
|
ymin_s, xmin_s, ymax_s, xmax_s = b
|
|
|
|
# Y mapping: Add the group Y-offset (yming), then normalize to total_height
|
|
single_h = ymaxg - yming
|
|
new_ymin = int((yming + (ymin_s * single_h / 1000.0)) * 1000.0 / total_height)
|
|
new_ymax = int((yming + (ymax_s * single_h / 1000.0)) * 1000.0 / total_height)
|
|
|
|
# X mapping: Multiply by the width ratio of this sub-image vs the group image
|
|
new_xmin = int(xmin_s * width_r)
|
|
new_xmax = int(xmax_s * width_r)
|
|
|
|
f["box_2d"] = [new_ymin, new_xmin, new_ymax, new_xmax]
|
|
|
|
return global_feedbacks + corrected_feedbacks
|
|
|
|
import shutil
|
|
import grouping
|
|
|
|
def get_next_group_idx(label):
|
|
"""Finds the next available Group index for a given label."""
|
|
target_folder = GROUPS_DIR / label
|
|
target_folder.mkdir(exist_ok=True)
|
|
existing = list(target_folder.glob("Group_*.jpg"))
|
|
if not existing: return 0
|
|
return max([int(f.stem.split("_")[1]) for f in existing])
|
|
|
|
from utils import read_all_labels, enonce_total
|
|
|
|
def handle_label_errors(pid, label, res, pdf_path):
|
|
"""Handles Gemini labeling errors, moves/copies files, and returns new tasks."""
|
|
new_tasks = []
|
|
error_type = res.get("error")
|
|
|
|
all_labels = read_all_labels(INPUT_DIR)
|
|
labels_txt = (INPUT_DIR / "labels").read_text(encoding="utf-8", errors="replace")
|
|
enonce = enonce_total(INPUT_DIR)
|
|
|
|
if error_type == "wrong-label":
|
|
tprint(f"\tHandling wrong-label for {pid} {label}")
|
|
prompt = f"""This image is a part of the answer of a student to a written exam.
|
|
|
|
It was initially labeled '{label}' but I suspect this label is wrong. Perhaps the student himself wrote the wrong label.
|
|
|
|
You need to analyse this image, and find the label of the question it answers. Do not trust the label written by the student but instead check the content of its answer and the notation he uses to identify the correct label of the question the student answered.
|
|
|
|
Return ONLY the exact label string.
|
|
|
|
Here is the full content of the exam :
|
|
|
|
{enonce}
|
|
|
|
Here is a list of all possible labels. You need to answer with one of these :
|
|
|
|
{labels_txt}
|
|
"""
|
|
|
|
contents = [types.Content(role="user", parts=[
|
|
types.Part.from_bytes(data=get_single_image_bytes(pdf_path), mime_type="image/jpeg"),
|
|
types.Part.from_text(text=prompt) ])]
|
|
config = types.GenerateContentConfig(temperature=0.0)
|
|
new_label = call_gemini_with_retries(MODEL_ID_flash, contents, config).strip().strip('"\'')
|
|
if new_label not in all_labels:
|
|
tprint(f"\t\tCopie{pid} returned an incorrect label {new_label} from an initial wrong label {label}. Ignoring")
|
|
res["error"] = "wrg-lbl:cldtfix"
|
|
return []
|
|
if new_label == label:
|
|
res["error"] = ""
|
|
return []
|
|
new_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{new_label}.pdf"
|
|
if new_pdf_path.exists():
|
|
tprint(f"\t\tCopie{pid} tried to move wrong {label} to {new_label}, but it already exists.")
|
|
res["error"] = f"wrg-lbl:{new_label}?exists"
|
|
else:
|
|
res["error"] = f"wrg-lbl-moved-to:{new_label}"
|
|
tprint(f"\t\tCopie{pid} : moving wrong {label} to {new_label}.")
|
|
shutil.move(str(pdf_path), str(new_pdf_path))
|
|
# Since we moved the file, this Copie/label should not be taken
|
|
# into account in the future, I think
|
|
idx = get_next_group_idx(new_label)
|
|
height = grouping.get_pdf_height(str(new_pdf_path))
|
|
grouping.create_jpg(new_label, idx, [(pid, str(new_pdf_path), height)],
|
|
GROUPS_DIR)
|
|
tprint(f"\t\tMaking {new_label} group {idx+1}")
|
|
new_tasks.append((str(GROUPS_DIR / new_label / f"Group_{idx+1}.jpg"),
|
|
new_label, False))
|
|
|
|
elif error_type == "additional-answer":
|
|
prompt = f"""This image is a part of the answer of a student to a written exam.
|
|
|
|
It was initially labeled '{label}' but I suspect this image also contains answers to another, or several other questions.
|
|
|
|
You need to analyse this image, and find the list of the labels of the questions it answers. Return ONLY the list of the exact label strings.
|
|
|
|
If the end of the image only contains the first line of an answer to another question, ignore it.
|
|
|
|
Here is the full content of the exam :
|
|
|
|
{enonce}
|
|
|
|
Here is a list of all possible labels. You need to answer with a list one of these :
|
|
|
|
{labels_txt}
|
|
"""
|
|
tprint(f"\tHandling additional-answer for {pid} {label}")
|
|
contents = [types.Content(role="user", parts=[
|
|
types.Part.from_bytes(data=get_single_image_bytes(pdf_path), mime_type="image/jpeg"),
|
|
types.Part.from_text(text=prompt)
|
|
])]
|
|
config = types.GenerateContentConfig(temperature=0.0, response_mime_type="application/json")
|
|
try:
|
|
add_labels = json.loads(call_gemini_with_retries(MODEL_ID_flash, contents, config))
|
|
except Exception:
|
|
add_labels = []
|
|
|
|
tprint(f"\tHandling additional-answer for {pid} {label}")
|
|
keep_error = False
|
|
error = "al:"
|
|
for add_label in add_labels:
|
|
if add_label == label:
|
|
continue
|
|
if add_label not in all_labels:
|
|
tprint(f"\t\t Inexistent label ({add_label}) from additional-answer processing {pid} {label}. Ignoring")
|
|
error += f"{add_label}??"
|
|
keep_error = True
|
|
continue
|
|
new_pdf_path = COPIES_DIR / f"Copie{pid}" / f"{add_label}.pdf"
|
|
if not new_pdf_path.exists():
|
|
shutil.copy(str(pdf_path), str(new_pdf_path))
|
|
tprint(f"\t\tCopying Copie{pid} : {label} -> {add_label}")
|
|
idx = get_next_group_idx(add_label)
|
|
tprint(f"\t\tMaking {add_label} group {idx+1}")
|
|
height = grouping.get_pdf_height(str(new_pdf_path))
|
|
grouping.create_jpg(add_label, idx, [(pid, str(new_pdf_path), height)], GROUPS_DIR)
|
|
new_tasks.append((str(GROUPS_DIR / add_label / f"Group_{idx+1}.jpg"),
|
|
add_label, False))
|
|
error += f"(->){add_label}"
|
|
keep_error = True
|
|
else:
|
|
keep_error = True
|
|
error += f"(xx){add_label}"
|
|
tprint(f"\t\tAlready present (not copied) Copie{pid} : {label} -> {add_label}")
|
|
|
|
if not keep_error:
|
|
res["error"] = ""
|
|
else:
|
|
res["error"] = error
|
|
|
|
return new_tasks
|
|
|
|
def process_single_task(task_tuple, precomputed_response=None):
|
|
try:
|
|
global pro_count, flash_count, pro_quota_exhausted
|
|
file_path = task_tuple[0]
|
|
label = task_tuple[1]
|
|
can_spawn_tasks = task_tuple[2] if len(task_tuple) > 2 else True
|
|
|
|
group_name = os.path.splitext(file_path)[0]
|
|
json_path = group_name + '.json'
|
|
new_tasks = []
|
|
|
|
with open(json_path, 'r') as f:
|
|
group_data = json.load(f)
|
|
|
|
n = len(group_data)
|
|
d_data = {l[0]: (l[1], l[2], l[3]) for l in group_data}
|
|
total_height = group_data[-1][2]
|
|
use_flash = n >= 4 or total_height <= 500
|
|
|
|
# Only apply limits and counts if we are making a live call
|
|
if precomputed_response is None:
|
|
if not use_flash:
|
|
with pro_lock:
|
|
if pro_quota_exhausted:
|
|
use_flash = True
|
|
elif limit is None or pro_count < limit:
|
|
pro_count += 1
|
|
else:
|
|
use_flash = True
|
|
|
|
if use_flash:
|
|
with pro_lock:
|
|
flash_count += 1
|
|
|
|
try:
|
|
contents, config = generate_request(file_path, label)
|
|
model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
|
|
|
|
if precomputed_response:
|
|
tprint(f"Using batched response for: {label} {group_name}")
|
|
full_response_text = precomputed_response
|
|
else:
|
|
tprint(f"Asking Gemini {'Flash' if use_flash else 'Pro '}: {label} {group_name}")
|
|
full_response_text = call_gemini_with_retries(model_to_use, contents, config)
|
|
|
|
json_data = json.loads(full_response_text)
|
|
|
|
# Ensure consistency of answer placements
|
|
for p in json_data:
|
|
pid = p["id"]
|
|
res = p["result"]
|
|
yming, ymaxg, width_r = d_data[pid]
|
|
|
|
pdf_path = COPIES_DIR / f"Copie{pid}" / f"{label}.pdf"
|
|
if (not can_spawn_tasks) and res["error"] == "additional-answer":
|
|
tprint("\tSwallowing an additional-answer from a subsequent task.")
|
|
res["error"]= ""
|
|
if res["error"] != "":
|
|
tprint("\tError :", res["error"], "for Copie", pid, group_name)
|
|
|
|
if can_spawn_tasks and res.get("error") in ["wrong-label", "additional-answer"]:
|
|
new_tasks.extend(handle_label_errors(pid, label, res, pdf_path))
|
|
|
|
needs_correction = []
|
|
for (i,f) in enumerate(res["feedback"]):
|
|
b = f.get("box_2d")
|
|
if b:
|
|
ymin, xmin, ymax, xmax = b
|
|
ymin = ymin * total_height // 1000
|
|
ymax = ymax * total_height // 1000
|
|
|
|
if pid not in d_data:
|
|
tprint("Error : Gemini answered a copie id not present",
|
|
pid, label, group_name)
|
|
continue
|
|
|
|
if (ymin < yming - 50 or ymax > ymaxg + 50 or xmax / 1000 > width_r):
|
|
needs_correction.append(i)
|
|
break
|
|
if ymin < yming - 5:
|
|
ymin = yming - 5
|
|
b[0] = ymin * 1000 // total_height
|
|
if ymax > ymaxg + 5:
|
|
ymax = ymaxg + 5
|
|
b[2] = ymax * 1000 // total_height
|
|
|
|
|
|
if needs_correction:
|
|
tprint(f"\tBox anomalies detected for Copie {pid} {group_name}. \n\tRequesting isolated correction from Gemini Flash...")
|
|
try:
|
|
res["feedback"] = correct_boxes_with_gemini(
|
|
pid, label, res["feedback"],
|
|
yming, ymaxg, width_r, total_height)
|
|
except Exception as e:
|
|
tprint(f"\tCorrection failed for Copie {pid}, {group_name} : {e}\n\tRemoving the boxes")
|
|
# Fallback if the second request fails entirely
|
|
for (i, f) in enumerate(res["feedback"]):
|
|
if i in needs_correction:
|
|
f["box_2d"] = None
|
|
|
|
# --- Use Lock for writing shared data ---
|
|
with io_lock:
|
|
if label not in results:
|
|
results[label] = []
|
|
results[label].append(json_data)
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
# To track progress
|
|
completed_tasks.append((file_path, label))
|
|
with open(progress_path, "w", encoding="utf-8") as f:
|
|
json.dump(completed_tasks, f, indent=2)
|
|
|
|
except json.JSONDecodeError:
|
|
tprint(f"Error decoding JSON for {file_path}", file=sys.stderr)
|
|
except Exception as e:
|
|
error_msg = f"Exception processing {file_path}: {e}"
|
|
print(error_msg, file=sys.stderr)
|
|
with io_lock:
|
|
errors_summary.append((error_msg, file_path))
|
|
return new_tasks
|
|
finally:
|
|
flush_thread_log()
|
|
|
|
if __name__ == "__main__":
|
|
if args.refaire:
|
|
refaire_path = INPUT_DIR / "refaire.json"
|
|
overwritten_path = INPUT_DIR / "overwritten_correction.json"
|
|
|
|
if refaire_path.exists():
|
|
with open(refaire_path, "r", encoding="utf-8") as f:
|
|
refaire_list = json.load(f)
|
|
|
|
overwritten_data = []
|
|
if overwritten_path.exists():
|
|
with open(overwritten_path, "r", encoding="utf-8") as f:
|
|
overwritten_data = json.load(f)
|
|
|
|
dirty_results = False
|
|
|
|
for copie_name, labels in refaire_list:
|
|
pid = copie_name.replace("Copie", "")
|
|
copie_dir = COPIES_DIR / copie_name
|
|
|
|
# If list is empty, redo all labels available for this Copie
|
|
if not labels:
|
|
labels = [p.stem for p in copie_dir.glob("*.pdf")]
|
|
|
|
for label in labels:
|
|
# 1. Extract and backup old corrections
|
|
if label in results:
|
|
for batch in results[label]:
|
|
to_remove = None
|
|
for item in batch:
|
|
if item.get("id") == pid:
|
|
to_remove = item
|
|
break
|
|
if to_remove:
|
|
batch.remove(to_remove)
|
|
overwritten_data.append({
|
|
"pid": pid,
|
|
"label": label,
|
|
"data": to_remove,
|
|
"timestamp": time.time()
|
|
})
|
|
dirty_results = True
|
|
# Clean up empty batches
|
|
results[label] = [b for b in results[label] if b]
|
|
|
|
# 2. Make new group and add to tasks
|
|
pdf_path = copie_dir / f"{label}.pdf"
|
|
if pdf_path.exists():
|
|
idx = get_next_group_idx(label)
|
|
height = grouping.get_pdf_height(str(pdf_path))
|
|
grouping.create_jpg(label, idx, [(pid, str(pdf_path), height)], GROUPS_DIR)
|
|
new_group_path = str(GROUPS_DIR / label / f"Group_{idx+1}.jpg")
|
|
tasks_to_process.append((new_group_path, label))
|
|
|
|
if dirty_results:
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
json.dump(results, f, indent=2)
|
|
with open(overwritten_path, "w", encoding="utf-8") as f:
|
|
json.dump(overwritten_data, f, indent=2)
|
|
else:
|
|
print(f"Warning: --refaire flag used, but {refaire_path} not found.", file=sys.stderr)
|
|
|
|
|
|
if args.batch or args.batch_from:
|
|
from utils import read_all_labels
|
|
all_labels = read_all_labels(INPUT_DIR)
|
|
|
|
batch_tasks = []
|
|
if args.batch_from:
|
|
if args.batch_from not in all_labels:
|
|
sys.exit(f"Error: Label '{args.batch_from}' not found. Available labels: {all_labels}")
|
|
|
|
target_idx = all_labels.index(args.batch_from)
|
|
live_tasks = []
|
|
|
|
for task in tasks_to_process:
|
|
lbl = task[1]
|
|
# Any label found sequentially equal or after `args.batch_from` gets batched
|
|
if lbl in all_labels and all_labels.index(lbl) >= target_idx:
|
|
batch_tasks.append(task)
|
|
else:
|
|
live_tasks.append(task)
|
|
|
|
tasks_to_process = live_tasks # Keep live tasks to be run right after
|
|
else:
|
|
batch_tasks = tasks_to_process
|
|
tasks_to_process = [] # Run nothing live if just `--batch`
|
|
|
|
if batch_tasks:
|
|
batch_flash_file = INPUT_DIR / "batch_requests_flash.jsonl"
|
|
batch_pro_file = INPUT_DIR / "batch_requests_pro.jsonl"
|
|
|
|
count_flash = 0
|
|
count_pro = 0
|
|
|
|
with open(batch_flash_file, "w", encoding="utf-8") as f_flash, \
|
|
open(batch_pro_file, "w", encoding="utf-8") as f_pro:
|
|
|
|
for task in batch_tasks:
|
|
file_path, label = task[0], task[1]
|
|
group_name = os.path.splitext(file_path)[0]
|
|
json_path = group_name + '.json'
|
|
|
|
with open(json_path, 'r') as jf:
|
|
group_data = json.load(jf)
|
|
use_flash = len(group_data) >= 4 or group_data[-1][2] <= 500
|
|
|
|
image_data = Path(file_path).read_bytes()
|
|
b64_img = base64.b64encode(image_data).decode("utf-8")
|
|
|
|
# Format payload matching Gemini Batch API file requirements
|
|
req = {
|
|
"key": file_path, # The ID returned in the output file
|
|
"request": {
|
|
"contents": [{
|
|
"role": "user",
|
|
"parts": [
|
|
{"inlineData": {"mimeType": "image/jpeg", "data": b64_img}},
|
|
{"text": make_prompt(label)}
|
|
]
|
|
}],
|
|
"generation_config": {
|
|
"temperature": 1.0,
|
|
"topP": 0.95,
|
|
"maxOutputTokens": 65535,
|
|
"responseMimeType": "application/json",
|
|
"responseSchema": UNROLLED_SCHEMA
|
|
}
|
|
}
|
|
}
|
|
|
|
if use_flash:
|
|
f_flash.write(json.dumps(req) + "\n")
|
|
count_flash += 1
|
|
else:
|
|
f_pro.write(json.dumps(req) + "\n")
|
|
count_pro += 1
|
|
|
|
print(f"Batch generation complete.")
|
|
print(f" - {count_flash} requests saved to {batch_flash_file} (for {MODEL_ID_flash})")
|
|
print(f" - {count_pro} requests saved to {batch_pro_file} (for {MODEL_ID_pro})")
|
|
print("Upload these files via the File API and create two separate batch jobs.")
|
|
|
|
# If there's no live tasks to do, and we aren't doing a batched ingestion, exit right away
|
|
if not tasks_to_process and not args.deal_with_batched:
|
|
sys.exit(0)
|
|
|
|
batched_responses = {}
|
|
if args.deal_with_batched:
|
|
batch_results_path = INPUT_DIR / "batched_correction_result.jsonl"
|
|
if batch_results_path.exists():
|
|
print(f"Loading batch results from {batch_results_path}...")
|
|
with open(batch_results_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
if not line.strip(): continue
|
|
data = json.loads(line)
|
|
task_id = data.get("key") # Corresponds to the key sent in the request
|
|
|
|
if "response" in data:
|
|
try:
|
|
# Extract the JSON response text per standard Batch API schema
|
|
resp_text = data["response"]["candidates"][0]["content"]["parts"][0]["text"]
|
|
batched_responses[task_id] = resp_text
|
|
except (KeyError, IndexError) as e:
|
|
print(f"Warning: Could not parse response for {task_id}: {e}", file=sys.stderr)
|
|
elif "error" in data:
|
|
print(f"Batch API Error for {task_id}: {data['error']}", file=sys.stderr)
|
|
else:
|
|
print(f"Warning: Batch results file {batch_results_path} not found.", file=sys.stderr)
|
|
|
|
print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...")
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
|
|
futures = {}
|
|
for task in tasks_to_process:
|
|
file_path = task[0]
|
|
precomp = batched_responses.get(file_path)
|
|
futures[executor.submit(process_single_task, task, precomp)] = task
|
|
|
|
# Process tasks as they complete, allowing dynamic task addition
|
|
for future in concurrent.futures.as_completed(futures):
|
|
try:
|
|
new_generated_tasks = future.result()
|
|
if new_generated_tasks:
|
|
for new_task in new_generated_tasks:
|
|
# New tasks from wrong-label/additional-answer will fallback to live API
|
|
futures[executor.submit(process_single_task, new_task)] = new_task
|
|
except Exception as e:
|
|
print(f"Exception during task execution: {e}", file=sys.stderr)
|
|
|
|
end_time = time.time()
|
|
print("Time elapsed : ", end_time - start_time)
|
|
print("Requests to pro / flash : ", pro_count, flash_count)
|
|
if errors_summary:
|
|
print("\n--- Summary of Exceptions (You can use several images on one instance) ---", file=sys.stderr)
|
|
for (err, file) in errors_summary:
|
|
print(err, file=sys.stderr)
|
|
escaped_path = shlex.quote(str(file))
|
|
print(f"Run : python correction.py {escaped_path}")
|