Limit requests to Gemini pro
parent
bd1362dff8
commit
d725b9edbc
|
|
@ -13,13 +13,13 @@ results = {}
|
||||||
|
|
||||||
# Parse Arguments
|
# Parse Arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--overwrite", action="store_true", help="Force redo requests even if output exists")
|
parser.add_argument("--overwrite", action="store_true",
|
||||||
# parse_known_args is used to avoid conflicts if run inside an environment passing other flags
|
help="Force redo requests even if output exists")
|
||||||
|
parser.add_argument("--limit", type=int, help="limit calls to gemini rpo integer")
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
|
||||||
if arg_path.suffix == ".jpg":
|
if arg_path.suffix == ".jpg":
|
||||||
# Preserve original behaviour
|
|
||||||
INPUT_DIR = str(arg_path.parents[1])
|
INPUT_DIR = str(arg_path.parents[1])
|
||||||
FULL_LABEL = arg_path.parent.name
|
FULL_LABEL = arg_path.parent.name
|
||||||
tasks.append((str(arg_path), FULL_LABEL))
|
tasks.append((str(arg_path), FULL_LABEL))
|
||||||
|
|
@ -151,8 +151,8 @@ if PROXY_URL:
|
||||||
os.environ["http_proxy"] = PROXY_URL
|
os.environ["http_proxy"] = PROXY_URL
|
||||||
os.environ["https_proxy"] = PROXY_URL
|
os.environ["https_proxy"] = PROXY_URL
|
||||||
|
|
||||||
MODEL_ID = "gemini-3-pro-preview"
|
MODEL_ID_pro = "gemini-3-pro-preview"
|
||||||
MODEL_ID_BIS = "gemini-3-flash-preview"
|
MODEL_ID_flash = "gemini-3-flash-preview"
|
||||||
api_key="REMOVED_API_KEY"
|
api_key="REMOVED_API_KEY"
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, TypeAdapter
|
from pydantic import BaseModel, Field, TypeAdapter
|
||||||
|
|
@ -202,17 +202,19 @@ def generate_request(file, full_label):
|
||||||
)
|
)
|
||||||
return (contents, generate_content_config)
|
return (contents, generate_content_config)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
client = genai.Client(api_key=api_key)
|
client = genai.Client(api_key=api_key)
|
||||||
output_path = Path(INPUT_DIR) / "correction.json"
|
output_path = Path(INPUT_DIR) / "correction.json"
|
||||||
progress_path = Path(INPUT_DIR) / "correction_progress.json"
|
progress_path = Path(INPUT_DIR) / "correction_progress.json"
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
overwrite = args.overwrite
|
overwrite = args.overwrite
|
||||||
|
limit = args.limit
|
||||||
completed_tasks = []
|
completed_tasks = []
|
||||||
|
|
||||||
# --- Lock for thread-safe file writing ---
|
# --- Lock for thread-safe file writing ---
|
||||||
io_lock = threading.Lock()
|
io_lock = threading.Lock()
|
||||||
|
pro_lock = threading.Lock() # New lock for counter
|
||||||
|
pro_count = 0 # New counter
|
||||||
|
flash_count = 0
|
||||||
|
|
||||||
if overwrite:
|
if overwrite:
|
||||||
if output_path.exists():
|
if output_path.exists():
|
||||||
|
|
@ -244,7 +246,18 @@ def process_single_task(task_tuple):
|
||||||
n = len(group_data)
|
n = len(group_data)
|
||||||
d_data = {l[0]: (l[1], l[2]) for l in group_data}
|
d_data = {l[0]: (l[1], l[2]) for l in group_data}
|
||||||
total_height = group_data[-1][2]
|
total_height = group_data[-1][2]
|
||||||
use_flash = n >= 5
|
use_flash = n >= 4 or total_height <= 500
|
||||||
|
if not use_flash and limit is not None:
|
||||||
|
with pro_lock:
|
||||||
|
if pro_count < limit:
|
||||||
|
pro_count += 1
|
||||||
|
else:
|
||||||
|
# Limit reached, force switch to Flash
|
||||||
|
use_flash = True
|
||||||
|
if use_flash:
|
||||||
|
with pro_lock:
|
||||||
|
flash_count += 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
contents, config = generate_request(file_path, label)
|
contents, config = generate_request(file_path, label)
|
||||||
if use_flash:
|
if use_flash:
|
||||||
|
|
@ -256,7 +269,7 @@ def process_single_task(task_tuple):
|
||||||
# Assuming client is thread-safe (usually is).
|
# Assuming client is thread-safe (usually is).
|
||||||
# If not, create a new client instance inside this function.
|
# If not, create a new client instance inside this function.
|
||||||
for chunk in client.models.generate_content_stream(
|
for chunk in client.models.generate_content_stream(
|
||||||
model=MODEL_ID_BIS if use_flash else MODEL_ID,
|
model=MODEL_ID_flash if use_flash else MODEL_ID_pro,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
config=config,
|
config=config,
|
||||||
):
|
):
|
||||||
|
|
@ -322,4 +335,5 @@ with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
|
||||||
executor.map(process_single_task, tasks_to_process)
|
executor.map(process_single_task, tasks_to_process)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print("Time elapsed : ", end_time - start_time,"\n\n\n\n\n")
|
print("Time elapsed : ", end_time - start_time)
|
||||||
|
print("Requests to pro / flash : ", pro_count, flash_count)
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from pdf2image import convert_from_path, pdfinfo_from_path
|
||||||
DPI = 200 # Good balance for readability and size
|
DPI = 200 # Good balance for readability and size
|
||||||
A4_HEIGHT_INCHES = 11.69
|
A4_HEIGHT_INCHES = 11.69
|
||||||
FULL_PAGE_PX = int(A4_HEIGHT_INCHES * DPI)
|
FULL_PAGE_PX = int(A4_HEIGHT_INCHES * DPI)
|
||||||
MAX_GROUP_HEIGHT = 1.75 * FULL_PAGE_PX
|
MAX_GROUP_HEIGHT = 1.5 * FULL_PAGE_PX
|
||||||
MAX_GROUP_COUNT = 8
|
MAX_GROUP_COUNT = 8
|
||||||
SEPARATOR_HEIGHT = 20
|
SEPARATOR_HEIGHT = 20
|
||||||
LABEL_HEIGHT = 50
|
LABEL_HEIGHT = 50
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue