Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix sam session #531

Merged
merged 1 commit into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,14 @@ rembg i -a path/to/input.png path/to/output.png
Passing extras parameters

```
rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png
SAM example

rembg i -m sam -x '{ "sam_prompt": [{"type": "point", "data": [724, 740], "label": 1}] }' examples/plants-1.jpg examples/plants-1.out.png
```

```
Custom model example

rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png
```

Expand Down
Binary file added examples/plants-1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/plants-1.out.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
247 changes: 179 additions & 68 deletions rembg/sessions/sam.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
from copy import deepcopy
from typing import List

import cv2
import numpy as np
import onnxruntime as ort
import pooch
from jsonschema import validate
from PIL import Image
from PIL.Image import Image as PILImage

Expand All @@ -15,37 +18,58 @@ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)

return (newh, neww)


def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray:
def apply_coords(coords: np.ndarray, original_size, target_length):
old_h, old_w = original_size
new_h, new_w = get_preprocess_shape(
original_size[0], original_size[1], target_length
)
coords = coords.copy().astype(float)

coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)

return coords


def resize_longes_side(img: PILImage, size=1024):
w, h = img.size
if h > w:
new_h, new_w = size, int(w * size / h)
else:
new_h, new_w = int(h * size / w), size
def get_input_points(prompt):
points = []
labels = []

for mark in prompt:
if mark["type"] == "point":
points.append(mark["data"])
labels.append(mark["label"])
elif mark["type"] == "rectangle":
points.append([mark["data"][0], mark["data"][1]])
points.append([mark["data"][2], mark["data"][3]])
labels.append(2)
labels.append(3)

return img.resize((new_w, new_h))
points, labels = np.array(points), np.array(labels)
return points, labels


def pad_to_square(img: np.ndarray, size=1024):
h, w = img.shape[:2]
padh = size - h
padw = size - w
img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant")
img = img.astype(np.float32)
return img
def transform_masks(masks, original_size, transform_matrix):
output_masks = []

for batch in range(masks.shape[0]):
batch_masks = []
for mask_id in range(masks.shape[1]):
mask = masks[batch, mask_id]
mask = cv2.warpAffine(
mask,
transform_matrix[:2],
(original_size[1], original_size[0]),
flags=cv2.INTER_LINEAR,
)
batch_masks.append(mask)
output_masks.append(batch_masks)

return np.array(output_masks)


class SamSession(BaseSession):
Expand All @@ -70,7 +94,7 @@ def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwar
**kwargs: Arbitrary keyword arguments.
"""
self.model_name = model_name
paths = self.__class__.download_models()
paths = self.__class__.download_models(*args, **kwargs)
self.encoder = ort.InferenceSession(
str(paths[0]),
providers=ort.get_available_providers(),
Expand All @@ -85,9 +109,9 @@ def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwar
def normalize(
self,
img: np.ndarray,
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
size=(1024, 1024),
mean=(),
std=(),
size=(),
*args,
**kwargs,
):
Expand All @@ -96,19 +120,16 @@ def normalize(

Args:
img (np.ndarray): The input image.
mean (tuple, optional): The mean values for normalization. Defaults to (123.675, 116.28, 103.53).
std (tuple, optional): The standard deviation values for normalization. Defaults to (58.395, 57.12, 57.375).
size (tuple, optional): The target size of the image. Defaults to (1024, 1024).
mean (tuple, optional): The mean values for normalization. Defaults to ().
std (tuple, optional): The standard deviation values for normalization. Defaults to ().
size (tuple, optional): The target size of the image. Defaults to ().
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.

Returns:
np.ndarray: The normalized image.
"""
pixel_mean = np.array([*mean]).reshape(1, 1, -1)
pixel_std = np.array([*std]).reshape(1, 1, -1)
x = (img - pixel_mean) / pixel_std
return x
return img

def predict(
self,
Expand All @@ -129,36 +150,89 @@ def predict(
Returns:
List[PILImage]: A list of masks generated by the decoder.
"""
# Preprocess image
image = resize_longes_side(img)
image = np.array(image)
image = self.normalize(image)
image = pad_to_square(image)

input_labels = kwargs.get("input_labels")
input_points = kwargs.get("input_points")

if input_labels is None:
raise ValueError("input_labels is required")
if input_points is None:
raise ValueError("input_points is required")

# Transpose
image = image.transpose(2, 0, 1)[None, :, :, :]
# Run encoder (Image embedding)
encoded = self.encoder.run(None, {"x": image})
image_embedding = encoded[0]

# Add a batch index, concatenate a padding point, and transform.
prompt = kwargs.get("sam_prompt", "{}")
schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {"type": "string"},
"label": {"type": "integer"},
"data": {
"type": "array",
"items": {"type": "number"},
},
},
},
}

validate(instance=prompt, schema=schema)

target_size = 1024
input_size = (684, 1024)
encoder_input_name = self.encoder.get_inputs()[0].name

img = img.convert("RGB")
cv_image = np.array(img)
original_size = cv_image.shape[:2]

scale_x = input_size[1] / cv_image.shape[1]
scale_y = input_size[0] / cv_image.shape[0]
scale = min(scale_x, scale_y)

transform_matrix = np.array(
[
[scale, 0, 0],
[0, scale, 0],
[0, 0, 1],
]
)

cv_image = cv2.warpAffine(
cv_image,
transform_matrix[:2],
(input_size[1], input_size[0]),
flags=cv2.INTER_LINEAR,
)

## encoder

encoder_inputs = {
encoder_input_name: cv_image.astype(np.float32),
}

encoder_output = self.encoder.run(None, encoder_inputs)
image_embedding = encoder_output[0]

embedding = {
"image_embedding": image_embedding,
"original_size": original_size,
"transform_matrix": transform_matrix,
}

## decoder

input_points, input_labels = get_input_points(prompt)
onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
None, :, :
]
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
None, :
].astype(np.float32)
onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32)
onnx_coord = apply_coords(onnx_coord, input_size, target_size).astype(
np.float32
)

onnx_coord = np.concatenate(
[
onnx_coord,
np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32),
],
axis=2,
)
onnx_coord = np.matmul(onnx_coord, transform_matrix.T)
onnx_coord = onnx_coord[:, :, :2].astype(np.float32)

# Create an empty mask input and an indicator for no mask.
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)

Expand All @@ -168,17 +242,19 @@ def predict(
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(img.size[::-1], dtype=np.float32),
"orig_im_size": np.array(input_size, dtype=np.float32),
}

masks, _, low_res_logits = self.decoder.run(None, decoder_inputs)
masks = masks > 0.0
masks = [
Image.fromarray((masks[i, 0] * 255).astype(np.uint8))
for i in range(masks.shape[0])
]
masks, _, _ = self.decoder.run(None, decoder_inputs)
inv_transform_matrix = np.linalg.inv(transform_matrix)
masks = transform_masks(masks, original_size, inv_transform_matrix)

mask = np.zeros((masks.shape[2], masks.shape[3], 3), dtype=np.uint8)
for m in masks[0, :, :, :]:
mask[m > 0.0] = [255, 255, 255]

return masks
mask = Image.fromarray(mask).convert("L")
return [mask]

@classmethod
def download_models(cls, *args, **kwargs):
Expand All @@ -195,29 +271,64 @@ def download_models(cls, *args, **kwargs):
Returns:
tuple: A tuple containing the file paths of the downloaded encoder and decoder models.
"""
fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
model_name = kwargs.get("sam_model", "sam_vit_b_01ec64")
quant = kwargs.get("sam_quant", False)

fname_encoder = f"{model_name}.encoder.onnx"
fname_decoder = f"{model_name}.decoder.onnx"

if quant:
fname_encoder = f"{model_name}.encoder.quant.onnx"
fname_decoder = f"{model_name}.decoder.quant.onnx"

pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:13d97c5c79ab13ef86d67cbde5f1b250",
f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_encoder}",
None,
fname=fname_encoder,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)

pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:fa3d1c36a3187d3de1c8deebf33dd127",
f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_decoder}",
None,
fname=fname_decoder,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)

if fname_encoder == "sam_vit_h_4b8939.encoder.onnx" and not os.path.exists(
os.path.join(
cls.u2net_home(*args, **kwargs), "sam_vit_h_4b8939.encoder_data.bin"
)
):
content = bytearray()

for i in range(1, 4):
pooch.retrieve(
f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/sam_vit_h_4b8939.encoder_data.{i}.bin",
None,
fname=f"sam_vit_h_4b8939.encoder_data.{i}.bin",
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)

fbin = os.path.join(
cls.u2net_home(*args, **kwargs),
f"sam_vit_h_4b8939.encoder_data.{i}.bin",
)
content.extend(open(fbin, "rb").read())
os.remove(fbin)

with open(
os.path.join(
cls.u2net_home(*args, **kwargs),
"sam_vit_h_4b8939.encoder_data.bin",
),
"wb",
) as fp:
fp.write(content)

return (
os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder),
os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder),
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
long_description = (here / "README.md").read_text(encoding="utf-8")

install_requires = [
"jsonschema",
"numpy",
"onnxruntime",
"opencv-python-headless",
Expand Down
Binary file added tests/fixtures/plants-1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/results/anime-girl-1.sam.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/results/car-1.sam.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/results/cloth-1.sam.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/results/plants-1.isnet-anime.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/results/plants-1.isnet-general-use.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/results/plants-1.sam.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/results/plants-1.silueta.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/results/plants-1.u2net.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/results/plants-1.u2net_cloth_seg.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/results/plants-1.u2net_human_seg.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/results/plants-1.u2netp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading