Limit requests to Gemini pro
parent
bd1362dff8
commit
d725b9edbc
|
|
@ -13,13 +13,13 @@ results = {}
|
|||
|
||||
# Parse Arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--overwrite", action="store_true", help="Force redo requests even if output exists")
|
||||
# parse_known_args is used to avoid conflicts if run inside an environment passing other flags
|
||||
parser.add_argument("--overwrite", action="store_true",
|
||||
help="Force redo requests even if output exists")
|
||||
parser.add_argument("--limit", type=int, help="limit calls to gemini rpo integer")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
|
||||
if arg_path.suffix == ".jpg":
|
||||
# Preserve original behaviour
|
||||
INPUT_DIR = str(arg_path.parents[1])
|
||||
FULL_LABEL = arg_path.parent.name
|
||||
tasks.append((str(arg_path), FULL_LABEL))
|
||||
|
|
@ -151,8 +151,8 @@ if PROXY_URL:
|
|||
os.environ["http_proxy"] = PROXY_URL
|
||||
os.environ["https_proxy"] = PROXY_URL
|
||||
|
||||
MODEL_ID = "gemini-3-pro-preview"
|
||||
MODEL_ID_BIS = "gemini-3-flash-preview"
|
||||
MODEL_ID_pro = "gemini-3-pro-preview"
|
||||
MODEL_ID_flash = "gemini-3-flash-preview"
|
||||
api_key="REMOVED_API_KEY"
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
|
@ -202,17 +202,19 @@ def generate_request(file, full_label):
|
|||
)
|
||||
return (contents, generate_content_config)
|
||||
|
||||
|
||||
|
||||
client = genai.Client(api_key=api_key)
|
||||
output_path = Path(INPUT_DIR) / "correction.json"
|
||||
progress_path = Path(INPUT_DIR) / "correction_progress.json"
|
||||
start_time = time.time()
|
||||
overwrite = args.overwrite
|
||||
limit = args.limit
|
||||
completed_tasks = []
|
||||
|
||||
# --- Lock for thread-safe file writing ---
|
||||
io_lock = threading.Lock()
|
||||
pro_lock = threading.Lock() # New lock for counter
|
||||
pro_count = 0 # New counter
|
||||
flash_count = 0
|
||||
|
||||
if overwrite:
|
||||
if output_path.exists():
|
||||
|
|
@ -244,7 +246,18 @@ def process_single_task(task_tuple):
|
|||
n = len(group_data)
|
||||
d_data = {l[0]: (l[1], l[2]) for l in group_data}
|
||||
total_height = group_data[-1][2]
|
||||
use_flash = n >= 5
|
||||
use_flash = n >= 4 or total_height <= 500
|
||||
if not use_flash and limit is not None:
|
||||
with pro_lock:
|
||||
if pro_count < limit:
|
||||
pro_count += 1
|
||||
else:
|
||||
# Limit reached, force switch to Flash
|
||||
use_flash = True
|
||||
if use_flash:
|
||||
with pro_lock:
|
||||
flash_count += 1
|
||||
|
||||
try:
|
||||
contents, config = generate_request(file_path, label)
|
||||
if use_flash:
|
||||
|
|
@ -256,7 +269,7 @@ def process_single_task(task_tuple):
|
|||
# Assuming client is thread-safe (usually is).
|
||||
# If not, create a new client instance inside this function.
|
||||
for chunk in client.models.generate_content_stream(
|
||||
model=MODEL_ID_BIS if use_flash else MODEL_ID,
|
||||
model=MODEL_ID_flash if use_flash else MODEL_ID_pro,
|
||||
contents=contents,
|
||||
config=config,
|
||||
):
|
||||
|
|
@ -322,4 +335,5 @@ with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
|
|||
executor.map(process_single_task, tasks_to_process)
|
||||
|
||||
end_time = time.time()
|
||||
print("Time elapsed : ", end_time - start_time,"\n\n\n\n\n")
|
||||
print("Time elapsed : ", end_time - start_time)
|
||||
print("Requests to pro / flash : ", pro_count, flash_count)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from pdf2image import convert_from_path, pdfinfo_from_path
|
|||
DPI = 200 # Good balance for readability and size
|
||||
A4_HEIGHT_INCHES = 11.69
|
||||
FULL_PAGE_PX = int(A4_HEIGHT_INCHES * DPI)
|
||||
MAX_GROUP_HEIGHT = 1.75 * FULL_PAGE_PX
|
||||
MAX_GROUP_HEIGHT = 1.5 * FULL_PAGE_PX
|
||||
MAX_GROUP_COUNT = 8
|
||||
SEPARATOR_HEIGHT = 20
|
||||
LABEL_HEIGHT = 50
|
||||
|
|
|
|||
Loading…
Reference in New Issue