From b13ed34acf4788d1624df406e445deb000deb912 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Miquel?= Date: Wed, 25 Feb 2026 20:53:09 +0100 Subject: [PATCH] Retry when a label is wrong --- gemini_for_labels.py | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/gemini_for_labels.py b/gemini_for_labels.py index 13e0618..04a03d2 100644 --- a/gemini_for_labels.py +++ b/gemini_for_labels.py @@ -203,6 +203,7 @@ for path_str in args.input_paths: print(f"Error: {input_arg} is not a valid file or directory.") labels_txt = (INPUT_DIR / "labels").read_text() +valid_labels_set = set(line.strip() for line in labels_txt.splitlines() if line.strip()) names_txt = (INPUT_DIR / "names").read_text() client = genai.Client(api_key=api_key) @@ -252,27 +253,34 @@ def process_copy_group(group_key, files): print(f"[{group_key}] Processing {image_file.name} with {len(accumulated_labels)} accumulated labels...") - try: - contents, config = generate_request(image_file, labels_txt, names_txt, accumulated_labels) + for attempt in range(2) + try: + contents, config = generate_request(image_file, labels_txt, names_txt, accumulated_labels) - response = client.models.generate_content( - model=MODEL_ID, - contents=contents, - config=config - ) + response = client.models.generate_content( + model=MODEL_ID, + contents=contents, + config=config + ) - annota = AnnotationData.model_validate_json(response.text) + annota = AnnotationData.model_validate_json(response.text) + unknown = [item.label for item in annota.list if item.label not in valid_labels_set] + if unknown: + print(f"Error: {image_file.name} contained unknown labels: {unknown}") + if attempt == 0: + print("Retrying request...") + continue # Retry immediately - # Save result - with open(output_json, "w", encoding="utf-8") as f: - json.dump(annota.model_dump(), f, indent=2) + # Save result + with open(output_json, "w", encoding="utf-8") as f: + json.dump(annota.model_dump(), f, indent=2) - # Update context for the next part in this group - for box in annota.list: - accumulated_labels.append(box.label) - - except Exception as e: - print(f"Error processing {image_file.name}: {e}") + # Update context for the next part in this group + for box in annota.list: + accumulated_labels.append(box.label) + break # exit retry loop + except Exception as e: + print(f"Error processing {image_file.name}: {e}") # Rate Limiting elapsed = time.time() - start_time