326 lines
11 KiB
Python
326 lines
11 KiB
Python
import sys
|
|
import os
|
|
import time
|
|
from pathlib import Path
|
|
import argparse
|
|
|
|
if len(sys.argv) < 2:
|
|
sys.exit("Usage: python script.py InterroTest/Ex 2/Group_1.jpg OR <InputDir>")
|
|
|
|
arg_path = Path(sys.argv[1])
|
|
tasks = [] # List of tuples: (filepath_str, label_str)
|
|
results = {}
|
|
|
|
# Parse Arguments
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--overwrite", action="store_true", help="Force redo requests even if output exists")
|
|
# parse_known_args is used to avoid conflicts if run inside an environment passing other flags
|
|
args, _ = parser.parse_known_args()
|
|
|
|
|
|
if arg_path.suffix == ".jpg":
|
|
# Preserve original behaviour
|
|
INPUT_DIR = str(arg_path.parents[1])
|
|
FULL_LABEL = arg_path.parent.name
|
|
tasks.append((str(arg_path), FULL_LABEL))
|
|
results[FULL_LABEL] = []
|
|
else:
|
|
# Directory behaviour
|
|
INPUT_DIR = str(arg_path)
|
|
if not arg_path.exists():
|
|
sys.exit(f"Directory {INPUT_DIR} not found.")
|
|
|
|
for sub in arg_path.iterdir():
|
|
if sub.is_dir() and sub.name.startswith("Ex"):
|
|
label = sub.name
|
|
results[label] = []
|
|
for img in sub.glob("*.jpg"):
|
|
tasks.append((str(img), label))
|
|
|
|
my_prompt = """I'm giving you an image of several written answers to an exam.
|
|
|
|
Each answer is separated by a black horizontal line, and underneath,
|
|
to the left, is indicated the ID of the answer, from `01` to `50`.
|
|
|
|
I want you to score each answer, from 0 to 4, you may score half
|
|
points, such as 2.5. Even if a result is wrong, if the reasoning is
|
|
correct and could lead to a right answer, you should give at least
|
|
half the points.
|
|
|
|
You also need to give feedback to the student, in french :
|
|
- which part of his answer is wrong,
|
|
- why is it wrong
|
|
- possibly, what he should have done instead.
|
|
Your feedback may contain LaTeX fragments written like `$a^2 + b^2 = c^2$`.
|
|
|
|
If your score is note 4, you should always provide some feedback
|
|
explaining what's missing.
|
|
|
|
For each piece of feedback, if it is related to a specific part of the
|
|
answer that is wrong, you may provide a `box_2d`, to locate this
|
|
specific part of the answer. This `box_2d` should be in the form
|
|
[ymin, xmin, ymax, xmax] normalized to 0-1000. If you do not provide
|
|
one, set `box_2d` to `null`.
|
|
|
|
If the answer is correct, there is no need to provide feedback.
|
|
|
|
For example, if the student says a function is continuous when it
|
|
isn't, provide the coordinates where the word «continuous» is. If a
|
|
calculation went wrong, gives the coordinates of the step where it
|
|
goes wrong, and as feedback, what went wrong.
|
|
|
|
You should also give me a measure of confidence, from 0 to 1 that you
|
|
were able to correctly understand the answer. A score below 0.5 means
|
|
that you think it is likely that you couldn't understand an important
|
|
part.
|
|
|
|
In some case, you may find that either
|
|
- The student didn't answer the right question. Set the score to 0.
|
|
Since it could be a labeling error, indicate is by setting `error`
|
|
to \"wrong-label\".
|
|
- You can find an answer to another question of the exercice (taking
|
|
more than a couple of lines). Score the question you are supposed
|
|
to score, but set `error` to \"additional-answer\".
|
|
If there's no error, set `error` to `\"\"`.
|
|
|
|
You will answer using json describing a list of dictionary with a key
|
|
\"id\", and a key \"result\" that contains the \"score\", the \"confidence\", a
|
|
list \"feedback\", and possibly an \"error\". Like this example :
|
|
|
|
[{ \"id\": \"01\",
|
|
\"result\": {\"score\" : 2.5,
|
|
\"confidence\" : 0.8,
|
|
\"feedback\": [{text: \"Un retour générique. Il faut apprendre le cours.\", box_2d: null},
|
|
{text: \"Non, la fonction n'est pas forcément continue\", pos: [145, 280, 340, 500]}],
|
|
\"error\": \"\"}
|
|
},
|
|
{ \"id\": \"04\",
|
|
\"result\": {\"score\" : 4.,
|
|
\"confidence\" : 0.9,
|
|
\"feedback\" : []
|
|
\"error\": \"\" }
|
|
}
|
|
]
|
|
|
|
Here is the text of the exercice of the exam :
|
|
|
|
```
|
|
<<text>>
|
|
```
|
|
|
|
Here is a possible correct answer :
|
|
|
|
```
|
|
<<corr>>
|
|
```
|
|
|
|
Here is some additional scoring instructions :
|
|
|
|
```
|
|
<<persp>>
|
|
```
|
|
|
|
You are asked to score the question or exercice labeled `<<label>>`,
|
|
do not score or give feedback to any other question."""
|
|
|
|
def make_prompt(full_label):
|
|
l = full_label.split(" ")
|
|
ex_label = l[0] + " " + l[1]
|
|
text = (Path(INPUT_DIR) / "Text" / ex_label).read_text()
|
|
corr = (Path(INPUT_DIR) / "Sol" / ex_label).read_text()
|
|
persp = (Path(INPUT_DIR) / "Persp" / ex_label).read_text()
|
|
if persp == "":
|
|
perps = "There is no additional scoring instructions."
|
|
return my_prompt.replace("<<text>>", text).replace("<<corr>>", corr).replace("<<persp>>", persp).replace("<<label>>", full_label)
|
|
|
|
from google import genai
|
|
from google.genai import types
|
|
import base64
|
|
import json
|
|
from pathlib import Path
|
|
import os
|
|
import threading
|
|
import concurrent.futures
|
|
|
|
NB_THREADS = 8
|
|
|
|
# PROXY_URL = "http://192.168.241.1:3128"
|
|
PROXY_URL = None
|
|
|
|
if PROXY_URL:
|
|
os.environ["http_proxy"] = PROXY_URL
|
|
os.environ["https_proxy"] = PROXY_URL
|
|
|
|
MODEL_ID = "gemini-3-pro-preview"
|
|
MODEL_ID_BIS = "gemini-3-flash-preview"
|
|
api_key="REMOVED_API_KEY"
|
|
|
|
from pydantic import BaseModel, Field, TypeAdapter
|
|
from typing import List, Optional, Tuple
|
|
|
|
class FeedbackItem(BaseModel):
|
|
text: str = Field(description="Feedback content")
|
|
box_2d: Optional[List[int]] = Field(None, description="box coordinates or null")
|
|
|
|
class ResultData(BaseModel):
|
|
score: float = Field(description="The numeric score")
|
|
confidence: float = Field(description="Confidence level")
|
|
feedback: List[FeedbackItem] = Field(description="List of feedback items")
|
|
error: str = Field(description="Indicates if an error occurred")
|
|
|
|
class EvaluationEntry(BaseModel):
|
|
id: str = Field(description="Entry identifier")
|
|
result: ResultData = Field(description="Result details")
|
|
|
|
# The root model for parsing is be: List[EvaluationEntry]
|
|
|
|
def generate_request(file, full_label):
|
|
"""Generates request for Gemini."""
|
|
prompt = make_prompt(full_label)
|
|
image_path = Path(file)
|
|
|
|
contents = [
|
|
types.Content(
|
|
role="user",
|
|
parts=[
|
|
types.Part.from_bytes(
|
|
data=image_path.read_bytes(),
|
|
mime_type="image/jpeg"
|
|
),
|
|
types.Part.from_text(text=prompt),
|
|
],
|
|
)
|
|
]
|
|
|
|
generate_content_config = types.GenerateContentConfig(
|
|
temperature=1.0,
|
|
top_p=0.95,
|
|
seed=0,
|
|
max_output_tokens=65535,
|
|
response_mime_type= "application/json",
|
|
response_json_schema= TypeAdapter(List[EvaluationEntry]).json_schema()
|
|
)
|
|
return (contents, generate_content_config)
|
|
|
|
|
|
|
|
client = genai.Client(api_key=api_key)
|
|
output_path = Path(INPUT_DIR) / "correction.json"
|
|
progress_path = Path(INPUT_DIR) / "correction_progress.json"
|
|
start_time = time.time()
|
|
overwrite = args.overwrite
|
|
completed_tasks = []
|
|
|
|
# --- Lock for thread-safe file writing ---
|
|
io_lock = threading.Lock()
|
|
|
|
if overwrite:
|
|
if output_path.exists():
|
|
output_path.unlink()
|
|
if progress_path.exists():
|
|
progress_path.unlink()
|
|
else:
|
|
if progress_path.exists():
|
|
with open(progress_path, "r", encoding="utf-8") as f:
|
|
completed_tasks = json.load(f)
|
|
# Reload existing results to avoid overwriting them with partial data
|
|
if output_path.exists():
|
|
with open(output_path, "r", encoding="utf-8") as f:
|
|
results = json.load(f)
|
|
|
|
# Create a set for O(1) lookup. Normalize paths to strings.
|
|
completed_set = set((str(f), l) for f, l in completed_tasks)
|
|
|
|
# Filter tasks first to avoid overhead in threads
|
|
tasks_to_process = [t for t in tasks if (str(t[0]), t[1]) not in completed_set]
|
|
|
|
def process_single_task(task_tuple):
|
|
file_path, label = task_tuple
|
|
group_name = os.path.splitext(file_path)[0]
|
|
json_path = group_name + '.json'
|
|
with open(json_path, 'r') as f:
|
|
# List of (groupid, start, end), in pixels
|
|
group_data = json.load(f)
|
|
n = len(group_data)
|
|
d_data = {l[0]: (l[1], l[2]) for l in group_data}
|
|
total_height = group_data[-1][2]
|
|
use_flash = n >= 5
|
|
try:
|
|
contents, config = generate_request(file_path, label)
|
|
if use_flash:
|
|
print(f"Asking Flash Gemini: {label} {file_path}")
|
|
else:
|
|
print(f"Asking Gemini: {label} {file_path}")
|
|
|
|
full_response_text = ""
|
|
# Assuming client is thread-safe (usually is).
|
|
# If not, create a new client instance inside this function.
|
|
for chunk in client.models.generate_content_stream(
|
|
model=MODEL_ID_BIS if use_flash else MODEL_ID,
|
|
contents=contents,
|
|
config=config,
|
|
):
|
|
if chunk.text:
|
|
full_response_text += chunk.text
|
|
|
|
# Parse JSON
|
|
json_data = json.loads(full_response_text)
|
|
|
|
if use_flash:
|
|
print(f"Gemini Flash answered for {file_path}")
|
|
else:
|
|
print(f"Gemini answered for {file_path}")
|
|
|
|
# print("Debug : ", json_data)
|
|
# Ensure consistency of answer placements
|
|
for p in json_data:
|
|
pid = p["id"]
|
|
res = p["result"]
|
|
if res["error"] != "":
|
|
print("Error :", res["error"], "for Copie", pid, label, group_name)
|
|
for f in res["feedback"]:
|
|
b = f["box_2d"]
|
|
if b:
|
|
ymin, xmin, ymax, xmax = b
|
|
ymin = ymin * total_height // 1000
|
|
ymax = ymax * total_height // 1000
|
|
|
|
if pid not in d_data:
|
|
print("Error : Gemini answered a copie id not present", pid, label, group_name)
|
|
continue
|
|
yming,ymaxg = d_data[pid]
|
|
if ymin < yming-50 or ymax > ymaxg+50:
|
|
print("Error : Gemini answered box2d not at the right position", pid, label, group_name)
|
|
if ymax < yming or ymin > ymaxg:
|
|
print("Removing the box.")
|
|
f["box_2d"] = None
|
|
continue
|
|
nymin = max(ymin, yming) * 1000 // total_height
|
|
nymax = min(ymax, ymaxg) * 1000 // total_height
|
|
|
|
f["box_2d"] = [nymin, xmin, nymax, xmax]
|
|
# print("Group :", yming, ymaxg, "Answered:", ymin, ymax)
|
|
|
|
# --- CRITICAL: Use Lock for writing shared data ---
|
|
with io_lock:
|
|
if label not in results:
|
|
results[label] = [] # Ensure key exists if not using defaultdict
|
|
results[label].append(json_data)
|
|
|
|
# Save Results
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
except json.JSONDecodeError:
|
|
print(f"Error decoding JSON for {file_path}", file=sys.stderr)
|
|
except Exception as e:
|
|
print(f"Exception processing {file_path}: {e}", file=sys.stderr)
|
|
|
|
print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...")
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
|
|
executor.map(process_single_task, tasks_to_process)
|
|
|
|
end_time = time.time()
|
|
print("Time elapsed : ", end_time - start_time,"\n\n\n\n\n")
|