Initial untested support for batching requests

master
Sébastien Miquel 2026-04-18 11:17:18 +02:00
parent 882c9b64ba
commit 3673bd6fe1
1 changed files with 95 additions and 16 deletions

View File

@ -18,6 +18,10 @@ parser.add_argument("--overwrite", action="store_true",
parser.add_argument("--limit", type=int, help="limit calls to gemini rpo integer") parser.add_argument("--limit", type=int, help="limit calls to gemini rpo integer")
parser.add_argument("--refaire", action="store_true", parser.add_argument("--refaire", action="store_true",
help="Redo specific copies/labels defined in refaire.json") help="Redo specific copies/labels defined in refaire.json")
parser.add_argument("--batch", action="store_true",
help="Generate a JSONL file of requests to send to the Gemini Batch API")
parser.add_argument("--deal-with-batched", type=str, metavar="FILE",
help="Process a JSONL file containing completed batch results")
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
@ -548,7 +552,7 @@ Here is a list of all possible labels. You need to answer with a list one of the
return new_tasks return new_tasks
def process_single_task(task_tuple): def process_single_task(task_tuple, precomputed_response=None):
try: try:
global pro_count, flash_count, pro_quota_exhausted global pro_count, flash_count, pro_quota_exhausted
file_path = task_tuple[0] file_path = task_tuple[0]
@ -567,6 +571,8 @@ def process_single_task(task_tuple):
total_height = group_data[-1][2] total_height = group_data[-1][2]
use_flash = n >= 4 or total_height <= 500 use_flash = n >= 4 or total_height <= 500
# Only apply limits and counts if we are making a live call
if precomputed_response is None:
if not use_flash: if not use_flash:
with pro_lock: with pro_lock:
if pro_quota_exhausted: if pro_quota_exhausted:
@ -583,9 +589,14 @@ def process_single_task(task_tuple):
try: try:
contents, config = generate_request(file_path, label) contents, config = generate_request(file_path, label)
model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
tprint(f"Asking Gemini {'Flash' if use_flash else 'Pro '}: {label} {group_name}")
if precomputed_response:
tprint(f"Using batched response for: {label} {group_name}")
full_response_text = precomputed_response
else:
tprint(f"Asking Gemini {'Flash' if use_flash else 'Pro '}: {label} {group_name}")
full_response_text = call_gemini_with_retries(model_to_use, contents, config) full_response_text = call_gemini_with_retries(model_to_use, contents, config)
json_data = json.loads(full_response_text) json_data = json.loads(full_response_text)
# Ensure consistency of answer placements # Ensure consistency of answer placements
@ -723,10 +734,78 @@ if __name__ == "__main__":
else: else:
print(f"Warning: --refaire flag used, but {refaire_path} not found.", file=sys.stderr) print(f"Warning: --refaire flag used, but {refaire_path} not found.", file=sys.stderr)
if args.batch:
batch_file = Path(INPUT_DIR) / "batch_requests.jsonl"
with open(batch_file, "w", encoding="utf-8") as f:
for task in tasks_to_process:
file_path, label = task[0], task[1]
group_name = os.path.splitext(file_path)[0]
json_path = group_name + '.json'
with open(json_path, 'r') as jf:
group_data = json.load(jf)
use_flash = len(group_data) >= 4 or group_data[-1][2] <= 500
model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
image_data = Path(file_path).read_bytes()
b64_img = base64.b64encode(image_data).decode("utf-8")
# Format payload. NOTE: adapt the JSON format if your specific Gemini
# Batch API endpoint expects a slightly different schema.
req = {
"custom_id": file_path, # Mapping ID
"method": "POST",
"url": f"/v1beta/models/{model_to_use}:generateContent",
"body": {
"contents": [{
"role": "user",
"parts": [
{"inlineData": {"mimeType": "image/jpeg", "data": b64_img}},
{"text": make_prompt(label)}
]
}],
"generationConfig": {
"temperature": 1.0,
"topP": 0.95,
"maxOutputTokens": 65535,
"responseMimeType": "application/json",
"responseSchema": TypeAdapter(List[EvaluationEntry]).json_schema()
}
}
}
f.write(json.dumps(req) + "\n")
print(f"Batch generation complete. {len(tasks_to_process)} requests saved to {batch_file}")
sys.exit(0)
batched_responses = {}
if args.deal_with_batched:
batch_results_path = Path(args.deal_with_batched)
if batch_results_path.exists():
print(f"Loading batch results from {batch_results_path}...")
with open(batch_results_path, "r", encoding="utf-8") as f:
for line in f:
if not line.strip(): continue
data = json.loads(line)
task_id = data.get("custom_id")
# Extract the JSON response text. Adapt this path to match your API output schema!
try:
resp_text = data["response"]["body"]["candidates"][0]["content"]["parts"][0]["text"]
batched_responses[task_id] = resp_text
except (KeyError, IndexError):
batched_responses[task_id] = data.get("response_text", "")
else:
print(f"Warning: Batch results file {batch_results_path} not found.", file=sys.stderr)
print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...") print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...")
with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
futures = {executor.submit(process_single_task, task): task for task in tasks_to_process} futures = {}
for task in tasks_to_process:
file_path = task[0]
precomp = batched_responses.get(file_path)
futures[executor.submit(process_single_task, task, precomp)] = task
# Process tasks as they complete, allowing dynamic task addition # Process tasks as they complete, allowing dynamic task addition
for future in concurrent.futures.as_completed(futures): for future in concurrent.futures.as_completed(futures):
@ -734,11 +813,11 @@ if __name__ == "__main__":
new_generated_tasks = future.result() new_generated_tasks = future.result()
if new_generated_tasks: if new_generated_tasks:
for new_task in new_generated_tasks: for new_task in new_generated_tasks:
# New tasks from wrong-label/additional-answer will fallback to live API
futures[executor.submit(process_single_task, new_task)] = new_task futures[executor.submit(process_single_task, new_task)] = new_task
except Exception as e: except Exception as e:
print(f"Exception during task execution: {e}", file=sys.stderr) print(f"Exception during task execution: {e}", file=sys.stderr)
end_time = time.time() end_time = time.time()
print("Time elapsed : ", end_time - start_time) print("Time elapsed : ", end_time - start_time)
print("Requests to pro / flash : ", pro_count, flash_count) print("Requests to pro / flash : ", pro_count, flash_count)