Copies/correction.py

356 lines
12 KiB
Python

import sys
import os
import time
from pathlib import Path
import argparse
if len(sys.argv) < 2:
sys.exit("Usage: python script.py InterroTest/Ex 2/Group_1.jpg OR <InputDir>")
arg_path = Path(sys.argv[1])
tasks = [] # List of tuples: (filepath_str, label_str)
results = {}
# Parse Arguments
parser = argparse.ArgumentParser()
parser.add_argument("--overwrite", action="store_true",
help="Force redo requests even if output exists")
parser.add_argument("--limit", type=int, help="limit calls to gemini rpo integer")
args, _ = parser.parse_known_args()
if arg_path.suffix == ".jpg":
INPUT_DIR = str(arg_path.parents[1])
FULL_LABEL = arg_path.parent.name
tasks.append((str(arg_path), FULL_LABEL))
results[FULL_LABEL] = []
else:
# Directory behaviour
INPUT_DIR = str(arg_path)
if not arg_path.exists():
sys.exit(f"Directory {INPUT_DIR} not found.")
for sub in arg_path.iterdir():
if sub.is_dir() and sub.name.startswith("Ex"):
label = sub.name
results[label] = []
for img in sub.glob("*.jpg"):
tasks.append((str(img), label))
my_prompt = """I'm giving you an image of several written answers to an exam.
Each answer is separated by a black horizontal line, and underneath,
to the left, is indicated the ID of the answer, from `01` to `50`.
I want you to score each answer, from 0 to 4, you may score half
points, such as 2.5. Even if a result is wrong, if the reasoning is
correct and could lead to a right answer, you should give at least
half the points.
You also need to give feedback to the student, in french :
- which part of his answer is wrong,
- why is it wrong
- possibly, what he should have done instead.
Your feedback may contain LaTeX fragments written like `$a^2 + b^2 = c^2$`.
If your score is not 4, you should always provide some feedback
explaining what's missing.
For each piece of feedback, if it is related to a specific part of the
answer that is wrong, you may provide a `box_2d`, to locate this
specific part of the answer. This `box_2d` should be in the form
[ymin, xmin, ymax, xmax] normalized to 0-1000. If you do not provide
one, set `box_2d` to `null`.
If the answer is correct, there is no need to provide feedback. You do
not have to give positive feedback, but if you do, do not provide a
`box_2d` for it.
For example, if the student says a function is continuous when it
isn't, provide the coordinates where the word «continuous» is. If a
calculation went wrong, gives the coordinates of the step where it
goes wrong, and as feedback, what went wrong.
You should also give me a measure of confidence, from 0 to 1 that you
were able to correctly understand the answer. A score below 0.5 means
that you think it is likely that you couldn't understand an important
part.
In some case, you may find that either
- The student didn't answer the right question. Set the score to 0.
Since it could be a labeling error, indicate is by setting `error`
to \"wrong-label\".
- You can find an answer to another question of the exercice (taking
more than a couple of lines). Score the question you are supposed
to score, but set `error` to \"additional-answer\".
- The answer to the question is empty, or the student has only
rewritten the statement of the question. In this case, set `error`
to \"empty-answer\" and do not provide any kind of feedback.
If there's no error, set `error` to `\"\"`.
You will answer using json describing a list of dictionary with a key
\"id\", and a key \"result\" that contains the \"score\", the \"confidence\", a
list \"feedback\", and possibly an \"error\". Like this example :
[{ \"id\": \"01\",
\"result\": {\"score\" : 2.5,
\"confidence\" : 0.8,
\"feedback\": [{text: \"Un retour générique. Il faut apprendre le cours.\", box_2d: null},
{text: \"Non, la fonction n'est pas forcément continue\", pos: [145, 280, 340, 500]}],
\"error\": \"\"}
},
{ \"id\": \"04\",
\"result\": {\"score\" : 4.,
\"confidence\" : 0.9,
\"feedback\" : []
\"error\": \"\" }
}
]
Here is the text of the exercice of the exam :
```
<<text>>
```
Here is a possible correct answer :
```
<<corr>>
```
Here is some additional scoring instructions :
```
<<persp>>
```
You are asked to score the question or exercice labeled `<<label>>`,
do not score or give feedback to any other question."""
def make_prompt(full_label):
l = full_label.split(" ")
ex_label = l[0] + " " + l[1]
text = (Path(INPUT_DIR) / "Text" / ex_label).read_text()
corr = (Path(INPUT_DIR) / "Sol" / ex_label).read_text()
persp = (Path(INPUT_DIR) / "Persp" / ex_label).read_text()
if persp == "":
perps = "There is no additional scoring instructions."
return my_prompt.replace("<<text>>", text).replace("<<corr>>", corr).replace("<<persp>>", persp).replace("<<label>>", full_label)
from google import genai
from google.genai import types
import base64
import json
from pathlib import Path
import os
import threading
import concurrent.futures
NB_THREADS = 8
# PROXY_URL = "http://192.168.241.1:3128"
PROXY_URL = None
if PROXY_URL:
os.environ["http_proxy"] = PROXY_URL
os.environ["https_proxy"] = PROXY_URL
MODEL_ID_pro = "gemini-3.1-pro-preview"
MODEL_ID_flash = "gemini-3-flash-preview"
api_key = os.environ["GEMINI_API_KEY"]
from pydantic import BaseModel, Field, TypeAdapter
from typing import List, Optional, Tuple
class FeedbackItem(BaseModel):
text: str = Field(description="Feedback content")
box_2d: Optional[List[int]] = Field(None, description="box coordinates or null")
class ResultData(BaseModel):
score: float = Field(description="The numeric score")
confidence: float = Field(description="Confidence level")
feedback: List[FeedbackItem] = Field(description="List of feedback items")
error: str = Field(description="Indicates if an error occurred")
class EvaluationEntry(BaseModel):
id: str = Field(description="Entry identifier")
result: ResultData = Field(description="Result details")
# The root model for parsing is be: List[EvaluationEntry]
def generate_request(file, full_label):
"""Generates request for Gemini."""
prompt = make_prompt(full_label)
image_path = Path(file)
contents = [
types.Content(
role="user",
parts=[
types.Part.from_bytes(
data=image_path.read_bytes(),
mime_type="image/jpeg"
),
types.Part.from_text(text=prompt),
],
)
]
generate_content_config = types.GenerateContentConfig(
temperature=1.0,
top_p=0.95,
seed=0,
max_output_tokens=65535,
response_mime_type= "application/json",
response_json_schema= TypeAdapter(List[EvaluationEntry]).json_schema()
)
return (contents, generate_content_config)
client = genai.Client(api_key=api_key)
output_path = Path(INPUT_DIR) / "correction.json"
progress_path = Path(INPUT_DIR) / "correction_progress.json"
start_time = time.time()
overwrite = args.overwrite
limit = args.limit
completed_tasks = []
# --- Lock for thread-safe file writing ---
io_lock = threading.Lock()
pro_lock = threading.Lock() # New lock for counter
pro_count = 0 # New counter
flash_count = 0
if overwrite:
if output_path.exists():
output_path.unlink()
if progress_path.exists():
progress_path.unlink()
else:
if progress_path.exists():
with open(progress_path, "r", encoding="utf-8") as f:
completed_tasks = json.load(f)
# Reload existing results to avoid overwriting them with partial data
if output_path.exists():
with open(output_path, "r", encoding="utf-8") as f:
results = json.load(f)
# Create a set for O(1) lookup. Normalize paths to strings.
completed_set = set((str(f), l) for f, l in completed_tasks)
# Filter tasks first to avoid overhead in threads
tasks_to_process = [t for t in tasks if (str(t[0]), t[1]) not in completed_set]
def process_single_task(task_tuple):
global pro_count, flash_count
file_path, label = task_tuple
group_name = os.path.splitext(file_path)[0]
json_path = group_name + '.json'
with open(json_path, 'r') as f:
# List of (groupid, start, end), in pixels
group_data = json.load(f)
n = len(group_data)
# l[3] is ratio of width to width of group
d_data = {l[0]: (l[1], l[2], l[3]) for l in group_data}
total_height = group_data[-1][2]
use_flash = n >= 4 or total_height <= 500
if not use_flash and limit is not None:
with pro_lock:
if pro_count < limit:
pro_count += 1
else:
# Limit reached, force switch to Flash
use_flash = True
if use_flash:
with pro_lock:
flash_count += 1
try:
contents, config = generate_request(file_path, label)
if use_flash:
print(f"Asking Flash Gemini: {label} {file_path}")
else:
print(f"Asking Gemini: {label} {file_path}")
full_response_text = ""
# Assuming client is thread-safe (usually is).
# If not, create a new client instance inside this function.
for chunk in client.models.generate_content_stream(
model=MODEL_ID_flash if use_flash else MODEL_ID_pro,
contents=contents,
config=config,
):
if chunk.text:
full_response_text += chunk.text
# Parse JSON
json_data = json.loads(full_response_text)
if use_flash:
print(f"Gemini Flash answered for {file_path}")
else:
print(f"Gemini answered for {file_path}")
# print("Debug : ", json_data)
# Ensure consistency of answer placements
for p in json_data:
pid = p["id"]
res = p["result"]
if res["error"] != "":
print("Error :", res["error"], "for Copie", pid, label, group_name)
for f in res["feedback"]:
b = f["box_2d"]
if b:
ymin, xmin, ymax, xmax = b
ymin = ymin * total_height // 1000
ymax = ymax * total_height // 1000
if pid not in d_data:
print("Error : Gemini answered a copie id not present", pid, label, group_name)
continue
yming,ymaxg, width_r = d_data[pid]
if ymin < yming-50 or ymax > ymaxg+50:
print("Error : Gemini answered box2d too low/up",
pid, label, group_name)
if ymax < yming or ymin > ymaxg:
print("Removing the box.")
f["box_2d"] = None
continue
nymin = max(ymin, yming) * 1000 // total_height
nymax = min(ymax, ymaxg) * 1000 // total_height
f["box_2d"] = [nymin, xmin, nymax, xmax]
# print("Group :", yming, ymaxg, "Answered:", ymin, ymax)
if xmax / 1000 > width_r:
print("Error : Gemini answered box2d too right",
pid, label, group_name)
if xmin /1000 > width_r:
print("Removing the box.")
f["box_2d"] = None
continue
f["box_2d"][3] = int(width_r * 1000)
# --- Use Lock for writing shared data ---
with io_lock:
if label not in results:
results[label] = [] # Ensure key exists if not using defaultdict
results[label].append(json_data)
# Save Results
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
except json.JSONDecodeError:
print(f"Error decoding JSON for {file_path}", file=sys.stderr)
except Exception as e:
print(f"Exception processing {file_path}: {e}", file=sys.stderr)
print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...")
with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
executor.map(process_single_task, tasks_to_process)
end_time = time.time()
print("Time elapsed : ", end_time - start_time)
print("Requests to pro / flash : ", pro_count, flash_count)