Fix batching, hopefully
parent
b6a0f5d83f
commit
bc47f81556
|
|
@ -741,9 +741,17 @@ if __name__ == "__main__":
|
|||
else:
|
||||
print(f"Warning: --refaire flag used, but {refaire_path} not found.", file=sys.stderr)
|
||||
|
||||
|
||||
if args.batch:
|
||||
batch_file = Path(INPUT_DIR) / "batch_requests.jsonl"
|
||||
with open(batch_file, "w", encoding="utf-8") as f:
|
||||
batch_flash_file = Path(INPUT_DIR) / "batch_requests_flash.jsonl"
|
||||
batch_pro_file = Path(INPUT_DIR) / "batch_requests_pro.jsonl"
|
||||
|
||||
count_flash = 0
|
||||
count_pro = 0
|
||||
|
||||
with open(batch_flash_file, "w", encoding="utf-8") as f_flash, \
|
||||
open(batch_pro_file, "w", encoding="utf-8") as f_pro:
|
||||
|
||||
for task in tasks_to_process:
|
||||
file_path, label = task[0], task[1]
|
||||
group_name = os.path.splitext(file_path)[0]
|
||||
|
|
@ -752,18 +760,14 @@ if __name__ == "__main__":
|
|||
with open(json_path, 'r') as jf:
|
||||
group_data = json.load(jf)
|
||||
use_flash = len(group_data) >= 4 or group_data[-1][2] <= 500
|
||||
model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
|
||||
|
||||
image_data = Path(file_path).read_bytes()
|
||||
b64_img = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
# Format payload. NOTE: adapt the JSON format if your specific Gemini
|
||||
# Batch API endpoint expects a slightly different schema.
|
||||
# Format payload matching Gemini Batch API file requirements
|
||||
req = {
|
||||
"custom_id": file_path, # Mapping ID
|
||||
"method": "POST",
|
||||
"url": f"/v1beta/models/{model_to_use}:generateContent",
|
||||
"body": {
|
||||
"key": file_path, # The ID returned in the output file
|
||||
"request": {
|
||||
"contents": [{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
|
|
@ -771,7 +775,7 @@ if __name__ == "__main__":
|
|||
{"text": make_prompt(label)}
|
||||
]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"generation_config": {
|
||||
"temperature": 1.0,
|
||||
"topP": 0.95,
|
||||
"maxOutputTokens": 65535,
|
||||
|
|
@ -780,9 +784,18 @@ if __name__ == "__main__":
|
|||
}
|
||||
}
|
||||
}
|
||||
f.write(json.dumps(req) + "\n")
|
||||
|
||||
print(f"Batch generation complete. {len(tasks_to_process)} requests saved to {batch_file}")
|
||||
if use_flash:
|
||||
f_flash.write(json.dumps(req) + "\n")
|
||||
count_flash += 1
|
||||
else:
|
||||
f_pro.write(json.dumps(req) + "\n")
|
||||
count_pro += 1
|
||||
|
||||
print(f"Batch generation complete.")
|
||||
print(f" - {count_flash} requests saved to {batch_flash_file} (for {MODEL_ID_flash})")
|
||||
print(f" - {count_pro} requests saved to {batch_pro_file} (for {MODEL_ID_pro})")
|
||||
print("Upload these files via the File API and create two separate batch jobs.")
|
||||
sys.exit(0)
|
||||
|
||||
batched_responses = {}
|
||||
|
|
@ -794,19 +807,21 @@ if __name__ == "__main__":
|
|||
for line in f:
|
||||
if not line.strip(): continue
|
||||
data = json.loads(line)
|
||||
task_id = data.get("custom_id")
|
||||
task_id = data.get("key") # Corresponds to the key sent in the request
|
||||
|
||||
# Extract the JSON response text. Adapt this path to match your API output schema!
|
||||
try:
|
||||
resp_text = data["response"]["body"]["candidates"][0]["content"]["parts"][0]["text"]
|
||||
batched_responses[task_id] = resp_text
|
||||
except (KeyError, IndexError):
|
||||
batched_responses[task_id] = data.get("response_text", "")
|
||||
if "response" in data:
|
||||
try:
|
||||
# Extract the JSON response text per standard Batch API schema
|
||||
resp_text = data["response"]["candidates"][0]["content"]["parts"][0]["text"]
|
||||
batched_responses[task_id] = resp_text
|
||||
except (KeyError, IndexError) as e:
|
||||
print(f"Warning: Could not parse response for {task_id}: {e}", file=sys.stderr)
|
||||
elif "error" in data:
|
||||
print(f"Batch API Error for {task_id}: {data['error']}", file=sys.stderr)
|
||||
else:
|
||||
print(f"Warning: Batch results file {batch_results_path} not found.", file=sys.stderr)
|
||||
|
||||
print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
|
||||
futures = {}
|
||||
for task in tasks_to_process:
|
||||
|
|
|
|||
Loading…
Reference in New Issue