Fix batching, hopefully
parent
b6a0f5d83f
commit
bc47f81556
|
|
@ -741,9 +741,17 @@ if __name__ == "__main__":
|
||||||
else:
|
else:
|
||||||
print(f"Warning: --refaire flag used, but {refaire_path} not found.", file=sys.stderr)
|
print(f"Warning: --refaire flag used, but {refaire_path} not found.", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
if args.batch:
|
if args.batch:
|
||||||
batch_file = Path(INPUT_DIR) / "batch_requests.jsonl"
|
batch_flash_file = Path(INPUT_DIR) / "batch_requests_flash.jsonl"
|
||||||
with open(batch_file, "w", encoding="utf-8") as f:
|
batch_pro_file = Path(INPUT_DIR) / "batch_requests_pro.jsonl"
|
||||||
|
|
||||||
|
count_flash = 0
|
||||||
|
count_pro = 0
|
||||||
|
|
||||||
|
with open(batch_flash_file, "w", encoding="utf-8") as f_flash, \
|
||||||
|
open(batch_pro_file, "w", encoding="utf-8") as f_pro:
|
||||||
|
|
||||||
for task in tasks_to_process:
|
for task in tasks_to_process:
|
||||||
file_path, label = task[0], task[1]
|
file_path, label = task[0], task[1]
|
||||||
group_name = os.path.splitext(file_path)[0]
|
group_name = os.path.splitext(file_path)[0]
|
||||||
|
|
@ -752,18 +760,14 @@ if __name__ == "__main__":
|
||||||
with open(json_path, 'r') as jf:
|
with open(json_path, 'r') as jf:
|
||||||
group_data = json.load(jf)
|
group_data = json.load(jf)
|
||||||
use_flash = len(group_data) >= 4 or group_data[-1][2] <= 500
|
use_flash = len(group_data) >= 4 or group_data[-1][2] <= 500
|
||||||
model_to_use = MODEL_ID_flash if use_flash else MODEL_ID_pro
|
|
||||||
|
|
||||||
image_data = Path(file_path).read_bytes()
|
image_data = Path(file_path).read_bytes()
|
||||||
b64_img = base64.b64encode(image_data).decode("utf-8")
|
b64_img = base64.b64encode(image_data).decode("utf-8")
|
||||||
|
|
||||||
# Format payload. NOTE: adapt the JSON format if your specific Gemini
|
# Format payload matching Gemini Batch API file requirements
|
||||||
# Batch API endpoint expects a slightly different schema.
|
|
||||||
req = {
|
req = {
|
||||||
"custom_id": file_path, # Mapping ID
|
"key": file_path, # The ID returned in the output file
|
||||||
"method": "POST",
|
"request": {
|
||||||
"url": f"/v1beta/models/{model_to_use}:generateContent",
|
|
||||||
"body": {
|
|
||||||
"contents": [{
|
"contents": [{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"parts": [
|
"parts": [
|
||||||
|
|
@ -771,7 +775,7 @@ if __name__ == "__main__":
|
||||||
{"text": make_prompt(label)}
|
{"text": make_prompt(label)}
|
||||||
]
|
]
|
||||||
}],
|
}],
|
||||||
"generationConfig": {
|
"generation_config": {
|
||||||
"temperature": 1.0,
|
"temperature": 1.0,
|
||||||
"topP": 0.95,
|
"topP": 0.95,
|
||||||
"maxOutputTokens": 65535,
|
"maxOutputTokens": 65535,
|
||||||
|
|
@ -780,9 +784,18 @@ if __name__ == "__main__":
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
f.write(json.dumps(req) + "\n")
|
|
||||||
|
|
||||||
print(f"Batch generation complete. {len(tasks_to_process)} requests saved to {batch_file}")
|
if use_flash:
|
||||||
|
f_flash.write(json.dumps(req) + "\n")
|
||||||
|
count_flash += 1
|
||||||
|
else:
|
||||||
|
f_pro.write(json.dumps(req) + "\n")
|
||||||
|
count_pro += 1
|
||||||
|
|
||||||
|
print(f"Batch generation complete.")
|
||||||
|
print(f" - {count_flash} requests saved to {batch_flash_file} (for {MODEL_ID_flash})")
|
||||||
|
print(f" - {count_pro} requests saved to {batch_pro_file} (for {MODEL_ID_pro})")
|
||||||
|
print("Upload these files via the File API and create two separate batch jobs.")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
batched_responses = {}
|
batched_responses = {}
|
||||||
|
|
@ -794,19 +807,21 @@ if __name__ == "__main__":
|
||||||
for line in f:
|
for line in f:
|
||||||
if not line.strip(): continue
|
if not line.strip(): continue
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
task_id = data.get("custom_id")
|
task_id = data.get("key") # Corresponds to the key sent in the request
|
||||||
|
|
||||||
# Extract the JSON response text. Adapt this path to match your API output schema!
|
if "response" in data:
|
||||||
try:
|
try:
|
||||||
resp_text = data["response"]["body"]["candidates"][0]["content"]["parts"][0]["text"]
|
# Extract the JSON response text per standard Batch API schema
|
||||||
|
resp_text = data["response"]["candidates"][0]["content"]["parts"][0]["text"]
|
||||||
batched_responses[task_id] = resp_text
|
batched_responses[task_id] = resp_text
|
||||||
except (KeyError, IndexError):
|
except (KeyError, IndexError) as e:
|
||||||
batched_responses[task_id] = data.get("response_text", "")
|
print(f"Warning: Could not parse response for {task_id}: {e}", file=sys.stderr)
|
||||||
|
elif "error" in data:
|
||||||
|
print(f"Batch API Error for {task_id}: {data['error']}", file=sys.stderr)
|
||||||
else:
|
else:
|
||||||
print(f"Warning: Batch results file {batch_results_path} not found.", file=sys.stderr)
|
print(f"Warning: Batch results file {batch_results_path} not found.", file=sys.stderr)
|
||||||
|
|
||||||
print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...")
|
print(f"Starting processing on {len(tasks_to_process)} tasks with {NB_THREADS} threads...")
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=NB_THREADS) as executor:
|
||||||
futures = {}
|
futures = {}
|
||||||
for task in tasks_to_process:
|
for task in tasks_to_process:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue