Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeheddes committed Mar 18, 2024
1 parent 7953773 commit d79cdc7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
1 change: 0 additions & 1 deletion docs/classifiers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ torchhd.classifiers
Vanilla
AdaptHD
OnlineHD
RefineHD
NeuralHD
DistHD
CompHD
Expand Down
21 changes: 12 additions & 9 deletions torchhd/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __call__(self, samples: Tensor) -> Tensor:
samples (Tensor): Batch of samples to be classified.
Returns:
Tensor: Logits of each samples for each class.
Tensor: Logits of each sample for each class.
"""
return super().__call__(samples)
Expand Down Expand Up @@ -151,6 +151,8 @@ def accuracy(self, data_loader: DataLoader) -> float:
class Vanilla(Classifier):
r"""Baseline centroid classifier.
This classifier uses level-hypervectors to encode the feature values which are then combined using a hash table with random keys.
Args:
n_features (int): Size of each input sample.
n_dimensions (int): The number of hidden dimensions to use.
Expand Down Expand Up @@ -539,6 +541,7 @@ class LeHDC(Classifier):
max_level (int, optional): The upper-bound of the range represented by the level-hypervectors.
epochs (int, optional): The number of iteration over the training data.
lr (float, optional): The learning rate.
patience (int, optional): Number of epochs with no improvement after which learning rate will be reduced.
weight_decay (float, optional): The rate at which the weights of the model are decayed during training.
dropout_rate (float, optional): The fraction of hidden dimensions to randomly zero-out.
device (``torch.device``, optional): the desired device of the weights. Default: if ``None``, uses the current device for the default tensor type (see ``torch.set_default_tensor_type()``). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
Expand All @@ -560,6 +563,7 @@ def __init__(
max_level: int = 1,
epochs: int = 120,
lr: float = 0.01,
patience: int = 2,
weight_decay: float = 0.003,
dropout_rate: float = 0.3,
device: torch.device = None,
Expand All @@ -571,6 +575,7 @@ def __init__(

self.epochs = epochs
self.lr = lr
self.patience = patience
self.weight_decay = weight_decay

self.keys = Random(n_features, n_dimensions, device=device, dtype=dtype)
Expand Down Expand Up @@ -611,7 +616,9 @@ def fit(self, data_loader: DataLoader):
weight_decay=self.weight_decay,
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, patience=self.patience
)

for _ in trange(self.epochs, desc="fit"):
accumulated_loss = 0
Expand Down Expand Up @@ -711,7 +718,8 @@ def forward(self, samples: Tensor) -> Tensor:

def compress(self, input):
shape = (input.size(0), self.chunks, self.n_dimensions // self.chunks)
return functional.hash_table(self.chunk_keys, torch.reshape(input, shape))
keys = self.chunk_keys[None, ...].expand(input.size(0), -1, -1)
return functional.hash_table(keys, torch.reshape(input, shape))

def fit(self, data_loader: DataLoader):
for samples, labels in data_loader:
Expand All @@ -722,12 +730,7 @@ def fit(self, data_loader: DataLoader):
self.model_count.add(encoded, labels)

with torch.no_grad():
shape = (self.n_classes, self.chunks, self.n_dimensions // self.chunks)
weight_chunks = torch.reshape(self.model_count.weight, shape)

keys = self.chunk_keys[None, ...].expand(self.n_classes, -1, -1)
comp_weights = functional.hash_table(keys, weight_chunks)
self.model.weight.data = comp_weights
self.model.weight.data = self.compress(self.model_count.weight)

return self

Expand Down

0 comments on commit d79cdc7

Please sign in to comment.