327 lines
12 KiB
Python
327 lines
12 KiB
Python
from google import genai
|
|
from google.genai import types
|
|
import base64
|
|
from pathlib import Path
|
|
from pydantic import BaseModel, Field
|
|
from typing import List, Dict
|
|
import sys
|
|
import os
|
|
import time
|
|
import json
|
|
import argparse
|
|
import re
|
|
from collections import defaultdict
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
MODEL_ID = "gemini-3-flash-preview"
|
|
api_key = os.environ["GEMINI_API_KEY"]
|
|
|
|
my_prompt = """I'm giving you an image of the left columns of a written exam.
|
|
Students answer several exercises, which can have several questions.
|
|
|
|
The image consists of several columns, separated by vertical black
|
|
lines. The image should be read top to bottom and then left to right,
|
|
meaning first column, then second column, etc.
|
|
|
|
In their sheet, students delimit exercises and questions using
|
|
delimiters such as `Ex 1`, or `Exercice 1`, and `1)` or `a)`. You need
|
|
to give me the bounding boxes of each delimiter.
|
|
|
|
When giving the bounding box of the first question of an exercise, the
|
|
box should be large enough to contain both the exercice label
|
|
(`Exercice i`) and the question label (`1)`) parts. If they are
|
|
horizontally far apart (example : if the `1)` is to the left and the
|
|
`Exercice i` is either to the right, or in the middle) then give only
|
|
the bounding box of the question label `1)` part. You should still
|
|
label it as `Exercice i : 1)` though.
|
|
|
|
You also need to give me the student name. It should appear on the top
|
|
left of the image. Disregard any mention of `MPSI 3`, it is their
|
|
class. A list of possible student names will be given below.
|
|
|
|
You will answer with a JSON object, containing a `name` field with the
|
|
name, and a `list` field, with the list of the bounding boxes and
|
|
their labels. The box_2d should be [ymin, xmin, ymax, xmax] normalized
|
|
to 0-1000.
|
|
|
|
Here is an example :
|
|
{\"name\" : \"John Doe\", \"list\" : [{\"box_2d\": (10, 20, 30, 40), \"label\" : \"Ex 1 : 1)\"}]}
|
|
|
|
Do not provide a box_2d for the name. Only for the labels. Order the
|
|
box_2d by their position in the page, column by column : first column
|
|
(top to bottom), then second column, etc.
|
|
|
|
You may find the same label present several times, as a student either
|
|
recall the current label on a new page, or adds content to its answer
|
|
later on. Give the position of each instance of each label.
|
|
|
|
For this exam you should look for the labels given below, separated by
|
|
newlines. A student need not have answered every question, so some may
|
|
be missing.
|
|
|
|
##labels##
|
|
|
|
Here's a list of the names of the students, pick the one that matches
|
|
the best or `\"Unknown\"` if you cannot read the name
|
|
|
|
##names##"""
|
|
my_prompt2 = """I'm giving you an image of the left columns of a written exam.
|
|
Students answer several exercises, which can have several questions.
|
|
|
|
The image consists of several columns, separated by vertical black
|
|
lines. The image should be read top to bottom and then left to right,
|
|
meaning first column, then second column, etc.
|
|
|
|
In their sheet, students delimit exercises and questions using
|
|
delimiters such as `Ex 1`, or `Exercice 1`, and `1)` or `a)`. You need
|
|
to give me the bounding boxes of each delimiter.
|
|
|
|
When giving the bounding box of the first question of an exercise, the
|
|
box should be large enough to contain both the exercice label
|
|
(`Exercice i`) and the question label (`1)`) parts.
|
|
|
|
You also need to give me the student name. It should appear on the top
|
|
left of the image. Disregard any mention of `MPSI 3`, it is their
|
|
class. A list of possible student names will be given below.
|
|
|
|
You will answer with a JSON object, containing a `name` field with the
|
|
name, and a `list` field, with the list of the bounding boxes and
|
|
their labels. The box_2d should be [ymin, xmin, ymax, xmax] normalized
|
|
to 0-1000.
|
|
|
|
Here is an example :
|
|
{\"name\" : \"John Doe\", \"list\" : [{\"box_2d\": (10, 20, 30, 40), \"label\" : \"Ex 1 : 1)\"}]}
|
|
|
|
Do not provide a box_2d for the name. Only for the labels.
|
|
|
|
You may find the same label present several times, as a student either
|
|
recall the current label on a new page, or adds content to its answer
|
|
later on. Give the position of each instance of each label.
|
|
|
|
This image is one part of a sequence (e.g., part 2 of 3) for a single
|
|
student. Here is the list of labels found in the *previous* parts of
|
|
this copy:
|
|
|
|
[
|
|
##prev_context##
|
|
]
|
|
|
|
If the first column starts with a number like =3)= or =c)=, look at
|
|
the labels in the list above. If the last relevant label was =Ex 4 :
|
|
2)=, you should label the new box =Ex 4 : 3)=.
|
|
|
|
For this exam you should look for the labels given below, separated by
|
|
newlines. A student need not have answered every question, so some may
|
|
be missing.
|
|
|
|
##labels##
|
|
|
|
Since this copy isn't the first part of a sequence, simply set the
|
|
name to `\"Continued\"`."""
|
|
|
|
class BoxItem(BaseModel):
|
|
box_2d: List[int] = Field(description="Bounding box coordinates (e.g., [ymin, xmin, ymax, xmax])")
|
|
label: str = Field(description="The label associated with the specific box")
|
|
|
|
class AnnotationData(BaseModel):
|
|
name: str = Field(description="The name identifier")
|
|
list: List[BoxItem] = Field(description="List of bounding box items")
|
|
|
|
|
|
def generate_request(file, labels, names, context_labels):
|
|
"""Generates request for Gemini with context."""
|
|
|
|
image_path = Path(file)
|
|
|
|
# Format context list as a string
|
|
context_str = ", ".join([f'"{l}"' for l in context_labels]) if context_labels else "No previous context"
|
|
|
|
if context_labels == []:
|
|
text = my_prompt.replace("##labels##", labels)\
|
|
.replace("##names##", names)
|
|
else:
|
|
text = my_prompt2.replace("##labels##", labels)\
|
|
.replace("##prev_context##", context_str)
|
|
|
|
contents = [
|
|
types.Content(
|
|
role="user",
|
|
parts=[
|
|
types.Part.from_bytes(
|
|
data=image_path.read_bytes(),
|
|
mime_type="image/jpeg"
|
|
),
|
|
types.Part.from_text(text=text),
|
|
],
|
|
)
|
|
]
|
|
|
|
generate_content_config = types.GenerateContentConfig(
|
|
temperature=1.0,
|
|
top_p=0.95,
|
|
seed=0,
|
|
max_output_tokens=65535,
|
|
response_mime_type= "application/json",
|
|
response_json_schema= AnnotationData.model_json_schema(),
|
|
)
|
|
return (contents, generate_content_config)
|
|
|
|
# Argument Parsing
|
|
parser = argparse.ArgumentParser(description="Process a directory or specific files using Gemini.")
|
|
parser.add_argument("input_paths", nargs='+', help="The input directory or specific files")
|
|
parser.add_argument("--overwrite", action="store_true", help="Regenerate output even if it exists")
|
|
args = parser.parse_args()
|
|
|
|
# input_arg = Path(args.input_path)
|
|
image_files = []
|
|
|
|
from utils import natural_key
|
|
|
|
for path_str in args.input_paths:
|
|
input_arg = Path(path_str)
|
|
target_files = []
|
|
|
|
# 1. Determine which files to process
|
|
if input_arg.is_file():
|
|
target_files = [input_arg]
|
|
elif input_arg.is_dir():
|
|
target_files = list(input_arg.glob("Copie*.pdf"))
|
|
if not target_files:
|
|
print(f"Warning: No Copie*.pdf files found in {input_arg}")
|
|
else:
|
|
print(f"Error: {input_arg} is not a valid file or directory.")
|
|
continue
|
|
|
|
# 2. Run the logic for all collected files
|
|
for target_file in target_files:
|
|
INPUT_DIR = target_file.parent
|
|
CUTLEFT_DIR = INPUT_DIR / 'Cutleft'
|
|
|
|
# Matches stem_01.jpg, stem_02.jpg, etc.
|
|
found_files = sorted(
|
|
CUTLEFT_DIR.glob(f"{target_file.stem}_*.jpg"),
|
|
key=natural_key
|
|
)
|
|
|
|
if found_files:
|
|
image_files.extend(found_files)
|
|
else:
|
|
print(f"Warning: No variants found for {target_file.stem} in {CUTLEFT_DIR}")
|
|
|
|
labels_txt = (INPUT_DIR / "labels").read_text()
|
|
valid_labels_set = set(line.strip() for line in labels_txt.splitlines() if line.strip())
|
|
names_path = (INPUT_DIR / "names")
|
|
if !os.path.exists(names_path):
|
|
names_path = Path("names")
|
|
names_txt = names_path.read_text()
|
|
|
|
valid_names_set = set(line.strip() for line in names_txt.splitlines() if line.strip())
|
|
valid_names_set.add("Unknown")
|
|
valid_names_set.add("Continued")
|
|
|
|
client = genai.Client(api_key=api_key)
|
|
|
|
# Group files by Copy ID (e.g. Copie01_01.jpg -> Copie01)
|
|
# regex: match everything before the last underscore if it ends in digits
|
|
file_groups = defaultdict(list)
|
|
for img in image_files:
|
|
stem = img.stem
|
|
# match CopieXX_YY -> Group CopieXX
|
|
match = re.match(r"(.+)_(\d+)$", stem)
|
|
if match:
|
|
group_key = match.group(1)
|
|
file_groups[group_key].append(img)
|
|
else:
|
|
# Fallback for files without underscore numbering
|
|
file_groups[stem].append(img)
|
|
|
|
# Sort files within each group to ensure sequential processing
|
|
for key in file_groups:
|
|
file_groups[key].sort(key=lambda x: x.name)
|
|
|
|
TARGET_INTERVAL = 3.5
|
|
|
|
def process_copy_group(group_key, files):
|
|
"""Processes a list of files belonging to one copy sequentially to maintain context."""
|
|
|
|
# Context accumulator for this specific copy
|
|
accumulated_labels = []
|
|
|
|
for image_file in files:
|
|
start_time = time.time()
|
|
base_name = image_file.stem
|
|
output_json = INPUT_DIR / f"{base_name}.json"
|
|
|
|
# Check existing
|
|
if output_json.exists() and not args.overwrite:
|
|
print(f"[{group_key}] Skipping {image_file.name}, output exists.")
|
|
# If skipping, we should try to load existing labels to keep context for next parts
|
|
try:
|
|
with open(output_json, 'r') as f:
|
|
data = json.load(f)
|
|
for item in data.get('list', []):
|
|
accumulated_labels.append(item['label'])
|
|
except:
|
|
pass # If read fails, next part has no context
|
|
continue
|
|
|
|
print(f"[{group_key}] Processing {image_file.name} with {len(accumulated_labels)} accumulated labels...")
|
|
|
|
attempt = -1
|
|
while True:
|
|
attempt += 1
|
|
if attempt > 0:
|
|
time.sleep(10 * attempt)
|
|
try:
|
|
contents, config = generate_request(image_file, labels_txt, names_txt, accumulated_labels)
|
|
|
|
response = client.models.generate_content(
|
|
model=MODEL_ID,
|
|
contents=contents,
|
|
config=config
|
|
)
|
|
|
|
annota = AnnotationData.model_validate_json(response.text)
|
|
unknown = [item.label for item in annota.list if item.label not in valid_labels_set]
|
|
name = annota.name
|
|
if unknown:
|
|
print(f"Error: {image_file.name} contained unknown labels: {unknown}")
|
|
print("Retrying request...")
|
|
continue # Retry immediately
|
|
|
|
if name not in valid_names_set:
|
|
print(f"Error: {image_file.name} returned unknown name : {name}")
|
|
if attempt == 0:
|
|
print("Retrying request...")
|
|
continue # Retry immediately
|
|
else:
|
|
name = "Unknown"
|
|
|
|
# Save result
|
|
with open(output_json, "w", encoding="utf-8") as f:
|
|
json.dump(annota.model_dump(), f, indent=2)
|
|
|
|
# Update context for the next part in this group
|
|
for box in annota.list:
|
|
accumulated_labels.append(box.label)
|
|
break # exit retry loop
|
|
except Exception as e:
|
|
print(f"Error processing {image_file.name}: {e}")
|
|
|
|
# Rate Limiting
|
|
elapsed = time.time() - start_time
|
|
time.sleep(max(0, TARGET_INTERVAL - elapsed))
|
|
|
|
# Run ThreadPool on GROUPS (Copies), not individual files
|
|
# Each thread handles one student's full exam copy sequentially
|
|
with ThreadPoolExecutor(max_workers=12) as executor:
|
|
# Convert dict items to arguments for map
|
|
# executor.map expects a function and an iterable.
|
|
# We use a lambda or separate function to unpack the tuple if needed,
|
|
# but here we'll just submit futures.
|
|
futures = [executor.submit(process_copy_group, k, v) for k, v in file_groups.items()]
|
|
|
|
# Wait for all to complete
|
|
for future in futures:
|
|
future.result()
|