diff --git a/examples/classifiers.py b/examples/classifiers.py new file mode 100644 index 00000000..05f3a6c0 --- /dev/null +++ b/examples/classifiers.py @@ -0,0 +1,82 @@ +import torch +import torchhd +from torchhd.datasets.isolet import ISOLET + +classifiers = [ + "Vanilla", + "AdaptHD", + "OnlineHD", + "NeuralHD", + "DistHD", + "CompHD", + "SparseHD", + "QuantHD", + "LeHDC", + "IntRVFL", +] + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print("Using {} device".format(device)) + +DIMENSIONS = 1024 # number of hypervector dimensions +BATCH_SIZE = 12 # for GPUs with enough memory we can process multiple images at ones + +train_ds = ISOLET("../data", train=True, download=True) +train_ld = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) + +test_ds = ISOLET("../data", train=False, download=True) +test_ld = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False) + +num_features = train_ds[0][0].size(-1) +num_classes = len(train_ds.classes) + +std, mean = torch.std_mean(train_ds.data, dim=0, keepdim=False) + +def transform(sample): + return (sample - mean) / std + +train_ds.transform = transform +test_ds.transform = transform + +params = { + "Vanilla": { + }, + "AdaptHD": { + "epochs": 10, + }, + "OnlineHD": { + "epochs": 10, + }, + "NeuralHD": { + "epochs": 10, + "regen_freq": 5, + }, + "DistHD": { + "epochs": 10, + "regen_freq": 5, + }, + "CompHD": { + }, + "SparseHD": { + "epochs": 10, + }, + "QuantHD": { + "epochs": 10, + }, + "LeHDC": { + "epochs": 10, + }, + "IntRVFL": { + }, +} + +for classifier in classifiers: + print() + print(classifier) + + model_cls = getattr(torchhd.classifiers, classifier) + model: torchhd.classifiers.Classifier = model_cls(num_features, DIMENSIONS, num_classes, device=device, **params[classifier]) + + model.fit(train_ld) + accuracy = model.accuracy(test_ld) + print(f"Testing accuracy of {(accuracy * 100):.3f}%") diff --git a/torchhd/classifiers.py b/torchhd/classifiers.py index a17d9735..287dc527 100644 --- a/torchhd/classifiers.py +++ b/torchhd/classifiers.py @@ -24,6 +24,7 @@ from typing import Optional, Literal, Callable, Iterable, Tuple import math import scipy.linalg +from tqdm import trange import torch import torch.nn as nn import torch.nn.functional as F @@ -123,19 +124,28 @@ def predict(self, samples: Tensor) -> LongTensor: """ return torch.argmax(self(samples), dim=-1) - def accuracy(self, samples: Tensor, labels: LongTensor) -> float: + def accuracy(self, data_loader: DataLoader) -> float: """Accuracy in predicting the labels of the samples. Args: - samples (Tensor): Batch of samples to be classified. - labels (LongTensor): Batch of true labels of the samples. + data_loader (DataLoader): Iterable of tuples containing a batch of samples and labels. Returns: float: The accuracy of predicting the true labels. """ - predictions = self.predict(samples) - return torch.mean(predictions == labels, dtype=torch.float).item() + n_correct = 0 + n_total = 0 + + for samples, labels in data_loader: + samples = samples.to(self.device) + labels = labels.to(self.device) + + predictions = self.predict(samples) + n_correct += torch.sum(predictions == labels, dtype=torch.long).item() + n_total += predictions.numel() + + return n_correct / n_total class Vanilla(Classifier): @@ -252,7 +262,7 @@ def encoder(self, samples: Tensor) -> Tensor: def fit(self, data_loader: DataLoader): - for _ in range(self.epochs): + for _ in trange(self.epochs, desc="fit"): for samples, labels in data_loader: samples = samples.to(self.device) labels = labels.to(self.device) @@ -304,7 +314,7 @@ def __init__( def fit(self, data_loader: DataLoader): - for _ in range(self.epochs): + for _ in trange(self.epochs, desc="fit"): for samples, labels in data_loader: samples = samples.to(self.device) labels = labels.to(self.device) @@ -313,100 +323,7 @@ def fit(self, data_loader: DataLoader): self.model.add_online(encoded, labels, lr=self.lr) return self - - -class RefineHD(Classifier): - r"""Implements `RefineHD: : Accurate and Efficient Single-Pass Adaptive Learning Using Hyperdimensional Computing `_. - - Args: - n_features (int): Size of each input sample. - n_dimensions (int): The number of hidden dimensions to use. - n_classes (int): The number of classes. - epochs (int): The number of iteration over the training data. - lr (float): The learning rate. - device (``torch.device``, optional): the desired device of the weights. Default: if ``None``, uses the current device for the default tensor type (see ``torch.set_default_tensor_type()``). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. - dtype (``torch.dtype``, optional): the desired data type of the weights. Default: if ``None``, uses ``torch.get_default_dtype()``. - - """ - - encoder: Sinusoid - model: Centroid - - def __init__( - self, - n_features: int, - n_dimensions: int, - n_classes: int, - *, - epochs: int = 120, - lr: float = 0.035, - device: torch.device = None, - dtype: torch.dtype = None - ) -> None: - super().__init__( - n_features, n_dimensions, n_classes, device=device, dtype=dtype - ) - - self.epochs = epochs - self.lr = lr - - self.adjust_reset() - - self.encoder = Sinusoid(n_features, n_dimensions, device=device, dtype=dtype) - self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype) - - def adjust_reset(self): - self.similarity_sum = 0 - self.count = 0 - self.error_similarity_sum = 0 - self.error_count = 0 - - def fit(self, data_loader: DataLoader): - for _ in range(self.epochs): - for samples, labels in data_loader: - samples = samples.to(self.device) - labels = labels.to(self.device) - - encoded = self.encoder(samples) - logits = self.model(encoded) - - top2_pred = torch.topk(logits, 2) - pred = top2_pred.indices[:, 0] - is_wrong = labels != pred - - w = 1 - top2_pred.values[:, 0] - top2_pred.values[:, 1] - - self.similarity_sum += logits.max(1).values.item() - self.count += 1 - if self.error_count == 0: - val = self.similarity_sum / self.count - else: - val = self.error_similarity_sum / self.error_count - if is_wrong.sum().item() == 0: - if logits.max(1).values.item() < val: - self.model.weight.index_add_(0, labels, self.lr * w * encoded) - return - - self.error_count += 1 - self.error_similarity_sum += logits.max(1).values.item() - - logits = logits[is_wrong] - encoded = encoded[is_wrong] - labels = labels[is_wrong] - pred = pred[is_wrong] - - alpha1 = 1.0 - logits.gather(1, labels.unsqueeze(1)) - alpha2 = logits.gather(1, pred.unsqueeze(1)) - 1 - - self.model.weight.index_add_( - 0, labels, alpha1 * w * encoded, alpha=self.lr - ) - self.model.weight.index_add_( - 0, pred, alpha2 * w * encoded, alpha=self.lr - ) - - return self - + # Adapted from: https://gitlab.com/biaslab/neuralhd class NeuralHD(Classifier): @@ -464,7 +381,7 @@ def fit(self, data_loader: DataLoader): encoded = self.encoder(samples) self.model.add(encoded, labels) - for epoch_idx in range(1, self.epochs): + for epoch_idx in trange(1, self.epochs, desc="fit"): for samples, labels in data_loader: samples = samples.to(self.device) labels = labels.to(self.device) @@ -480,7 +397,7 @@ def fit(self, data_loader: DataLoader): regen_dims = torch.topk(scores, n_regen_dims, largest=False).indices self.model.weight.data[:, regen_dims].zero_() self.encoder.weight.data[regen_dims, :].normal_() - self.encoder.bias.data[regen_dims].uniform_(0, 2 * math.pi) + self.encoder.bias.data[:, regen_dims].uniform_(0, 2 * math.pi) return self @@ -543,7 +460,7 @@ def fit(self, data_loader: DataLoader): n_regen_dims = math.ceil(self.regen_rate * self.n_dimensions) - for epoch_idx in range(self.epochs): + for epoch_idx in trange(self.epochs, desc="fit"): for samples, labels in data_loader: samples = samples.to(self.device) labels = labels.to(self.device) @@ -567,31 +484,37 @@ def fit(self, data_loader: DataLoader): return self def regen_score(self, samples, labels): - scores = self(samples) + encoded = self.encoder(samples) + scores = self.model(encoded) top2_preds = torch.topk(scores, k=2).indices pred1, pred2 = torch.unbind(top2_preds, dim=-1) - wrong = pred1 != labels + is_wrong = pred1 != labels + + # cancel update if all predictions were correct + if is_wrong.sum().item() == 0: + return 0 - samples = samples[wrong] - pred2 = pred2[wrong] - labels = labels[wrong] - pred1 = pred1[wrong] + encoded = encoded[is_wrong] + pred2 = pred2[is_wrong] + labels = labels[is_wrong] + pred1 = pred1[is_wrong] weight = F.normalize(self.model.weight, dim=1) # Partial correct partial = pred2 == labels - dist2corr = torch.abs(weight[labels[partial]] - samples[partial]) - dist2incorr = torch.abs(weight[pred1[partial]] - samples[partial]) + + dist2corr = torch.abs(weight[labels[partial]] - encoded[partial]) + dist2incorr = torch.abs(weight[pred1[partial]] - encoded[partial]) partial_dist = torch.sum( (self.beta * dist2incorr - self.alpha * dist2corr), dim=0 ) # Completely incorrect complete = pred2 != labels - dist2corr = torch.abs(weight[labels[complete]] - samples[complete]) - dist2incorr1 = torch.abs(weight[pred1[complete]] - samples[complete]) - dist2incorr2 = torch.abs(weight[pred2[complete]] - samples[complete]) + dist2corr = torch.abs(weight[labels[complete]] - encoded[complete]) + dist2incorr1 = torch.abs(weight[pred1[complete]] - encoded[complete]) + dist2incorr2 = torch.abs(weight[pred2[complete]] - encoded[complete]) complete_dist = torch.sum( ( self.beta * dist2incorr1 @@ -637,7 +560,7 @@ def __init__( max_level: int = 1, epochs: int = 120, lr: float = 0.01, - weight_decay: float = 0.03, + weight_decay: float = 0.003, dropout_rate: float = 0.3, device: torch.device = None, dtype: torch.dtype = None @@ -690,7 +613,7 @@ def fit(self, data_loader: DataLoader): scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2) - for _ in range(self.epochs): + for _ in trange(self.epochs, desc="fit"): accumulated_loss = 0 for samples, labels in data_loader: @@ -747,7 +670,7 @@ def __init__( n_levels: int = 100, min_level: int = -1, max_level: int = 1, - chunks: int = 10, + chunks: int = 4, device: torch.device = None, dtype: torch.dtype = None ) -> None: @@ -787,7 +710,7 @@ def forward(self, samples: Tensor) -> Tensor: return self.model(self.compress(self.encoder(samples))) def compress(self, input): - shape = (self.chunks, self.n_dimensions // self.chunks) + shape = (input.size(0), self.chunks, self.n_dimensions // self.chunks) return functional.hash_table(self.chunk_keys, torch.reshape(input, shape)) def fit(self, data_loader: DataLoader): @@ -871,7 +794,7 @@ def encoder(self, samples: Tensor) -> Tensor: return functional.hash_table(self.feat_keys.weight, self.levels(samples)).sign() def fit(self, data_loader: DataLoader): - for _ in range(self.epochs): + for _ in trange(self.epochs, desc="fit"): for samples, labels in data_loader: samples = samples.to(self.device) labels = labels.to(self.device) @@ -991,7 +914,7 @@ def fit(self, data_loader: DataLoader): self.binarize() - for _ in range(1, self.epochs): + for _ in trange(1, self.epochs, desc="fit"): for samples, labels in data_loader: samples = samples.to(self.device) labels = labels.to(self.device)