78 lines
2.4 KiB
Python
78 lines
2.4 KiB
Python
import os
|
|
import sys
|
|
import argparse
|
|
from pathlib import Path
|
|
from google import genai
|
|
from google.genai import types
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Upload JSONL files and create Gemini Batch jobs.")
|
|
parser.add_argument("root_dir", type=str, help="Root directory containing the batch JSONL files")
|
|
args = parser.parse_args()
|
|
|
|
root_dir = Path(args.root_dir)
|
|
|
|
if "GEMINI_API_KEY" not in os.environ:
|
|
sys.exit("Error: GEMINI_API_KEY environment variable not set.")
|
|
|
|
client = genai.Client()
|
|
|
|
# Define the batch files and their corresponding models
|
|
batches_to_create = [
|
|
{
|
|
"file_path": root_dir / "batch_requests_flash.jsonl",
|
|
"model_id": "gemini-3-flash-preview",
|
|
"display_name": f"flash-correction-{root_dir.name}"
|
|
},
|
|
{
|
|
"file_path": root_dir / "batch_requests_pro.jsonl",
|
|
"model_id": "gemini-3.1-pro-preview",
|
|
"display_name": f"pro-correction-{root_dir.name}"
|
|
}
|
|
]
|
|
|
|
for batch in batches_to_create:
|
|
file_path = batch["file_path"]
|
|
model_id = batch["model_id"]
|
|
display_name = batch["display_name"]
|
|
|
|
# Check if the file exists
|
|
if not file_path.exists():
|
|
print(f"Skipping {model_id}: {file_path.name} does not exist.")
|
|
continue
|
|
|
|
# Check if the file is empty (e.g., if all tasks went to Flash, Pro might be empty)
|
|
if file_path.stat().st_size == 0:
|
|
print(f"Skipping {model_id}: {file_path.name} is empty.")
|
|
continue
|
|
|
|
print(f"Processing {file_path.name} for model {model_id}...")
|
|
|
|
# 1. Upload the file to the File API
|
|
print(f" Uploading file...")
|
|
uploaded_file = client.files.upload(
|
|
file=str(file_path),
|
|
config=types.UploadFileConfig(
|
|
display_name=f"{display_name}-input",
|
|
mime_type='jsonl'
|
|
)
|
|
)
|
|
print(f" Uploaded successfully! File ID: {uploaded_file.name}")
|
|
|
|
# 2. Create the batch job
|
|
print(f" Starting batch job...")
|
|
batch_job = client.batches.create(
|
|
model=model_id,
|
|
src=uploaded_file.name,
|
|
config={
|
|
'display_name': display_name,
|
|
},
|
|
)
|
|
print(f" Success! Batch Job Name: {batch_job.name}\n")
|
|
|
|
print("-" * 50)
|
|
print("All batch jobs have been initiated.")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|