From 1c6c5a4c8a83e0d47dd498fec39f846efe431b20 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 15 Jul 2024 13:25:42 +0800 Subject: [PATCH 1/7] init --- annotator/mobile_sam/__init__.py | 48 ++++++++++++++++++++++++++++++ requirements.txt | 1 + scripts/preprocessor/mobile_sam.py | 30 +++++++++++++++++++ 3 files changed, 79 insertions(+) create mode 100644 annotator/mobile_sam/__init__.py create mode 100644 scripts/preprocessor/mobile_sam.py diff --git a/annotator/mobile_sam/__init__.py b/annotator/mobile_sam/__init__.py new file mode 100644 index 000000000..4c7578dea --- /dev/null +++ b/annotator/mobile_sam/__init__.py @@ -0,0 +1,48 @@ +from __future__ import print_function + +import os +import numpy as np +from PIL import Image +from typing import Union + +from modules import devices +from annotator.util import load_model +from annotator.annotator_path import models_path + +from controlnet_aux import SamDetector +from controlnet_aux.segment_anything import sam_model_registry, SamAutomaticMaskGenerator + +class SamDetector_Aux(SamDetector): + + model_dir = os.path.join(models_path, "mobile_sam") + + def __init__(self, mask_generator: SamAutomaticMaskGenerator): + super().__init__(mask_generator) + + self.device = devices.device + self.model = SamDetector_Aux().to(self.device).eval() + self.from_pretrained(model_type="vit_t") + + @classmethod + def from_pretrained(cls, model_type="vit_t"): + """ + Possible model_type : vit_h, vit_l, vit_b, vit_t + download weights from https://huggingface.co/dhkim2810/MobileSAM + """ + remote_url = os.environ.get( + "CONTROLNET_MOBILE_SAM_MODEL_URL", + "https://huggingface.co/dhkim2810/MobileSAM/resolve/main/mobile_sam.pt", + ) + model_path = load_model( + "mobile_sam.pt", remote_url=remote_url, model_dir=cls.model_dir + ) + + sam = sam_model_registry[model_type](checkpoint=model_path) + + mask_generator = SamAutomaticMaskGenerator(sam) + + return cls(mask_generator) + + def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs) -> np.ndarray: + self.model.to(self.device) + super().__call__(image=input_image, detect_resolution=detect_resolution, image_resolution=image_resolution, output_type=output_type, **kwargs) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 10013834c..fef12cf3b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ matplotlib facexlib timm<=0.9.5 pydantic<=1.10.17 +controlnet_aux \ No newline at end of file diff --git a/scripts/preprocessor/mobile_sam.py b/scripts/preprocessor/mobile_sam.py new file mode 100644 index 000000000..8ab71bb09 --- /dev/null +++ b/scripts/preprocessor/mobile_sam.py @@ -0,0 +1,30 @@ +import numpy as np +from skimage import morphology + +from annotator.mobile_sam import SamDetector_Aux +from scripts.supported_preprocessor import Preprocessor, PreprocessorParameter +from scripts.utils import resize_image_with_pad + +class PreprocessorMobileSam(Preprocessor): + def __init__(self): + super().__init__(name="mobile_sam") + self.tags = ["Segmentation"] + self.model = None + + def __call__( + self, + input_image, + resolution, + slider_1=None, + slider_2=None, + slider_3=None, + **kwargs + ): + img, remove_pad = resize_image_with_pad(input_image, resolution) + if self.model is None: + self.model = SamDetector_Aux() + + result = self.model(img, detect_resolution=resolution, image_resolution=resolution) + return remove_pad(result) + +Preprocessor.add_supported_preprocessor(PreprocessorMobileSam()) \ No newline at end of file From 312dd8100ce60f286b307920f519b7aa52fa1b5e Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 15 Jul 2024 13:28:14 +0800 Subject: [PATCH 2/7] Update mobile_sam.py --- scripts/preprocessor/mobile_sam.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/scripts/preprocessor/mobile_sam.py b/scripts/preprocessor/mobile_sam.py index 8ab71bb09..36408729f 100644 --- a/scripts/preprocessor/mobile_sam.py +++ b/scripts/preprocessor/mobile_sam.py @@ -1,8 +1,5 @@ -import numpy as np -from skimage import morphology - from annotator.mobile_sam import SamDetector_Aux -from scripts.supported_preprocessor import Preprocessor, PreprocessorParameter +from scripts.supported_preprocessor import Preprocessor from scripts.utils import resize_image_with_pad class PreprocessorMobileSam(Preprocessor): From eb14e7a29f5a9bbb1ad2f4f695c24d70cc084c78 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 15 Jul 2024 13:45:38 +0800 Subject: [PATCH 3/7] Update __init__.py --- scripts/preprocessor/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/preprocessor/__init__.py b/scripts/preprocessor/__init__.py index 081af9779..75cbab682 100644 --- a/scripts/preprocessor/__init__.py +++ b/scripts/preprocessor/__init__.py @@ -5,4 +5,5 @@ from .ip_adapter_auto import * from .normal_dsine import * from .model_free_preprocessors import * -from .legacy.legacy_preprocessors import * \ No newline at end of file +from .legacy.legacy_preprocessors import * +from .mobile_sam import * \ No newline at end of file From f46298f06976c8c01256a97eb195e0ef4042b135 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 15 Jul 2024 13:52:52 +0800 Subject: [PATCH 4/7] update --- annotator/mobile_sam/__init__.py | 13 ++++++------- scripts/preprocessor/mobile_sam.py | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/annotator/mobile_sam/__init__.py b/annotator/mobile_sam/__init__.py index 4c7578dea..43b29436e 100644 --- a/annotator/mobile_sam/__init__.py +++ b/annotator/mobile_sam/__init__.py @@ -19,12 +19,8 @@ class SamDetector_Aux(SamDetector): def __init__(self, mask_generator: SamAutomaticMaskGenerator): super().__init__(mask_generator) - self.device = devices.device - self.model = SamDetector_Aux().to(self.device).eval() - self.from_pretrained(model_type="vit_t") - @classmethod - def from_pretrained(cls, model_type="vit_t"): + def from_pretrained(cls): """ Possible model_type : vit_h, vit_l, vit_b, vit_t download weights from https://huggingface.co/dhkim2810/MobileSAM @@ -35,9 +31,12 @@ def from_pretrained(cls, model_type="vit_t"): ) model_path = load_model( "mobile_sam.pt", remote_url=remote_url, model_dir=cls.model_dir - ) + ) + + sam = sam_model_registry["vit_t"](checkpoint=model_path) - sam = sam_model_registry[model_type](checkpoint=model_path) + cls.device = devices.device + cls.model = SamDetector_Aux().to(cls.device).eval() mask_generator = SamAutomaticMaskGenerator(sam) diff --git a/scripts/preprocessor/mobile_sam.py b/scripts/preprocessor/mobile_sam.py index 36408729f..cce7065f2 100644 --- a/scripts/preprocessor/mobile_sam.py +++ b/scripts/preprocessor/mobile_sam.py @@ -19,7 +19,7 @@ def __call__( ): img, remove_pad = resize_image_with_pad(input_image, resolution) if self.model is None: - self.model = SamDetector_Aux() + self.model = SamDetector_Aux.from_pretrained() result = self.model(img, detect_resolution=resolution, image_resolution=resolution) return remove_pad(result) From 5f5b51fb6259e4248bea9f3fc6da6e115eb08aad Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 15 Jul 2024 15:11:15 +0800 Subject: [PATCH 5/7] update --- annotator/mobile_sam/__init__.py | 16 +++++++++------- scripts/preprocessor/mobile_sam.py | 6 +++--- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/annotator/mobile_sam/__init__.py b/annotator/mobile_sam/__init__.py index 43b29436e..57b4124f8 100644 --- a/annotator/mobile_sam/__init__.py +++ b/annotator/mobile_sam/__init__.py @@ -16,8 +16,10 @@ class SamDetector_Aux(SamDetector): model_dir = os.path.join(models_path, "mobile_sam") - def __init__(self, mask_generator: SamAutomaticMaskGenerator): + def __init__(self, mask_generator: SamAutomaticMaskGenerator, sam): super().__init__(mask_generator) + self.device = devices.device + self.model = sam.to(self.device).eval() @classmethod def from_pretrained(cls): @@ -35,13 +37,13 @@ def from_pretrained(cls): sam = sam_model_registry["vit_t"](checkpoint=model_path) - cls.device = devices.device - cls.model = SamDetector_Aux().to(cls.device).eval() + cls.model = sam.to(devices.device).eval() - mask_generator = SamAutomaticMaskGenerator(sam) + mask_generator = SamAutomaticMaskGenerator(cls.model) - return cls(mask_generator) + return cls(mask_generator, sam) - def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs) -> np.ndarray: + def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="cv2", **kwargs) -> np.ndarray: self.model.to(self.device) - super().__call__(image=input_image, detect_resolution=detect_resolution, image_resolution=image_resolution, output_type=output_type, **kwargs) \ No newline at end of file + image = super().__call__(input_image=input_image, detect_resolution=detect_resolution, image_resolution=image_resolution, output_type=output_type, **kwargs) + return np.array(image).astype(np.uint8) \ No newline at end of file diff --git a/scripts/preprocessor/mobile_sam.py b/scripts/preprocessor/mobile_sam.py index cce7065f2..5f7cd7849 100644 --- a/scripts/preprocessor/mobile_sam.py +++ b/scripts/preprocessor/mobile_sam.py @@ -17,11 +17,11 @@ def __call__( slider_3=None, **kwargs ): - img, remove_pad = resize_image_with_pad(input_image, resolution) + #img, remove_pad = resize_image_with_pad(input_image, resolution) if self.model is None: self.model = SamDetector_Aux.from_pretrained() - result = self.model(img, detect_resolution=resolution, image_resolution=resolution) - return remove_pad(result) + result = self.model(input_image, detect_resolution=resolution, image_resolution=resolution, output_type="cv2") + return result Preprocessor.add_supported_preprocessor(PreprocessorMobileSam()) \ No newline at end of file From d52a1235f249795b65f3a88372d906c7d977f6c9 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 15 Jul 2024 15:13:50 +0800 Subject: [PATCH 6/7] Update mobile_sam.py --- scripts/preprocessor/mobile_sam.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scripts/preprocessor/mobile_sam.py b/scripts/preprocessor/mobile_sam.py index 5f7cd7849..e394647c0 100644 --- a/scripts/preprocessor/mobile_sam.py +++ b/scripts/preprocessor/mobile_sam.py @@ -1,6 +1,5 @@ from annotator.mobile_sam import SamDetector_Aux from scripts.supported_preprocessor import Preprocessor -from scripts.utils import resize_image_with_pad class PreprocessorMobileSam(Preprocessor): def __init__(self): @@ -17,7 +16,6 @@ def __call__( slider_3=None, **kwargs ): - #img, remove_pad = resize_image_with_pad(input_image, resolution) if self.model is None: self.model = SamDetector_Aux.from_pretrained() From 1c2b7144d262f3d7e2c02b9d041d82fa7d23736f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Tue, 16 Jul 2024 01:51:44 +0800 Subject: [PATCH 7/7] specify controlnet_aux>=0.0.9 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index fef12cf3b..a3319ea39 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,4 @@ matplotlib facexlib timm<=0.9.5 pydantic<=1.10.17 -controlnet_aux \ No newline at end of file +controlnet_aux>=0.0.9