A pytorch library for vector quantization methods. Vector quantization has been successfully used by high-quality image and audio generation, e.g., VQVAE, VQGAN.
Implemented methods:
- Vector Quantization
- Vector Quantization based on momentum moving average
- Vector Quantization based on gumbel-softmax trick
- Product Quantization
- Residual Quantization
import torch
from vector_quantize import VectorQuantizer
vq = VectorQuantizer(
n_e = 1024, # codebook vocalbulary size
e_dim = 256, # codebook vocalbulary dimension
beta = 1.0, # the weight on the commitment loss
)
x = torch.randn(1, 256, 16, 16) # size of NCHW
quantized, commit_loss, indices = vq(x) # shape of (1, 256, 16, 16), (1), (1, 16, 16)