Skip to content

Commit

Permalink
Example: Add ResNet18/VGG11 to feed_forward.py
Browse files Browse the repository at this point in the history
- add to the supported models in `share/example/feed_forward.py`:
  - `resnet18`
  - `vgg11`
  - `vgg11_bn`
  • Loading branch information
chr5tphr committed Aug 23, 2023
1 parent e9c6e09 commit 60a2c08
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion share/example/feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import DataLoader, Subset
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor
from torchvision.datasets import ImageFolder
from torchvision.models import vgg16, vgg16_bn, resnet50
from torchvision.models import vgg11, vgg11_bn, vgg16, vgg16_bn, resnet18, resnet50

from zennit.attribution import Gradient, SmoothGrad, IntegratedGradients, Occlusion
from zennit.composites import COMPOSITES
Expand All @@ -19,6 +19,9 @@
MODELS = {
'vgg16': (vgg16, VGGCanonizer),
'vgg16_bn': (vgg16_bn, VGGCanonizer),
'vgg11': (vgg11, VGGCanonizer),
'vgg11_bn': (vgg11_bn, VGGCanonizer),
'resnet18': (resnet18, ResNetCanonizer),
'resnet50': (resnet50, ResNetCanonizer),
}

Expand Down

0 comments on commit 60a2c08

Please sign in to comment.