From d725b9edbc73cbcceba2d65b70a9d218dc20ce25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Miquel?= Date: Tue, 10 Feb 2026 22:06:50 +0100 Subject: [PATCH] Limit requests to Gemini pro --- correction.py | 34 ++++++++++++++++++++++++---------- grouping.py | 2 +- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/correction.py b/correction.py index 9f2cafb..7551e83 100644 --- a/correction.py +++ b/correction.py @@ -13,13 +13,13 @@ results = {} # Parse Arguments parser = argparse.ArgumentParser() -parser.add_argument("--overwrite", action="store_true", help="Force redo requests even if output exists") -# parse_known_args is used to avoid conflicts if run inside an environment passing other flags +parser.add_argument("--overwrite", action="store_true", + help="Force redo requests even if output exists") +parser.add_argument("--limit", type=int, help="limit calls to gemini rpo integer") args, _ = parser.parse_known_args() if arg_path.suffix == ".jpg": - # Preserve original behaviour INPUT_DIR = str(arg_path.parents[1]) FULL_LABEL = arg_path.parent.name tasks.append((str(arg_path), FULL_LABEL)) @@ -151,8 +151,8 @@ if PROXY_URL: os.environ["http_proxy"] = PROXY_URL os.environ["https_proxy"] = PROXY_URL -MODEL_ID = "gemini-3-pro-preview" -MODEL_ID_BIS = "gemini-3-flash-preview" +MODEL_ID_pro = "gemini-3-pro-preview" +MODEL_ID_flash = "gemini-3-flash-preview" api_key="REMOVED_API_KEY" from pydantic import BaseModel, Field, TypeAdapter @@ -202,17 +202,19 @@ def generate_request(file, full_label): ) return (contents, generate_content_config) - - client = genai.Client(api_key=api_key) output_path = Path(INPUT_DIR) / "correction.json" progress_path = Path(INPUT_DIR) / "correction_progress.json" start_time = time.time() overwrite = args.overwrite +limit = args.limit completed_tasks = [] # --- Lock for thread-safe file writing --- io_lock = threading.Lock() +pro_lock = threading.Lock() # New lock for counter +pro_count = 0 # New counter +flash_count = 0 if overwrite: if output_path.exists(): @@ -244,7 +246,18 @@ def process_single_task(task_tuple): n = len(group_data) d_data = {l[0]: (l[1], l[2]) for l in group_data} total_height = group_data[-1][2] - use_flash = n >= 5 + use_flash = n >= 4 or total_height <= 500 + if not use_flash and limit is not None: + with pro_lock: + if pro_count < limit: + pro_count += 1 + else: + # Limit reached, force switch to Flash + use_flash = True + if use_flash: + with pro_lock: + flash_count += 1 + try: contents, config = generate_request(file_path, label) if use_flash: @@ -256,7 +269,7 @@ def process_single_task(task_tuple): # Assuming client is thread-safe (usually is). # If not, create a new client instance inside this function. for chunk in client.models.generate_content_stream( - model=MODEL_ID_BIS if use_flash else MODEL_ID, + model=MODEL_ID_flash if use_flash else MODEL_ID_pro, contents=contents, config=config, ): @@ -322,4 +335,5 @@ with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor: executor.map(process_single_task, tasks_to_process) end_time = time.time() -print("Time elapsed : ", end_time - start_time,"\n\n\n\n\n") +print("Time elapsed : ", end_time - start_time) +print("Requests to pro / flash : ", pro_count, flash_count) diff --git a/grouping.py b/grouping.py index 231b141..18300c6 100644 --- a/grouping.py +++ b/grouping.py @@ -11,7 +11,7 @@ from pdf2image import convert_from_path, pdfinfo_from_path DPI = 200 # Good balance for readability and size A4_HEIGHT_INCHES = 11.69 FULL_PAGE_PX = int(A4_HEIGHT_INCHES * DPI) -MAX_GROUP_HEIGHT = 1.75 * FULL_PAGE_PX +MAX_GROUP_HEIGHT = 1.5 * FULL_PAGE_PX MAX_GROUP_COUNT = 8 SEPARATOR_HEIGHT = 20 LABEL_HEIGHT = 50