89 lines
3.0 KiB
Python
89 lines
3.0 KiB
Python
import os
|
|
import sys
|
|
import argparse
|
|
from google import genai
|
|
|
|
if "GEMINI_API_KEY" not in os.environ:
|
|
sys.exit("Error: GEMINI_API_KEY environment variable not set.")
|
|
|
|
client = genai.Client()
|
|
|
|
def list_jobs():
|
|
print("Fetching recent batch jobs...\n")
|
|
try:
|
|
batch_jobs = client.batches.list()
|
|
jobs_found = False
|
|
|
|
for job in batch_jobs:
|
|
jobs_found = True
|
|
state = job.state.name if hasattr(job.state, 'name') else job.state
|
|
|
|
print("-" * 60)
|
|
print(f"Job Name: {job.name}")
|
|
|
|
if hasattr(job, 'display_name') and job.display_name:
|
|
print(f"Display Name: {job.display_name}")
|
|
|
|
print(f"State: {state}")
|
|
|
|
if state == 'JOB_STATE_FAILED' and hasattr(job, 'error'):
|
|
print(f"Error: {job.error}")
|
|
|
|
if state == 'JOB_STATE_SUCCEEDED' and hasattr(job, 'dest') and job.dest:
|
|
if hasattr(job.dest, 'file_name') and job.dest.file_name:
|
|
print(f"Output File: {job.dest.file_name}")
|
|
|
|
if not jobs_found:
|
|
print("No batch jobs found.")
|
|
else:
|
|
print("-" * 60)
|
|
print("\nTo download a completed job, run:")
|
|
print("python batch_status.py --download batches/<YOUR_BATCH_ID>")
|
|
|
|
except Exception as e:
|
|
sys.exit(f"An error occurred while listing jobs: {e}")
|
|
|
|
|
|
def download_job(job_name):
|
|
print(f"Checking status for {job_name}...\n")
|
|
try:
|
|
job = client.batches.get(name=job_name)
|
|
state = job.state.name if hasattr(job.state, 'name') else job.state
|
|
|
|
print(f"State: {state}")
|
|
|
|
if state != 'JOB_STATE_SUCCEEDED':
|
|
print("Job is not ready yet or has failed.")
|
|
if state == 'JOB_STATE_FAILED' and hasattr(job, 'error'):
|
|
print(f"Error: {job.error}")
|
|
return
|
|
|
|
if hasattr(job, 'dest') and job.dest and hasattr(job.dest, 'file_name') and job.dest.file_name:
|
|
result_file_name = job.dest.file_name
|
|
print(f"Downloading results from {result_file_name}...")
|
|
|
|
file_content_bytes = client.files.download(file=result_file_name)
|
|
output_path = f"results_{job_name.replace('/', '_')}.jsonl"
|
|
|
|
with open(output_path, "wb") as f:
|
|
f.write(file_content_bytes)
|
|
|
|
print(f"Success! Saved to {output_path}")
|
|
print(f"You can now feed this to your correction script using: --deal-with-batched {output_path}")
|
|
else:
|
|
print("Job succeeded but no output file was found.")
|
|
|
|
except Exception as e:
|
|
sys.exit(f"An error occurred while fetching the job: {e}")
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Manage Gemini Batch Jobs")
|
|
parser.add_argument("--download", type=str, metavar="JOB_NAME",
|
|
help="Download the results for a specific batch job (e.g. batches/123456)")
|
|
args = parser.parse_args()
|
|
|
|
if args.download:
|
|
download_job(args.download)
|
|
else:
|
|
list_jobs()
|