Refine segmentation masks with Segment Anything Model(SAM)
install SAM -> https://github.com/facebookresearch/segment-anything
numpy, cv2
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
SAMtype="..."
SAMcheckpoint="..."
sam = sam_model_registry[SAMtype](checkpoint=SAMcheckpoint)
mask_generator = SamAutomaticMaskGenerator(sam)
from segmentBooster import refineMask
imagePath"..."
segmentationMask="..."#2D numpy array with shape(image.shape), storing pixel level class IDs.
#outputs 2D numpy array with shape(image.shape), storing pixel level class IDs.
refinedMask=refineMask(imagePath,segmentationMask,mask_generator)