diff --git a/15_transfer_learning.py b/15_transfer_learning.py index e68c4c85..60756f3b 100644 --- a/15_transfer_learning.py +++ b/15_transfer_learning.py @@ -167,7 +167,9 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25): # Parameters of newly constructed modules have requires_grad=True by default num_ftrs = model_conv.fc.in_features -model_conv.fc = nn.Linear(num_ftrs, 2) + +# Creating a FC layer of the same size as our categories +model_conv.fc = nn.Linear(num_ftrs, len(class_names)) model_conv = model_conv.to(device)