-
Notifications
You must be signed in to change notification settings - Fork 11
/
my_sam_LST.py
73 lines (42 loc) · 1.96 KB
/
my_sam_LST.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from segment_anything import sam_model_registry
import torch
import torch.nn as nn
from torchvision.models.resnet import resnet18
class SAM_LST(nn.Module):
def __init__(self):
super(SAM_LST, self).__init__()
self.sam, img_embedding_size = sam_model_registry["vit_b"](image_size=512,
num_classes=8,
checkpoint="/mnt/data3/chai/SAM/sam_vit_b_01ec64.pth",
pixel_mean=[0, 0, 0],
pixel_std=[1, 1, 1])
self.CNN_encoder = resnet18(pretrained=True)
self.sam_encoder = self.sam.image_encoder
for n, p in self.sam.named_parameters():
p.requires_grad = False
for n, p in self.sam.named_parameters():
if "alpha" in n:
p.requires_grad = True
if "output_upscaling" in n:
p.requires_grad = True
def forward(self, x, multimask_output = None, image_size =None):
cnn_out = self.CNN_encoder.conv1(x)
cnn_out = self.CNN_encoder.bn1(cnn_out)
cnn_out = self.CNN_encoder.relu(cnn_out)
cnn_out = self.CNN_encoder.maxpool(cnn_out)
cnn_out = self.CNN_encoder.layer1(cnn_out)
cnn_out = self.CNN_encoder.layer2(cnn_out)
cnn_out = self.CNN_encoder.layer3(cnn_out)
x = self.sam(x, multimask_output=multimask_output, image_size=image_size, CNN_input = cnn_out)
return x
if __name__ == "__main__":
net = SAM_LST().cuda()
out = net(torch.rand(1, 3, 512, 512).cuda(), 1, 512)
parameter = 0
select = 0
for n, p in net.named_parameters():
parameter += len(p.reshape(-1))
if p.requires_grad == True:
select += len(p.reshape(-1))
print(select / parameter * 100)
print(out['masks'].shape)