from google import genai from google.genai import types import base64 from pathlib import Path from pydantic import BaseModel, Field from typing import List, Dict import sys import os import time import json import argparse import re from collections import defaultdict from concurrent.futures import ThreadPoolExecutor MODEL_ID = "gemini-3-flash-preview" api_key = os.environ["GEMINI_API_KEY"] my_prompt = """I'm giving you an image of the left columns of a written exam. Students answer several exercises, which can have several questions. The image consists of several columns, separated by vertical black lines. The image should be read top to bottom and then left to right, meaning first column, then second column, etc. In their sheet, students delimit exercises and questions using delimiters such as `Ex 1`, or `Exercice 1`, and `1)` or `a)`. You need to give me the bounding boxes of each delimiter. When giving the bounding box of the first question of an exercise, the box should be large enough to contain both the exercice label (`Exercice i`) and the question label (`1)`) parts. If they are horizontally far apart (example : if the `1)` is to the left and the `Exercice i` is either to the right, or in the middle) then give only the bounding box of the question label `1)` part. You should still label it as `Exercice i : 1)` though. You also need to give me the student name. It should appear on the top left of the image. Disregard any mention of `MPSI 3`, it is their class. A list of possible student names will be given below. You will answer with a JSON object, containing a `name` field with the name, and a `list` field, with the list of the bounding boxes and their labels. The box_2d should be [ymin, xmin, ymax, xmax] normalized to 0-1000. Here is an example : {\"name\" : \"John Doe\", \"list\" : [{\"box_2d\": (10, 20, 30, 40), \"label\" : \"Ex 1 : 1)\"}]} Do not provide a box_2d for the name. Only for the labels. Order the box_2d by their position in the page, column by column : first column (top to bottom), then second column, etc. You may find the same label present several times, as a student either recall the current label on a new page, or adds content to its answer later on. Give the position of each instance of each label. For this exam you should look for the labels given below, separated by newlines. A student need not have answered every question, so some may be missing. ##labels## Here's a list of the names of the students, pick the one that matches the best or `\"Unknown\"` if you cannot read the name ##names##""" my_prompt2 = """I'm giving you an image of the left columns of a written exam. Students answer several exercises, which can have several questions. The image consists of several columns, separated by vertical black lines. The image should be read top to bottom and then left to right, meaning first column, then second column, etc. In their sheet, students delimit exercises and questions using delimiters such as `Ex 1`, or `Exercice 1`, and `1)` or `a)`. You need to give me the bounding boxes of each delimiter. When giving the bounding box of the first question of an exercise, the box should be large enough to contain both the exercice label (`Exercice i`) and the question label (`1)`) parts. You also need to give me the student name. It should appear on the top left of the image. Disregard any mention of `MPSI 3`, it is their class. A list of possible student names will be given below. You will answer with a JSON object, containing a `name` field with the name, and a `list` field, with the list of the bounding boxes and their labels. The box_2d should be [ymin, xmin, ymax, xmax] normalized to 0-1000. Here is an example : {\"name\" : \"John Doe\", \"list\" : [{\"box_2d\": (10, 20, 30, 40), \"label\" : \"Ex 1 : 1)\"}]} Do not provide a box_2d for the name. Only for the labels. You may find the same label present several times, as a student either recall the current label on a new page, or adds content to its answer later on. Give the position of each instance of each label. This image is one part of a sequence (e.g., part 2 of 3) for a single student. Here is the list of labels found in the *previous* parts of this copy: [ ##prev_context## ] If the first column starts with a number like =3)= or =c)=, look at the labels in the list above. If the last relevant label was =Ex 4 : 2)=, you should label the new box =Ex 4 : 3)=. For this exam you should look for the labels given below, separated by newlines. A student need not have answered every question, so some may be missing. ##labels## Since this copy isn't the first part of a sequence, simply set the name to `\"Continued\"`.""" class BoxItem(BaseModel): box_2d: List[int] = Field(description="Bounding box coordinates (e.g., [ymin, xmin, ymax, xmax])") label: str = Field(description="The label associated with the specific box") class AnnotationData(BaseModel): name: str = Field(description="The name identifier") list: List[BoxItem] = Field(description="List of bounding box items") def generate_request(file, labels, names, context_labels): """Generates request for Gemini with context.""" image_path = Path(file) # Format context list as a string context_str = ", ".join([f'"{l}"' for l in context_labels]) if context_labels else "No previous context" if context_labels == []: text = my_prompt.replace("##labels##", labels)\ .replace("##names##", names) else: text = my_prompt2.replace("##labels##", labels)\ .replace("##prev_context##", context_str) contents = [ types.Content( role="user", parts=[ types.Part.from_bytes( data=image_path.read_bytes(), mime_type="image/jpeg" ), types.Part.from_text(text=text), ], ) ] generate_content_config = types.GenerateContentConfig( temperature=1.0, top_p=0.95, seed=0, max_output_tokens=65535, response_mime_type= "application/json", response_json_schema= AnnotationData.model_json_schema(), ) return (contents, generate_content_config) # Argument Parsing parser = argparse.ArgumentParser(description="Process a directory or specific files using Gemini.") parser.add_argument("input_paths", nargs='+', help="The input directory or specific files") parser.add_argument("--overwrite", action="store_true", help="Regenerate output even if it exists") args = parser.parse_args() # input_arg = Path(args.input_path) image_files = [] from utils import natural_key for path_str in args.input_paths: input_arg = Path(path_str) target_files = [] # 1. Determine which files to process if input_arg.is_file(): target_files = [input_arg] elif input_arg.is_dir(): target_files = list(input_arg.glob("Copie*.pdf")) if not target_files: print(f"Warning: No Copie*.pdf files found in {input_arg}") else: print(f"Error: {input_arg} is not a valid file or directory.") continue # 2. Run the logic for all collected files for target_file in target_files: INPUT_DIR = target_file.parent CUTLEFT_DIR = INPUT_DIR / 'Cutleft' # Matches stem_01.jpg, stem_02.jpg, etc. found_files = sorted( CUTLEFT_DIR.glob(f"{target_file.stem}_*.jpg"), key=natural_key ) if found_files: image_files.extend(found_files) else: print(f"Warning: No variants found for {target_file.stem} in {CUTLEFT_DIR}") labels_txt = (INPUT_DIR / "labels").read_text() valid_labels_set = set(line.strip() for line in labels_txt.splitlines() if line.strip()) names_path = (INPUT_DIR / "names") if not os.path.exists(names_path): names_path = Path("names") names_txt = names_path.read_text() valid_names_set = set(line.strip() for line in names_txt.splitlines() if line.strip()) valid_names_set.add("Unknown") valid_names_set.add("Continued") client = genai.Client(api_key=api_key) # Group files by Copy ID (e.g. Copie01_01.jpg -> Copie01) # regex: match everything before the last underscore if it ends in digits file_groups = defaultdict(list) for img in image_files: stem = img.stem # match CopieXX_YY -> Group CopieXX match = re.match(r"(.+)_(\d+)$", stem) if match: group_key = match.group(1) file_groups[group_key].append(img) else: # Fallback for files without underscore numbering file_groups[stem].append(img) # Sort files within each group to ensure sequential processing for key in file_groups: file_groups[key].sort(key=lambda x: x.name) TARGET_INTERVAL = 3.5 def process_copy_group(group_key, files): """Processes a list of files belonging to one copy sequentially to maintain context.""" # Context accumulator for this specific copy accumulated_labels = [] for image_file in files: start_time = time.time() base_name = image_file.stem output_json = INPUT_DIR / f"{base_name}.json" # Check existing if output_json.exists() and not args.overwrite: print(f"[{group_key}] Skipping {image_file.name}, output exists.") # If skipping, we should try to load existing labels to keep context for next parts try: with open(output_json, 'r') as f: data = json.load(f) for item in data.get('list', []): accumulated_labels.append(item['label']) except: pass # If read fails, next part has no context continue print(f"[{group_key}] Processing {image_file.name} with {len(accumulated_labels)} accumulated labels...") attempt = -1 while True: attempt += 1 if attempt > 0: time.sleep(10 * attempt) try: contents, config = generate_request(image_file, labels_txt, names_txt, accumulated_labels) response = client.models.generate_content( model=MODEL_ID, contents=contents, config=config ) annota = AnnotationData.model_validate_json(response.text) unknown = [item.label for item in annota.list if item.label not in valid_labels_set] name = annota.name if unknown: print(f"Error: {image_file.name} contained unknown labels: {unknown}") print("Retrying request...") continue # Retry immediately if name not in valid_names_set: print(f"Error: {image_file.name} returned unknown name : {name}") if attempt == 0: print("Retrying request...") continue # Retry immediately else: name = "Unknown" annota.name = name # Save result with open(output_json, "w", encoding="utf-8") as f: json.dump(annota.model_dump(), f, indent=2) # Update context for the next part in this group for box in annota.list: accumulated_labels.append(box.label) break # exit retry loop except Exception as e: print(f"Error processing {image_file.name}: {e}") # Rate Limiting elapsed = time.time() - start_time time.sleep(max(0, TARGET_INTERVAL - elapsed)) # Run ThreadPool on GROUPS (Copies), not individual files # Each thread handles one student's full exam copy sequentially with ThreadPoolExecutor(max_workers=12) as executor: # Convert dict items to arguments for map # executor.map expects a function and an iterable. # We use a lambda or separate function to unpack the tuple if needed, # but here we'll just submit futures. futures = [executor.submit(process_copy_group, k, v) for k, v in file_groups.items()] # Wait for all to complete for future in futures: future.result()