Retry when a label is wrong

master
Sébastien Miquel 2026-02-25 20:53:09 +01:00
parent ba95a27039
commit b13ed34acf
1 changed files with 25 additions and 17 deletions

View File

@ -203,6 +203,7 @@ for path_str in args.input_paths:
print(f"Error: {input_arg} is not a valid file or directory.")
labels_txt = (INPUT_DIR / "labels").read_text()
valid_labels_set = set(line.strip() for line in labels_txt.splitlines() if line.strip())
names_txt = (INPUT_DIR / "names").read_text()
client = genai.Client(api_key=api_key)
@ -252,27 +253,34 @@ def process_copy_group(group_key, files):
print(f"[{group_key}] Processing {image_file.name} with {len(accumulated_labels)} accumulated labels...")
try:
contents, config = generate_request(image_file, labels_txt, names_txt, accumulated_labels)
for attempt in range(2)
try:
contents, config = generate_request(image_file, labels_txt, names_txt, accumulated_labels)
response = client.models.generate_content(
model=MODEL_ID,
contents=contents,
config=config
)
response = client.models.generate_content(
model=MODEL_ID,
contents=contents,
config=config
)
annota = AnnotationData.model_validate_json(response.text)
annota = AnnotationData.model_validate_json(response.text)
unknown = [item.label for item in annota.list if item.label not in valid_labels_set]
if unknown:
print(f"Error: {image_file.name} contained unknown labels: {unknown}")
if attempt == 0:
print("Retrying request...")
continue # Retry immediately
# Save result
with open(output_json, "w", encoding="utf-8") as f:
json.dump(annota.model_dump(), f, indent=2)
# Save result
with open(output_json, "w", encoding="utf-8") as f:
json.dump(annota.model_dump(), f, indent=2)
# Update context for the next part in this group
for box in annota.list:
accumulated_labels.append(box.label)
except Exception as e:
print(f"Error processing {image_file.name}: {e}")
# Update context for the next part in this group
for box in annota.list:
accumulated_labels.append(box.label)
break # exit retry loop
except Exception as e:
print(f"Error processing {image_file.name}: {e}")
# Rate Limiting
elapsed = time.time() - start_time