Skip to content

Commit

Permalink
Update tutorial with new (working) example code
Browse files Browse the repository at this point in the history
Fixes #167
  • Loading branch information
mikeheddes committed Mar 15, 2024
1 parent 549ad53 commit 8806a44
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 40 deletions.
89 changes: 49 additions & 40 deletions docs/classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ HDC Learning
After learning about representing and manipulating information in hyperspace, we can implement our first HDC classification model! We will use as an example the famous MNIST dataset that contains images of handwritten digits.


We start by importing Torchhd and any other libraries we need:
We start by importing Torchhd and the other libraries we need, in addition to specifying the training parameters:

.. code-block:: python
Expand All @@ -13,11 +13,22 @@ We start by importing Torchhd and any other libraries we need:
import torch.nn.functional as F
import torchvision
from torchvision.datasets import MNIST
# Note: this example requires the torchmetrics library: https://torchmetrics.readthedocs.io
import torchmetrics
from torchhd import functional
import torchhd
from torchhd.models import Centroid
from torchhd import embeddings
# Use the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using {} device".format(device))
DIMENSIONS = 10000
IMG_SIZE = 28
NUM_LEVELS = 1000
BATCH_SIZE = 1 # for GPUs with enough memory we can process multiple images at ones
Datasets
--------

Expand All @@ -34,55 +45,46 @@ Next, we load the training and testing datasets:
test_ld = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
In addition to the various datasets available in the Torch ecosystem, such as MNIST, the :ref:`datasets` module provides interface to several commonly used datasets in HDC. Such interfaces inherit from PyTorch's dataset class, ensuring interoperability with other datasets.
In addition to the various datasets available in the Torch ecosystem, such as MNIST, the :ref:`datasets` module provides an interface to several commonly used datasets in HDC. Such interfaces inherit from PyTorch's dataset class, ensuring interoperability with other datasets.

Training
--------

To perform the training, we start by defining a model. In addition to specifying the basis-hypervectors sets, the core part of the model is the encoding function. In the example below, we use random-hypervectors and level-hypervectors to encode the position and value of each pixel, respectively:
To perform the training, we start by defining an encoding. In addition to specifying the basis-hypervectors sets, a core part of learning is the encoding function. In the example below, we use random-hypervectors and level-hypervectors to encode the position and value of each pixel, respectively:

.. code-block:: python
class Model(nn.Module):
def __init__(self, num_classes, size):
super(Model, self).__init__()
self.flatten = torch.nn.Flatten()
self.position = embeddings.Random(size * size, DIMENSIONS)
self.value = embeddings.Level(NUM_LEVELS, DIMENSIONS)
self.classify = nn.Linear(DIMENSIONS, num_classes, bias=False)
self.classify.weight.data.fill_(0.0)
class Encoder(nn.Module):
def __init__(self, out_features, size, levels):
super(Encoder, self).__init__()
self.flatten = torch.nn.Flatten()
self.position = embeddings.Random(size * size, out_features)
self.value = embeddings.Level(levels, out_features)
def encode(self, x):
x = self.flatten(x)
sample_hv = functional.bind(self.position.weight, self.value(x))
sample_hv = functional.multiset(sample_hv)
return functional.hard_quantize(sample_hv)
def forward(self, x):
x = self.flatten(x)
sample_hv = torchhd.bind(self.position.weight, self.value(x))
sample_hv = torchhd.multiset(sample_hv)
return torchhd.hard_quantize(sample_hv)
def forward(self, x):
enc = self.encode(x)
logit = self.classify(enc)
return logit
encode = Encoder(DIMENSIONS, IMG_SIZE, NUM_LEVELS)
encode = encode.to(device)
model = Model(len(train_ds.classes), IMG_SIZE)
num_classes = len(train_ds.classes)
model = Centroid(DIMENSIONS, num_classes)
model = model.to(device)
Having defined the model, we iterate over the training samples to create the class-vectors:

.. code-block:: python
for samples, labels in train_ld:
samples = samples.to(device)
labels = labels.to(device)
samples_hv = model.encode(samples)
model.classify.weight[labels] += samples_hv
with torch.no_grad():
for samples, labels in tqdm(train_ld, desc="Training"):
samples = samples.to(device)
labels = labels.to(device)
model.classify.weight[:] = F.normalize(model.classify.weight)
samples_hv = encode(samples)
model.add(samples_hv, labels)
Testing
-------
Expand All @@ -91,9 +93,16 @@ With the model trained, we can classify the testing samples by encoding them and

.. code-block:: python
for samples, labels in test_ld:
samples = samples.to(device)
accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes)
with torch.no_grad():
model.normalize()
for samples, labels in tqdm(test_ld, desc="Testing"):
samples = samples.to(device)
samples_hv = encode(samples)
outputs = model(samples_hv, dot=True)
accuracy.update(outputs.cpu(), labels)
outputs = model(samples)
predictions = torch.argmax(outputs, dim=-1)
accuracy.update(predictions.cpu(), labels)
print(f"Testing accuracy of {(accuracy.compute().item() * 100):.3f}%")
2 changes: 2 additions & 0 deletions docs/docutils.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[restructuredtext parser]
tab_width: 4

0 comments on commit 8806a44

Please sign in to comment.