Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeheddes committed Mar 12, 2024
1 parent 5637645 commit 7267d86
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 121 deletions.
82 changes: 82 additions & 0 deletions examples/classifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import torchhd
from torchhd.datasets.isolet import ISOLET

classifiers = [
"Vanilla",
"AdaptHD",
"OnlineHD",
"NeuralHD",
"DistHD",
"CompHD",
"SparseHD",
"QuantHD",
"LeHDC",
"IntRVFL",
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using {} device".format(device))

DIMENSIONS = 1024 # number of hypervector dimensions
BATCH_SIZE = 12 # for GPUs with enough memory we can process multiple images at ones

train_ds = ISOLET("../data", train=True, download=True)
train_ld = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

test_ds = ISOLET("../data", train=False, download=True)
test_ld = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

num_features = train_ds[0][0].size(-1)
num_classes = len(train_ds.classes)

std, mean = torch.std_mean(train_ds.data, dim=0, keepdim=False)

def transform(sample):
return (sample - mean) / std

train_ds.transform = transform
test_ds.transform = transform

params = {
"Vanilla": {
},
"AdaptHD": {
"epochs": 10,
},
"OnlineHD": {
"epochs": 10,
},
"NeuralHD": {
"epochs": 10,
"regen_freq": 5,
},
"DistHD": {
"epochs": 10,
"regen_freq": 5,
},
"CompHD": {
},
"SparseHD": {
"epochs": 10,
},
"QuantHD": {
"epochs": 10,
},
"LeHDC": {
"epochs": 10,
},
"IntRVFL": {
},
}

for classifier in classifiers:
print()
print(classifier)

model_cls = getattr(torchhd.classifiers, classifier)
model: torchhd.classifiers.Classifier = model_cls(num_features, DIMENSIONS, num_classes, device=device, **params[classifier])

model.fit(train_ld)
accuracy = model.accuracy(test_ld)
print(f"Testing accuracy of {(accuracy * 100):.3f}%")
165 changes: 44 additions & 121 deletions torchhd/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Optional, Literal, Callable, Iterable, Tuple
import math
import scipy.linalg
from tqdm import trange
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -123,19 +124,28 @@ def predict(self, samples: Tensor) -> LongTensor:
"""
return torch.argmax(self(samples), dim=-1)

def accuracy(self, samples: Tensor, labels: LongTensor) -> float:
def accuracy(self, data_loader: DataLoader) -> float:
"""Accuracy in predicting the labels of the samples.
Args:
samples (Tensor): Batch of samples to be classified.
labels (LongTensor): Batch of true labels of the samples.
data_loader (DataLoader): Iterable of tuples containing a batch of samples and labels.
Returns:
float: The accuracy of predicting the true labels.
"""
predictions = self.predict(samples)
return torch.mean(predictions == labels, dtype=torch.float).item()
n_correct = 0
n_total = 0

for samples, labels in data_loader:
samples = samples.to(self.device)
labels = labels.to(self.device)

predictions = self.predict(samples)
n_correct += torch.sum(predictions == labels, dtype=torch.long).item()
n_total += predictions.numel()

return n_correct / n_total


class Vanilla(Classifier):
Expand Down Expand Up @@ -252,7 +262,7 @@ def encoder(self, samples: Tensor) -> Tensor:

def fit(self, data_loader: DataLoader):

for _ in range(self.epochs):
for _ in trange(self.epochs, desc="fit"):
for samples, labels in data_loader:
samples = samples.to(self.device)
labels = labels.to(self.device)
Expand Down Expand Up @@ -304,7 +314,7 @@ def __init__(

def fit(self, data_loader: DataLoader):

for _ in range(self.epochs):
for _ in trange(self.epochs, desc="fit"):
for samples, labels in data_loader:
samples = samples.to(self.device)
labels = labels.to(self.device)
Expand All @@ -313,100 +323,7 @@ def fit(self, data_loader: DataLoader):
self.model.add_online(encoded, labels, lr=self.lr)

return self


class RefineHD(Classifier):
r"""Implements `RefineHD: : Accurate and Efficient Single-Pass Adaptive Learning Using Hyperdimensional Computing <https://ieeexplore.ieee.org/abstract/document/10386671>`_.
Args:
n_features (int): Size of each input sample.
n_dimensions (int): The number of hidden dimensions to use.
n_classes (int): The number of classes.
epochs (int): The number of iteration over the training data.
lr (float): The learning rate.
device (``torch.device``, optional): the desired device of the weights. Default: if ``None``, uses the current device for the default tensor type (see ``torch.set_default_tensor_type()``). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
dtype (``torch.dtype``, optional): the desired data type of the weights. Default: if ``None``, uses ``torch.get_default_dtype()``.
"""

encoder: Sinusoid
model: Centroid

def __init__(
self,
n_features: int,
n_dimensions: int,
n_classes: int,
*,
epochs: int = 120,
lr: float = 0.035,
device: torch.device = None,
dtype: torch.dtype = None
) -> None:
super().__init__(
n_features, n_dimensions, n_classes, device=device, dtype=dtype
)

self.epochs = epochs
self.lr = lr

self.adjust_reset()

self.encoder = Sinusoid(n_features, n_dimensions, device=device, dtype=dtype)
self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype)

def adjust_reset(self):
self.similarity_sum = 0
self.count = 0
self.error_similarity_sum = 0
self.error_count = 0

def fit(self, data_loader: DataLoader):
for _ in range(self.epochs):
for samples, labels in data_loader:
samples = samples.to(self.device)
labels = labels.to(self.device)

encoded = self.encoder(samples)
logits = self.model(encoded)

top2_pred = torch.topk(logits, 2)
pred = top2_pred.indices[:, 0]
is_wrong = labels != pred

w = 1 - top2_pred.values[:, 0] - top2_pred.values[:, 1]

self.similarity_sum += logits.max(1).values.item()
self.count += 1
if self.error_count == 0:
val = self.similarity_sum / self.count
else:
val = self.error_similarity_sum / self.error_count
if is_wrong.sum().item() == 0:
if logits.max(1).values.item() < val:
self.model.weight.index_add_(0, labels, self.lr * w * encoded)
return

self.error_count += 1
self.error_similarity_sum += logits.max(1).values.item()

logits = logits[is_wrong]
encoded = encoded[is_wrong]
labels = labels[is_wrong]
pred = pred[is_wrong]

alpha1 = 1.0 - logits.gather(1, labels.unsqueeze(1))
alpha2 = logits.gather(1, pred.unsqueeze(1)) - 1

self.model.weight.index_add_(
0, labels, alpha1 * w * encoded, alpha=self.lr
)
self.model.weight.index_add_(
0, pred, alpha2 * w * encoded, alpha=self.lr
)

return self



# Adapted from: https://gitlab.com/biaslab/neuralhd
class NeuralHD(Classifier):
Expand Down Expand Up @@ -464,7 +381,7 @@ def fit(self, data_loader: DataLoader):
encoded = self.encoder(samples)
self.model.add(encoded, labels)

for epoch_idx in range(1, self.epochs):
for epoch_idx in trange(1, self.epochs, desc="fit"):
for samples, labels in data_loader:
samples = samples.to(self.device)
labels = labels.to(self.device)
Expand All @@ -480,7 +397,7 @@ def fit(self, data_loader: DataLoader):
regen_dims = torch.topk(scores, n_regen_dims, largest=False).indices
self.model.weight.data[:, regen_dims].zero_()
self.encoder.weight.data[regen_dims, :].normal_()
self.encoder.bias.data[regen_dims].uniform_(0, 2 * math.pi)
self.encoder.bias.data[:, regen_dims].uniform_(0, 2 * math.pi)

return self

Expand Down Expand Up @@ -543,7 +460,7 @@ def fit(self, data_loader: DataLoader):

n_regen_dims = math.ceil(self.regen_rate * self.n_dimensions)

for epoch_idx in range(self.epochs):
for epoch_idx in trange(self.epochs, desc="fit"):
for samples, labels in data_loader:
samples = samples.to(self.device)
labels = labels.to(self.device)
Expand All @@ -567,31 +484,37 @@ def fit(self, data_loader: DataLoader):
return self

def regen_score(self, samples, labels):
scores = self(samples)
encoded = self.encoder(samples)
scores = self.model(encoded)
top2_preds = torch.topk(scores, k=2).indices
pred1, pred2 = torch.unbind(top2_preds, dim=-1)
wrong = pred1 != labels
is_wrong = pred1 != labels

# cancel update if all predictions were correct
if is_wrong.sum().item() == 0:
return 0

samples = samples[wrong]
pred2 = pred2[wrong]
labels = labels[wrong]
pred1 = pred1[wrong]
encoded = encoded[is_wrong]
pred2 = pred2[is_wrong]
labels = labels[is_wrong]
pred1 = pred1[is_wrong]

weight = F.normalize(self.model.weight, dim=1)

# Partial correct
partial = pred2 == labels
dist2corr = torch.abs(weight[labels[partial]] - samples[partial])
dist2incorr = torch.abs(weight[pred1[partial]] - samples[partial])

dist2corr = torch.abs(weight[labels[partial]] - encoded[partial])
dist2incorr = torch.abs(weight[pred1[partial]] - encoded[partial])
partial_dist = torch.sum(
(self.beta * dist2incorr - self.alpha * dist2corr), dim=0
)

# Completely incorrect
complete = pred2 != labels
dist2corr = torch.abs(weight[labels[complete]] - samples[complete])
dist2incorr1 = torch.abs(weight[pred1[complete]] - samples[complete])
dist2incorr2 = torch.abs(weight[pred2[complete]] - samples[complete])
dist2corr = torch.abs(weight[labels[complete]] - encoded[complete])
dist2incorr1 = torch.abs(weight[pred1[complete]] - encoded[complete])
dist2incorr2 = torch.abs(weight[pred2[complete]] - encoded[complete])
complete_dist = torch.sum(
(
self.beta * dist2incorr1
Expand Down Expand Up @@ -637,7 +560,7 @@ def __init__(
max_level: int = 1,
epochs: int = 120,
lr: float = 0.01,
weight_decay: float = 0.03,
weight_decay: float = 0.003,
dropout_rate: float = 0.3,
device: torch.device = None,
dtype: torch.dtype = None
Expand Down Expand Up @@ -690,7 +613,7 @@ def fit(self, data_loader: DataLoader):

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2)

for _ in range(self.epochs):
for _ in trange(self.epochs, desc="fit"):
accumulated_loss = 0

for samples, labels in data_loader:
Expand Down Expand Up @@ -747,7 +670,7 @@ def __init__(
n_levels: int = 100,
min_level: int = -1,
max_level: int = 1,
chunks: int = 10,
chunks: int = 4,
device: torch.device = None,
dtype: torch.dtype = None
) -> None:
Expand Down Expand Up @@ -787,7 +710,7 @@ def forward(self, samples: Tensor) -> Tensor:
return self.model(self.compress(self.encoder(samples)))

def compress(self, input):
shape = (self.chunks, self.n_dimensions // self.chunks)
shape = (input.size(0), self.chunks, self.n_dimensions // self.chunks)
return functional.hash_table(self.chunk_keys, torch.reshape(input, shape))

def fit(self, data_loader: DataLoader):
Expand Down Expand Up @@ -871,7 +794,7 @@ def encoder(self, samples: Tensor) -> Tensor:
return functional.hash_table(self.feat_keys.weight, self.levels(samples)).sign()

def fit(self, data_loader: DataLoader):
for _ in range(self.epochs):
for _ in trange(self.epochs, desc="fit"):
for samples, labels in data_loader:
samples = samples.to(self.device)
labels = labels.to(self.device)
Expand Down Expand Up @@ -991,7 +914,7 @@ def fit(self, data_loader: DataLoader):

self.binarize()

for _ in range(1, self.epochs):
for _ in trange(1, self.epochs, desc="fit"):
for samples, labels in data_loader:
samples = samples.to(self.device)
labels = labels.to(self.device)
Expand Down

0 comments on commit 7267d86

Please sign in to comment.