Skip to content

Commit

Permalink
✨ feat(model): Support VGGs and ShuffleNets
Browse files Browse the repository at this point in the history
  • Loading branch information
KarhouTam committed Mar 26, 2024
1 parent 92db01d commit ff136dd
Showing 1 changed file with 60 additions and 0 deletions.
60 changes: 60 additions & 0 deletions src/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,58 @@ def __init__(self, version, dataset):
self.base.classifier[-1] = nn.Identity()


class ShuffleNet(DecoupledModel):
def __init__(self, version, dataset):
super().__init__()
archs = {
"0_5": (
models.shufflenet_v2_x0_5,
models.ShuffleNet_V2_X0_5_Weights.DEFAULT,
),
"1_0": (
models.shufflenet_v2_x1_0,
models.ShuffleNet_V2_X1_0_Weights.DEFAULT,
),
"1_5": (
models.shufflenet_v2_x1_5,
models.ShuffleNet_V2_X1_5_Weights.DEFAULT,
),
"2_0": (
models.shufflenet_v2_x2_0,
models.ShuffleNet_V2_X2_0_Weights.DEFAULT,
),
}
# NOTE: If you don't want parameters pretrained, set `pretrained` as False
pretrained = True
shufflenet: models.ShuffleNetV2 = archs[version][0](
weights=archs[version][1] if pretrained else None
)
self.base = shufflenet
self.classifier = nn.Linear(shufflenet.fc.in_features, NUM_CLASSES[dataset])
self.base.fc = nn.Identity()


class VGG(DecoupledModel):
def __init__(self, version, dataset):
super().__init__()
archs = {
"11": (models.vgg11, models.VGG11_Weights.DEFAULT),
"13": (models.vgg13, models.VGG13_Weights.DEFAULT),
"16": (models.vgg16, models.VGG16_Weights.DEFAULT),
"19": (models.vgg19, models.VGG19_Weights.DEFAULT),
}
# NOTE: If you don't want parameters pretrained, set `pretrained` as False
pretrained = True
vgg: models.VGG = archs[version][0](
weights=archs[version][1] if pretrained else None
)
self.base = vgg
self.classifier = nn.Linear(
vgg.classifier[-1].in_features, NUM_CLASSES[dataset]
)
self.base.classifier[-1] = nn.Identity()


# NOTE: You can build your custom model here.
# What you only need to do is define the architecture in __init__().
# Don't need to consider anything else, which are handled by DecoupledModel well already.
Expand Down Expand Up @@ -398,4 +450,12 @@ def __init__(self, dataset):
"efficient5": partial(EfficientNet, version="5"),
"efficient6": partial(EfficientNet, version="6"),
"efficient7": partial(EfficientNet, version="7"),
"shuffle0_5": partial(ShuffleNet, version="0_5"),
"shuffle1_0": partial(ShuffleNet, version="1_0"),
"shuffle1_5": partial(ShuffleNet, version="1_5"),
"shuffle2_0": partial(ShuffleNet, version="2_0"),
"vgg11": partial(VGG, version="11"),
"vgg13": partial(VGG, version="13"),
"vgg16": partial(VGG, version="16"),
"vgg19": partial(VGG, version="19"),
}

0 comments on commit ff136dd

Please sign in to comment.