Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/api sam predict boxes #198

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions scripts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class SamPredictRequest(BaseModel):
async def api_sam_predict(payload: SamPredictRequest = Body(...)) -> Any:
print(f"SAM API /sam/sam-predict received request")
payload.input_image = decode_to_pil(payload.input_image).convert('RGBA')
sam_output_mask_gallery, sam_message = sam_predict(
sam_output_mask_gallery, sam_message, boxes_filt = sam_predict(
payload.sam_model_name,
payload.input_image,
payload.sam_positive_points,
Expand All @@ -82,6 +82,7 @@ async def api_sam_predict(payload: SamPredictRequest = Body(...)) -> Any:
result["blended_images"] = list(map(encode_to_base64, sam_output_mask_gallery[:3]))
result["masks"] = list(map(encode_to_base64, sam_output_mask_gallery[3:6]))
result["masked_images"] = list(map(encode_to_base64, sam_output_mask_gallery[6:]))
result["boxes"] = boxes_filt.int().tolist()
return result

class DINOPredictRequest(BaseModel):
Expand All @@ -94,7 +95,7 @@ class DINOPredictRequest(BaseModel):
async def api_dino_predict(payload: DINOPredictRequest = Body(...)) -> Any:
print(f"SAM API /sam/dino-predict received request")
payload.input_image = decode_to_pil(payload.input_image)
dino_output_img, _, dino_msg = dino_predict(
dino_output_img, _, dino_msg, boxes_filt = dino_predict(
payload.input_image,
payload.dino_model_name,
payload.text_prompt,
Expand All @@ -107,6 +108,7 @@ async def api_dino_predict(payload: DINOPredictRequest = Body(...)) -> Any:
return {
"msg": dino_msg,
"image_with_box": encode_to_base64(dino_output_img) if dino_output_img is not None else None,
"boxes": boxes_filt.int().tolist() if boxes_filt is not None else None
}

class DilateMaskRequest(BaseModel):
Expand Down
8 changes: 4 additions & 4 deletions scripts/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
multimask_output=True)
masks = masks[:, None, ...]
garbage_collect(sam)
return create_mask_output(image_np, masks, boxes_filt), sam_predict_status + sam_predict_result + (f" However, GroundingDINO installment has failed. Your process automatically fall back to local groundingdino. Check your terminal for more detail and {dino_install_issue_text}." if (dino_enabled and not install_success) else "")
return create_mask_output(image_np, masks, boxes_filt), sam_predict_status + sam_predict_result + (f" However, GroundingDINO installment has failed. Your process automatically fall back to local groundingdino. Check your terminal for more detail and {dino_install_issue_text}." if (dino_enabled and not install_success) else ""), boxes_filt


def dino_predict(input_image, dino_model_name, text_prompt, box_threshold):
Expand All @@ -245,9 +245,9 @@ def dino_predict(input_image, dino_model_name, text_prompt, box_threshold):
return None, gr.update(), gr.update(visible=True, value=f"GroundingDINO requires text prompt.")
image_np = np.array(input_image)
boxes_filt, install_success = dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold)
boxes_filt = boxes_filt.numpy()
boxes_choice = [str(i) for i in range(boxes_filt.shape[0])]
return Image.fromarray(show_boxes(image_np, boxes_filt.astype(int), show_index=True)), gr.update(choices=boxes_choice, value=boxes_choice), gr.update(visible=False) if install_success else gr.update(visible=True, value=f"GroundingDINO installment failed. Your process automatically fall back to local groundingdino. See your terminal for more detail and {dino_install_issue_text}")
boxes_filt_numpy = boxes_filt.numpy()
boxes_choice = [str(i) for i in range(boxes_filt_numpy.shape[0])]
return Image.fromarray(show_boxes(image_np, boxes_filt_numpy.astype(int), show_index=True)), gr.update(choices=boxes_choice, value=boxes_choice), gr.update(visible=False) if install_success else gr.update(visible=True, value=f"GroundingDINO installment failed. Your process automatically fall back to local groundingdino. See your terminal for more detail and {dino_install_issue_text}"), boxes_filt


def dino_batch_process(
Expand Down