Implementation of Associative Compression Networks for Representation Learning (ACN) by Graves, Menick, and van den Oord. We also introduced a VQ-VAE style decoder to the ACN model and call this architecture ACN-VQ.
python train_acn.py --dataset_name FashionMNIST
python train_acn.py --dataset_name MNIST --vq_decoder
python train_acn.py -l path_to_model.pt --sample --pca --tsne
In the following table, we look at various aspects of the learned ACN codebook and reconstructions. We trained a KNN classifier on the ACN codes and report accuracy on the validation set of FashionMNIST and MNIST. We also show a PCA and TSNE plot for the ACN codebook of each model with points colored according to their true label.
Data Type | Eval | ACN | ACN-VQ |
---|---|---|---|
Fashion | KNN Accuracy | 89% | 89% |
Fashion | PCA | fashion-acn-pca | mnist-acnvq-pca |
Fashion | TSNE | fashion-acn-tsne | mnist-acnvq-pca |
MNIST | KNN Accuracy | 97% | 97% |
MNIST | PCA | mnist-acn-pca | mnist-acnvq-pca |
MNIST | TSNE | mnist-acn-tsne | mnist-acnvq-pca |
In the images in the following table, we encode an example from the validation set (upper/leftmost image) and look at its nearest neighbors (right columns) according to the learned ACN model. Reconstructions of codes are shown in the second row of each image. Each channel from the ACN codebook (of size 2,7,7) is shown in the bottom two rows of each image. Notice the differences in the ACN codes of the pure ACN (left column) and ACN-VQ (right column). Each column in the image is denoted by its label (L) and its index into the training dataset (I).
ACN | ACN-VQ |
---|---|