Copies/gemini_for_labels.py

301 lines
11 KiB
Python

from google import genai
from google.genai import types
import base64
from pathlib import Path
from pydantic import BaseModel, Field
from typing import List, Dict
import sys
import os
import time
import json
import argparse
import re
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
MODEL_ID = "gemini-3-flash-preview"
api_key="REMOVED_API_KEY"
my_prompt = """I'm giving you an image of the left columns of a written exam.
Students answer several exercises, which can have several questions.
The image consists of several columns, separated by vertical black
lines. The image should be read top to bottom and then left to right,
meaning first column, then second column, etc.
In their sheet, students delimit exercises and questions using
delimiters such as `Ex 1`, or `Exercice 1`, and `1)` or `a)`. You need
to give me the bounding boxes of each delimiter.
When giving the bounding box of the first question of an exercise, the
box should be large enough to contain both the exercice label
(`Exercice i`) and the question label (`1)`) parts. If they are
horizontally far apart (example : if the `1)` is to the left and the
`Exercice i` is either to the right, or in the middle) then give only
the bounding box of the question label `1)` part. You should still
label it as `Exercice i : 1)` though.
You also need to give me the student name. It should appear on the top
left of the image. Disregard any mention of `MPSI 3`, it is their
class. A list of possible student names will be given below.
You will answer with a JSON object, containing a `name` field with the
name, and a `list` field, with the list of the bounding boxes and
their labels. The box_2d should be [ymin, xmin, ymax, xmax] normalized
to 0-1000.
Here is an example :
{\"name\" : \"John Doe\", \"list\" : [{\"box_2d\": (10, 20, 30, 40), \"label\" : \"Ex 1 : 1)\"}]}
Do not provide a box_2d for the name. Only for the labels. Order the
box_2d by their position in the page, column by column : first column
(top to bottom), then second column, etc.
You may find the same label present several times, as a student either
recall the current label on a new page, or adds content to its answer
later on. Give the position of each instance of each label.
For this exam you should look for the labels given below, separated by
newlines. A student need not have answered every question, so some may
be missing.
##labels##
Here's a list of the names of the students, pick the one that matches
the best or `\"Unknown\"` if you cannot read the name
##names##"""
my_prompt2 = """I'm giving you an image of the left columns of a written exam.
Students answer several exercises, which can have several questions.
The image consists of several columns, separated by vertical black
lines. The image should be read top to bottom and then left to right,
meaning first column, then second column, etc.
In their sheet, students delimit exercises and questions using
delimiters such as `Ex 1`, or `Exercice 1`, and `1)` or `a)`. You need
to give me the bounding boxes of each delimiter.
When giving the bounding box of the first question of an exercise, the
box should be large enough to contain both the exercice label
(`Exercice i`) and the question label (`1)`) parts.
You also need to give me the student name. It should appear on the top
left of the image. Disregard any mention of `MPSI 3`, it is their
class. A list of possible student names will be given below.
You will answer with a JSON object, containing a `name` field with the
name, and a `list` field, with the list of the bounding boxes and
their labels. The box_2d should be [ymin, xmin, ymax, xmax] normalized
to 0-1000.
Here is an example :
{\"name\" : \"John Doe\", \"list\" : [{\"box_2d\": (10, 20, 30, 40), \"label\" : \"Ex 1 : 1)\"}]}
Do not provide a box_2d for the name. Only for the labels.
You may find the same label present several times, as a student either
recall the current label on a new page, or adds content to its answer
later on. Give the position of each instance of each label.
This image is one part of a sequence (e.g., part 2 of 3) for a single
student. Here is the list of labels found in the *previous* parts of
this copy:
[
##prev_context##
]
If the first column starts with a number like =3)= or =c)=, look at
the labels in the list above. If the last relevant label was =Ex 4 :
2)=, you should label the new box =Ex 4 : 3)=.
For this exam you should look for the labels given below, separated by
newlines. A student need not have answered every question, so some may
be missing.
##labels##
Since this copy isn't the first part of a sequence, simply set the
name to `\"Continued\"`."""
class BoxItem(BaseModel):
box_2d: List[int] = Field(description="Bounding box coordinates (e.g., [ymin, xmin, ymax, xmax])")
label: str = Field(description="The label associated with the specific box")
class AnnotationData(BaseModel):
name: str = Field(description="The name identifier")
list: List[BoxItem] = Field(description="List of bounding box items")
def generate_request(file, labels, names, context_labels):
"""Generates request for Gemini with context."""
image_path = Path(file)
# Format context list as a string
context_str = ", ".join([f'"{l}"' for l in context_labels]) if context_labels else "No previous context"
if context_labels == []:
text = my_prompt.replace("##labels##", labels)\
.replace("##names##", names)
else:
text = my_prompt2.replace("##labels##", labels)\
.replace("##prev_context##", context_str)
contents = [
types.Content(
role="user",
parts=[
types.Part.from_bytes(
data=image_path.read_bytes(),
mime_type="image/jpeg"
),
types.Part.from_text(text=text),
],
)
]
generate_content_config = types.GenerateContentConfig(
temperature=1.0,
top_p=0.95,
seed=0,
max_output_tokens=65535,
response_mime_type= "application/json",
response_json_schema= AnnotationData.model_json_schema(),
)
return (contents, generate_content_config)
# Argument Parsing
parser = argparse.ArgumentParser(description="Process a directory or specific files using Gemini.")
parser.add_argument("input_paths", nargs='+', help="The input directory or specific files")
parser.add_argument("--overwrite", action="store_true", help="Regenerate output even if it exists")
args = parser.parse_args()
# input_arg = Path(args.input_path)
image_files = []
def natural_key(text):
return [int(c) if c.isdigit() else c.lower() for c in re.split(r'(\d+)', str(text))]
for path_str in args.input_paths:
input_arg = Path(path_str)
if input_arg.is_file():
INPUT_DIR = input_arg.parent
CUTLEFT_DIR = INPUT_DIR / 'Cutleft'
# Matches stem_01.jpg, stem_02.jpg, etc.
found_files = sorted(list(CUTLEFT_DIR.glob(f"{input_arg.stem}_*.jpg")),
key=natural_key)
if found_files:
image_files.extend(found_files)
else:
print(f"Warning: No variants found for {input_arg.stem} in {CUTLEFT_DIR}")
elif input_arg.is_dir():
INPUT_DIR = input_arg
CUTLEFT_DIR = INPUT_DIR / 'Cutleft'
image_files.extend(sorted(list(CUTLEFT_DIR.glob("*.jpg")), key=natural_key))
else:
print(f"Error: {input_arg} is not a valid file or directory.")
labels_txt = (INPUT_DIR / "labels").read_text()
valid_labels_set = set(line.strip() for line in labels_txt.splitlines() if line.strip())
names_txt = (INPUT_DIR / "names").read_text()
client = genai.Client(api_key=api_key)
# Group files by Copy ID (e.g. Copie01_01.jpg -> Copie01)
# regex: match everything before the last underscore if it ends in digits
file_groups = defaultdict(list)
for img in image_files:
stem = img.stem
# match CopieXX_YY -> Group CopieXX
match = re.match(r"(.+)_(\d+)$", stem)
if match:
group_key = match.group(1)
file_groups[group_key].append(img)
else:
# Fallback for files without underscore numbering
file_groups[stem].append(img)
# Sort files within each group to ensure sequential processing
for key in file_groups:
file_groups[key].sort(key=lambda x: x.name)
TARGET_INTERVAL = 3.5
def process_copy_group(group_key, files):
"""Processes a list of files belonging to one copy sequentially to maintain context."""
# Context accumulator for this specific copy
accumulated_labels = []
for image_file in files:
start_time = time.time()
base_name = image_file.stem
output_json = INPUT_DIR / f"{base_name}.json"
# Check existing
if output_json.exists() and not args.overwrite:
print(f"[{group_key}] Skipping {image_file.name}, output exists.")
# If skipping, we should try to load existing labels to keep context for next parts
try:
with open(output_json, 'r') as f:
data = json.load(f)
for item in data.get('list', []):
accumulated_labels.append(item['label'])
except:
pass # If read fails, next part has no context
continue
print(f"[{group_key}] Processing {image_file.name} with {len(accumulated_labels)} accumulated labels...")
for attempt in range(2)
try:
contents, config = generate_request(image_file, labels_txt, names_txt, accumulated_labels)
response = client.models.generate_content(
model=MODEL_ID,
contents=contents,
config=config
)
annota = AnnotationData.model_validate_json(response.text)
unknown = [item.label for item in annota.list if item.label not in valid_labels_set]
if unknown:
print(f"Error: {image_file.name} contained unknown labels: {unknown}")
if attempt == 0:
print("Retrying request...")
continue # Retry immediately
# Save result
with open(output_json, "w", encoding="utf-8") as f:
json.dump(annota.model_dump(), f, indent=2)
# Update context for the next part in this group
for box in annota.list:
accumulated_labels.append(box.label)
break # exit retry loop
except Exception as e:
print(f"Error processing {image_file.name}: {e}")
# Rate Limiting
elapsed = time.time() - start_time
time.sleep(max(0, TARGET_INTERVAL - elapsed))
# Run ThreadPool on GROUPS (Copies), not individual files
# Each thread handles one student's full exam copy sequentially
with ThreadPoolExecutor(max_workers=6) as executor:
# Convert dict items to arguments for map
# executor.map expects a function and an iterable.
# We use a lambda or separate function to unpack the tuple if needed,
# but here we'll just submit futures.
futures = [executor.submit(process_copy_group, k, v) for k, v in file_groups.items()]
# Wait for all to complete
for future in futures:
future.result()