diff --git a/scripts/api.py b/scripts/api.py index 0d602db..f02d633 100644 --- a/scripts/api.py +++ b/scripts/api.py @@ -63,7 +63,7 @@ class SamPredictRequest(BaseModel): async def api_sam_predict(payload: SamPredictRequest = Body(...)) -> Any: print(f"SAM API /sam/sam-predict received request") payload.input_image = decode_to_pil(payload.input_image).convert('RGBA') - sam_output_mask_gallery, sam_message = sam_predict( + sam_output_mask_gallery, sam_message, boxes_filt = sam_predict( payload.sam_model_name, payload.input_image, payload.sam_positive_points, @@ -82,6 +82,7 @@ async def api_sam_predict(payload: SamPredictRequest = Body(...)) -> Any: result["blended_images"] = list(map(encode_to_base64, sam_output_mask_gallery[:3])) result["masks"] = list(map(encode_to_base64, sam_output_mask_gallery[3:6])) result["masked_images"] = list(map(encode_to_base64, sam_output_mask_gallery[6:])) + result["boxes"] = boxes_filt.int().tolist() return result class DINOPredictRequest(BaseModel): @@ -94,7 +95,7 @@ class DINOPredictRequest(BaseModel): async def api_dino_predict(payload: DINOPredictRequest = Body(...)) -> Any: print(f"SAM API /sam/dino-predict received request") payload.input_image = decode_to_pil(payload.input_image) - dino_output_img, _, dino_msg = dino_predict( + dino_output_img, _, dino_msg, boxes_filt = dino_predict( payload.input_image, payload.dino_model_name, payload.text_prompt, @@ -107,6 +108,7 @@ async def api_dino_predict(payload: DINOPredictRequest = Body(...)) -> Any: return { "msg": dino_msg, "image_with_box": encode_to_base64(dino_output_img) if dino_output_img is not None else None, + "boxes": boxes_filt.int().tolist() if boxes_filt is not None else None } class DilateMaskRequest(BaseModel): diff --git a/scripts/sam.py b/scripts/sam.py index 775bff2..310e748 100644 --- a/scripts/sam.py +++ b/scripts/sam.py @@ -235,7 +235,7 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points, multimask_output=True) masks = masks[:, None, ...] garbage_collect(sam) - return create_mask_output(image_np, masks, boxes_filt), sam_predict_status + sam_predict_result + (f" However, GroundingDINO installment has failed. Your process automatically fall back to local groundingdino. Check your terminal for more detail and {dino_install_issue_text}." if (dino_enabled and not install_success) else "") + return create_mask_output(image_np, masks, boxes_filt), sam_predict_status + sam_predict_result + (f" However, GroundingDINO installment has failed. Your process automatically fall back to local groundingdino. Check your terminal for more detail and {dino_install_issue_text}." if (dino_enabled and not install_success) else ""), boxes_filt def dino_predict(input_image, dino_model_name, text_prompt, box_threshold): @@ -245,9 +245,9 @@ def dino_predict(input_image, dino_model_name, text_prompt, box_threshold): return None, gr.update(), gr.update(visible=True, value=f"GroundingDINO requires text prompt.") image_np = np.array(input_image) boxes_filt, install_success = dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold) - boxes_filt = boxes_filt.numpy() - boxes_choice = [str(i) for i in range(boxes_filt.shape[0])] - return Image.fromarray(show_boxes(image_np, boxes_filt.astype(int), show_index=True)), gr.update(choices=boxes_choice, value=boxes_choice), gr.update(visible=False) if install_success else gr.update(visible=True, value=f"GroundingDINO installment failed. Your process automatically fall back to local groundingdino. See your terminal for more detail and {dino_install_issue_text}") + boxes_filt_numpy = boxes_filt.numpy() + boxes_choice = [str(i) for i in range(boxes_filt_numpy.shape[0])] + return Image.fromarray(show_boxes(image_np, boxes_filt_numpy.astype(int), show_index=True)), gr.update(choices=boxes_choice, value=boxes_choice), gr.update(visible=False) if install_success else gr.update(visible=True, value=f"GroundingDINO installment failed. Your process automatically fall back to local groundingdino. See your terminal for more detail and {dino_install_issue_text}"), boxes_filt def dino_batch_process(