Fix batching, hopefully

master
Sébastien Miquel 2026-04-21 14:32:20 +02:00
parent b6a0f5d83f
commit bc47f81556
1 changed files with 35 additions and 20 deletions

View File

@ -741,9 +741,17 @@ if __name__ == "__main__":
else: else:
print(f"Warning: --refaire flag used, but {refaire_path} not found.", file=sys.stderr) print(f"Warning: --refaire flag used, but {refaire_path} not found.", file=sys.stderr)
if args.batch: if args.batch:
batch_file = Path(INPUT_DIR) / "batch_requests.jsonl" batch_flash_file = Path(INPUT_DIR) / "batch_requests_flash.jsonl"
with open(batch_file, "w", encoding="utf-8") as f: batch_pro_file = Path(INPUT_DIR) / "batch_requests_pro.jsonl"
count_flash = 0
count_pro = 0
with open(batch_flash_file, "w", encoding="utf-8") as f_flash, \
open(batch_pro_file, "w", encoding="utf-8") as f_pro:
for task in tasks_to_process: for task in tasks_to_process:
file_path, label = task[0], task[1] file_path, label = task[0], task[1]
group_name = os.path.splitext(file_path)[0] group_name = os.path.splitext(file_path)[0]
@ -752,18 +760,14 @@ if __name__ == "__main__":
with open(json_path, 'r') as jf: with open(json_path, 'r') as jf:
group_data = json.load(jf) group_data = json.load(jf)
use_flash = len(group_data) >= 4 or group_data[-1][2] <= 500 use_flash = len(group_data) >= 4 or group_data[-1][2] <= 500
model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
image_data = Path(file_path).read_bytes() image_data = Path(file_path).read_bytes()
b64_img = base64.b64encode(image_data).decode("utf-8") b64_img = base64.b64encode(image_data).decode("utf-8")
# Format payload. NOTE: adapt the JSON format if your specific Gemini # Format payload matching Gemini Batch API file requirements
# Batch API endpoint expects a slightly different schema.
req = { req = {
"custom_id": file_path, # Mapping ID "key": file_path, # The ID returned in the output file
"method": "POST", "request": {
"url": f"/v1beta/models/{model_to_use}:generateContent",
"body": {
"contents": [{ "contents": [{
"role": "user", "role": "user",
"parts": [ "parts": [
@ -771,7 +775,7 @@ if __name__ == "__main__":
{"text": make_prompt(label)} {"text": make_prompt(label)}
] ]
}], }],
"generationConfig": { "generation_config": {
"temperature": 1.0, "temperature": 1.0,
"topP": 0.95, "topP": 0.95,
"maxOutputTokens": 65535, "maxOutputTokens": 65535,
@ -780,9 +784,18 @@ if __name__ == "__main__":
} }
} }
} }
f.write(json.dumps(req) + "\n")
print(f"Batch generation complete. {len(tasks_to_process)} requests saved to {batch_file}") if use_flash:
f_flash.write(json.dumps(req) + "\n")
count_flash += 1
else:
f_pro.write(json.dumps(req) + "\n")
count_pro += 1
print(f"Batch generation complete.")
print(f" - {count_flash} requests saved to {batch_flash_file} (for {MODEL_ID_flash})")
print(f" - {count_pro} requests saved to {batch_pro_file} (for {MODEL_ID_pro})")
print("Upload these files via the File API and create two separate batch jobs.")
sys.exit(0) sys.exit(0)
batched_responses = {} batched_responses = {}
@ -794,19 +807,21 @@ if __name__ == "__main__":
for line in f: for line in f:
if not line.strip(): continue if not line.strip(): continue
data = json.loads(line) data = json.loads(line)
task_id = data.get("custom_id") task_id = data.get("key") # Corresponds to the key sent in the request
# Extract the JSON response text. Adapt this path to match your API output schema! if "response" in data:
try: try:
resp_text = data["response"]["body"]["candidates"][0]["content"]["parts"][0]["text"] # Extract the JSON response text per standard Batch API schema
resp_text = data["response"]["candidates"][0]["content"]["parts"][0]["text"]
batched_responses[task_id] = resp_text batched_responses[task_id] = resp_text
except (KeyError, IndexError): except (KeyError, IndexError) as e:
batched_responses[task_id] = data.get("response_text", "") print(f"Warning: Could not parse response for {task_id}: {e}", file=sys.stderr)
elif "error" in data:
print(f"Batch API Error for {task_id}: {data['error']}", file=sys.stderr)
else: else:
print(f"Warning: Batch results file {batch_results_path} not found.", file=sys.stderr) print(f"Warning: Batch results file {batch_results_path} not found.", file=sys.stderr)
print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...") print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...")
with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
futures = {} futures = {}
for task in tasks_to_process: for task in tasks_to_process: