Copies/gemini_dir_batching.py

171 lines
5.7 KiB
Python

from google import genai
from google.genai import types
import base64
from pathlib import Path
from pydantic import BaseModel, Field
from typing import List
import sys
import os
import time
import json
import argparse
MODEL_ID = "gemini-3-flash-preview"
api_key="REMOVED_API_KEY"
my_prompt = """I'm giving you an image of the left columns of a written exam.
Students answer several exercises, which can have several questions.
The image consists of several columns, separated by vertical black
lines. The image should be read top to bottom and then left to right,
meaning first column, then second column, etc.
In their sheet, students delimit exercises and questions using
delimiters such as `Ex 1`, or `Exercice 1`, and `1)` or `a)`. You need
to give me the bounding boxes of each delimiter.
When giving the bounding box of the first question of an exercise, the
box should be large enough to contain both the exercice label
(`Exercice i`) and the question label (`1)`) parts.
You also need to give me the student name. It should appear on the top
left of the image. Disregard any mention of `MPSI 3`, it is their
class. A list of possible student names will be given below.
You will answer with a JSON object, containing a `name` field with the
name, and a `list` field, with the list of the bounding boxes and
their labels. The box_2d should be [ymin, xmin, ymax, xmax] normalized
to 0-1000.
Here is an example :
{\"name\" : \"John Doe\", \"list\" : [{\"box_2d\": (10, 20, 30, 40), \"label\" : \"Ex 1 : 1)\"}]}
Do not provide a box_2d for the name. Only for the labels.
You may find the same label present several times, as a student either
recall the current label on a new page, or adds content to its answer
later on. Give the position of each instance of each label.
For this exam you should look for the labels given below, separated by
newlines. A student need not have answered every question, so some may
be missing.
##labels##
Here's a list of the names of the students, pick the one that matches
the best or `\"Unknown\"` if you cannot read the name
##names##"""
class BoxItem(BaseModel):
box_2d: List[int] = Field(description="Bounding box coordinates (e.g., [ymin, xmin, ymax, xmax])")
label: str = Field(description="The label associated with the specific box")
class AnnotationData(BaseModel):
name: str = Field(description="The name identifier")
list: List[BoxItem] = Field(description="List of bounding box items")
def generate_request(file, labels, names):
"""Generates request for Gemini."""
image_path = Path(file)
text = my_prompt.replace("##labels##",labels).replace("##names##", names)
contents = [
types.Content(
role="user",
parts=[
types.Part.from_bytes(
data=image_path.read_bytes(),
mime_type="image/jpeg"
),
types.Part.from_text(text=text),
],
)
]
generate_content_config = types.GenerateContentConfig(
temperature=1.0,
top_p=0.95,
seed=0,
max_output_tokens=65535,
response_mime_type= "application/json",
response_json_schema= AnnotationData.model_json_schema(),
)
return (contents, generate_content_config)
# Argument Parsing
parser = argparse.ArgumentParser(description="Process a directory or specific file using Gemini.")
parser.add_argument("input_path", help="The input directory or specific file (e.g., Dir/File.pdf)")
parser.add_argument("--overwrite", action="store_true", help="Regenerate output even if it exists")
args = parser.parse_args()
input_arg = Path(args.input_path)
image_files = []
# Logic to handle Directory vs File argument
if input_arg.is_file():
# If argument is Dir/Copiedd.pdf
INPUT_DIR = input_arg.parent
CUTLEFT_DIR = INPUT_DIR / 'Cutleft'
# Look for matching .jpg in Cutleft (e.g., Copiedd.jpg)
target_image = CUTLEFT_DIR / f"{input_arg.stem}.jpg"
if target_image.exists():
image_files = [target_image]
else:
print(f"Error: Corresponding image {target_image} not found.")
sys.exit(1)
else:
# If argument is just Dir
INPUT_DIR = input_arg
CUTLEFT_DIR = INPUT_DIR / 'Cutleft'
image_files = sorted(list(CUTLEFT_DIR.glob("*.jpg")))
labels = (INPUT_DIR / "labels").read_text()
names = (INPUT_DIR / "names").read_text()
client = genai.Client(api_key=api_key)
# Target > 3.0s per request to stay under 20 RPM
TARGET_INTERVAL = 3.5
from concurrent.futures import ThreadPoolExecutor
def process_image(image_file):
start_time = time.time()
base_name, _ = os.path.splitext(image_file.name)
output_json = os.path.join(INPUT_DIR, f"{base_name}.json")
# Skip if already processed unless overwrite is enabled
if os.path.exists(output_json) and not args.overwrite:
print(f"Skipping {image_file.name}, output exists.")
return
print(f"Processing {image_file.name}...")
try:
# Prepare and execute request
contents, config = generate_request(image_file, labels, names)
response = client.models.generate_content(
model=MODEL_ID,
contents=contents,
config=config
)
annota = AnnotationData.model_validate_json(response.text)
# Save result
with open(output_json, "w", encoding="utf-8") as f:
json.dump(annota.model_dump(), f, indent=2)
except Exception as e:
print(f"Error processing {image_file.name}: {e}")
# Rate Limiting (Note: This limits per-thread, not global total)
elapsed = time.time() - start_time
time.sleep(max(0, TARGET_INTERVAL - elapsed))
# Run with 6 threads
with ThreadPoolExecutor(max_workers=6) as executor:
executor.map(process_image, image_files)