Limit requests to Gemini pro

master
Sébastien Miquel 2026-02-10 22:06:50 +01:00
parent bd1362dff8
commit d725b9edbc
2 changed files with 25 additions and 11 deletions

View File

@ -13,13 +13,13 @@ results = {}
# Parse Arguments
parser = argparse.ArgumentParser()
parser.add_argument("--overwrite", action="store_true", help="Force redo requests even if output exists")
# parse_known_args is used to avoid conflicts if run inside an environment passing other flags
parser.add_argument("--overwrite", action="store_true",
help="Force redo requests even if output exists")
parser.add_argument("--limit", type=int, help="limit calls to gemini rpo integer")
args, _ = parser.parse_known_args()
if arg_path.suffix == ".jpg":
# Preserve original behaviour
INPUT_DIR = str(arg_path.parents[1])
FULL_LABEL = arg_path.parent.name
tasks.append((str(arg_path), FULL_LABEL))
@ -151,8 +151,8 @@ if PROXY_URL:
os.environ["http_proxy"] = PROXY_URL
os.environ["https_proxy"] = PROXY_URL
MODEL_ID = "gemini-3-pro-preview"
MODEL_ID_BIS = "gemini-3-flash-preview"
MODEL_ID_pro = "gemini-3-pro-preview"
MODEL_ID_flash = "gemini-3-flash-preview"
api_key="REMOVED_API_KEY"
from pydantic import BaseModel, Field, TypeAdapter
@ -202,17 +202,19 @@ def generate_request(file, full_label):
)
return (contents, generate_content_config)
client = genai.Client(api_key=api_key)
output_path = Path(INPUT_DIR) / "correction.json"
progress_path = Path(INPUT_DIR) / "correction_progress.json"
start_time = time.time()
overwrite = args.overwrite
limit = args.limit
completed_tasks = []
# --- Lock for thread-safe file writing ---
io_lock = threading.Lock()
pro_lock = threading.Lock() # New lock for counter
pro_count = 0 # New counter
flash_count = 0
if overwrite:
if output_path.exists():
@ -244,7 +246,18 @@ def process_single_task(task_tuple):
n = len(group_data)
d_data = {l[0]: (l[1], l[2]) for l in group_data}
total_height = group_data[-1][2]
use_flash = n >= 5
use_flash = n >= 4 or total_height <= 500
if not use_flash and limit is not None:
with pro_lock:
if pro_count < limit:
pro_count += 1
else:
# Limit reached, force switch to Flash
use_flash = True
if use_flash:
with pro_lock:
flash_count += 1
try:
contents, config = generate_request(file_path, label)
if use_flash:
@ -256,7 +269,7 @@ def process_single_task(task_tuple):
# Assuming client is thread-safe (usually is).
# If not, create a new client instance inside this function.
for chunk in client.models.generate_content_stream(
model=MODEL_ID_BIS if use_flash else MODEL_ID,
model=MODEL_ID_flash if use_flash else MODEL_ID_pro,
contents=contents,
config=config,
):
@ -322,4 +335,5 @@ with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
executor.map(process_single_task, tasks_to_process)
end_time = time.time()
print("Time elapsed : ", end_time - start_time,"\n\n\n\n\n")
print("Time elapsed : ", end_time - start_time)
print("Requests to pro / flash : ", pro_count, flash_count)

View File

@ -11,7 +11,7 @@ from pdf2image import convert_from_path, pdfinfo_from_path
DPI = 200 # Good balance for readability and size
A4_HEIGHT_INCHES = 11.69
FULL_PAGE_PX = int(A4_HEIGHT_INCHES * DPI)
MAX_GROUP_HEIGHT = 1.75 * FULL_PAGE_PX
MAX_GROUP_HEIGHT = 1.5 * FULL_PAGE_PX
MAX_GROUP_COUNT = 8
SEPARATOR_HEIGHT = 20
LABEL_HEIGHT = 50