def plot_filtered_images(images, filters):
images = torch.cat([i.unsqueeze(0) for i in images], dim=0).cpu()
print(images.shape)
filters = filters.cpu()
n_images = images.shape[0]
n_filters = filters.shape[0]
filtered_images = F.conv2d(images, filters)
fig = plt.figure(figsize=(20, 20))
for i in range(n_images):
img = images[i].squeeze(0).numpy().transpose((1,2,0))
img = 255*(img - img.min())/(img.max() - img.min())
ax = fig.add_subplot(n_images, n_filters+1, i+1+(i*n_filters))
ax.imshow(img.astype(int), cmap='bone')
ax.set_title('Original')
ax.axis('off')
for j in range(n_filters):
image = filtered_images[i][j]
image = image.numpy().astype(float)
image = 255*(image - image.min())/(image.max() - image.min())
ax = fig.add_subplot(n_images, n_filters+1, i+1+(i*n_filters)+j+1)
ax.imshow(image.astype(int), cmap='bone')
ax.set_title(f'Filter {j+1}')
ax.axis('off')
return filtered_images
N_IMAGES = 5
images = [image for image, label in [testset[i] for i in range(N_IMAGES)]]
filters = model.conv1.weight.data
filtered_images = plot_filtered_images(images, filters)
n_classes = 10
n_samples = len(testset)
class_correct = torch.zeros(n_classes)
class_total = torch.zeros(n_classes)
labels = []
predicts = []
model.eval()
with torch.no_grad():
for x,y in tqdm(testloader):
x = x.to(device)
y = y.to('cpu')
y_pred = model.forward(x).to('cpu')
cls_pred = torch.argmax(y_pred, dim=1)
c = (cls_pred == y).squeeze() # one-hot vector
for i in range(x.shape[0]):
label = y[i]
class_correct[label] += c[i].item()
class_total[label] += 1
labels.append(y)
predicts.append(cls_pred)
labels = torch.cat(labels, dim = 0) # make it 1d array
predicts = torch.cat(predicts, dim = 0) # make it 1d array
for i in range(n_classes):
print('Accuracy of %5s : %2d %%' % (
names_classes[i], 100 * class_correct[i] / class_total[i]))
sorted_ids = list(range(1, n_samples + 1))
sorted_ids.sort(key=lambda x: str(x))
df = pd.DataFrame({'id': sorted_ids, 'label': predicts})
df['label'] = df['label'].apply(lambda x: names_classes[x])
df.to_csv('submission.csv', index=False)