Copies/fetch_batched_results.py

64 lines
2.3 KiB
Python

import os
import sys
import argparse
from pathlib import Path
from google import genai
def main():
parser = argparse.ArgumentParser(description="Download and combine completed batch jobs for a directory.")
parser.add_argument("root_dir", type=str, help="Directory containing the original batches")
args = parser.parse_args()
target_dir = Path(args.root_dir)
dir_name = target_dir.name
output_path = target_dir / "batched_correction_result.jsonl"
if "GEMINI_API_KEY" not in os.environ:
sys.exit("Error: GEMINI_API_KEY environment variable not set.")
client = genai.Client()
print(f"Fetching jobs matching '{dir_name}'...")
all_jobs = client.batches.list()
matching_jobs = []
# 1. Find jobs associated with this directory
for job in all_jobs:
if hasattr(job, 'display_name') and job.display_name and dir_name in job.display_name:
matching_jobs.append(job)
if not matching_jobs:
sys.exit(f"No batch jobs found containing '{dir_name}' in their display name.")
# 2. Check that all matching jobs are complete
for job in matching_jobs:
state = job.state.name if hasattr(job.state, 'name') else job.state
print(f"Found Job: {job.display_name} | State: {state}")
if state != 'JOB_STATE_SUCCEEDED':
sys.exit(f"Error: Job '{job.display_name}' has not succeeded yet. Try again later.")
# 3. Download and concatenate
print("\nAll jobs succeeded. Downloading results...")
combined_data = b""
for job in matching_jobs:
if hasattr(job, 'dest') and job.dest and hasattr(job.dest, 'file_name') and job.dest.file_name:
print(f"Downloading output for {job.display_name}...")
file_content_bytes = client.files.download(file=job.dest.file_name)
combined_data += file_content_bytes
# Ensure proper line separation between files in JSONL
if combined_data and not combined_data.endswith(b'\n'):
combined_data += b'\n'
else:
print(f"Warning: Job {job.display_name} succeeded but has no output file.")
# 4. Save to destination
with open(output_path, "wb") as f:
f.write(combined_data)
print(f"\nSuccess! All results concatenated and saved to:\n{output_path}")
if __name__ == "__main__":
main()