Retry when a label is wrong
parent
ba95a27039
commit
b13ed34acf
|
|
@ -203,6 +203,7 @@ for path_str in args.input_paths:
|
||||||
print(f"Error: {input_arg} is not a valid file or directory.")
|
print(f"Error: {input_arg} is not a valid file or directory.")
|
||||||
|
|
||||||
labels_txt = (INPUT_DIR / "labels").read_text()
|
labels_txt = (INPUT_DIR / "labels").read_text()
|
||||||
|
valid_labels_set = set(line.strip() for line in labels_txt.splitlines() if line.strip())
|
||||||
names_txt = (INPUT_DIR / "names").read_text()
|
names_txt = (INPUT_DIR / "names").read_text()
|
||||||
client = genai.Client(api_key=api_key)
|
client = genai.Client(api_key=api_key)
|
||||||
|
|
||||||
|
|
@ -252,27 +253,34 @@ def process_copy_group(group_key, files):
|
||||||
|
|
||||||
print(f"[{group_key}] Processing {image_file.name} with {len(accumulated_labels)} accumulated labels...")
|
print(f"[{group_key}] Processing {image_file.name} with {len(accumulated_labels)} accumulated labels...")
|
||||||
|
|
||||||
try:
|
for attempt in range(2)
|
||||||
contents, config = generate_request(image_file, labels_txt, names_txt, accumulated_labels)
|
try:
|
||||||
|
contents, config = generate_request(image_file, labels_txt, names_txt, accumulated_labels)
|
||||||
|
|
||||||
response = client.models.generate_content(
|
response = client.models.generate_content(
|
||||||
model=MODEL_ID,
|
model=MODEL_ID,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
annota = AnnotationData.model_validate_json(response.text)
|
annota = AnnotationData.model_validate_json(response.text)
|
||||||
|
unknown = [item.label for item in annota.list if item.label not in valid_labels_set]
|
||||||
|
if unknown:
|
||||||
|
print(f"Error: {image_file.name} contained unknown labels: {unknown}")
|
||||||
|
if attempt == 0:
|
||||||
|
print("Retrying request...")
|
||||||
|
continue # Retry immediately
|
||||||
|
|
||||||
# Save result
|
# Save result
|
||||||
with open(output_json, "w", encoding="utf-8") as f:
|
with open(output_json, "w", encoding="utf-8") as f:
|
||||||
json.dump(annota.model_dump(), f, indent=2)
|
json.dump(annota.model_dump(), f, indent=2)
|
||||||
|
|
||||||
# Update context for the next part in this group
|
# Update context for the next part in this group
|
||||||
for box in annota.list:
|
for box in annota.list:
|
||||||
accumulated_labels.append(box.label)
|
accumulated_labels.append(box.label)
|
||||||
|
break # exit retry loop
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing {image_file.name}: {e}")
|
print(f"Error processing {image_file.name}: {e}")
|
||||||
|
|
||||||
# Rate Limiting
|
# Rate Limiting
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue