From 0857091bfdaf1cfea08fa001f856f62d3f8ba302 Mon Sep 17 00:00:00 2001 From: Mike Heddes Date: Fri, 8 Mar 2024 13:02:26 -0800 Subject: [PATCH] Add LeHDC implementation --- torchhd/classify.py | 161 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 148 insertions(+), 13 deletions(-) diff --git a/torchhd/classify.py b/torchhd/classify.py index a7f68e74..453e3698 100644 --- a/torchhd/classify.py +++ b/torchhd/classify.py @@ -44,6 +44,7 @@ "OnlineHD", "NeuralHD", "DistHD", + "LeHDC", ] @@ -67,6 +68,10 @@ def __init__( self.n_dimensions = n_dimensions self.n_classes = n_classes + @property + def device(self) -> torch.device: + raise NotImplementedError() + def forward(self, samples: Tensor) -> Tensor: return self.model(self.encoder(samples)) @@ -94,7 +99,6 @@ def __init__( n_levels: int = 100, min_level: int = -1, max_level: int = 1, - batch_size: Union[int, None] = 1024, device: torch.device = None, dtype: torch.dtype = None ) -> None: @@ -102,8 +106,6 @@ def __init__( n_features, n_dimensions, n_classes, device=device, dtype=dtype ) - self.batch_size = batch_size - self.keys = Random(n_features, n_dimensions, device=device, dtype=dtype) self.levels = Level( n_levels, @@ -115,12 +117,19 @@ def __init__( ) self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype) + @property + def device(self) -> torch.device: + return self.model.weight.device + def encoder(self, samples: Tensor) -> Tensor: return functional.hash_table(self.keys.weight, self.levels(samples)).sign() def fit(self, data_loader: DataLoader) -> Self: for samples, labels in data_loader: + samples = samples.to(self.device) + labels = labels.to(self.device) + encoded = self.encoder(samples) self.model.add(encoded, labels) @@ -143,7 +152,6 @@ def __init__( max_level: int = 1, epochs: int = 120, lr: float = 0.035, - batch_size: Union[int, None] = 1024, device: torch.device = None, dtype: torch.dtype = None ) -> None: @@ -153,7 +161,6 @@ def __init__( self.epochs = epochs self.lr = lr - self.batch_size = batch_size self.keys = Random(n_features, n_dimensions, device=device, dtype=dtype) self.levels = Level( @@ -166,6 +173,10 @@ def __init__( ) self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype) + @property + def device(self) -> torch.device: + return self.model.weight.device + def encoder(self, samples: Tensor) -> Tensor: return functional.hash_table(self.keys.weight, self.levels(samples)).sign() @@ -173,6 +184,9 @@ def fit(self, data_loader: DataLoader) -> Self: for _ in range(self.epochs): for samples, labels in data_loader: + samples = samples.to(self.device) + labels = labels.to(self.device) + encoded = self.encoder(samples) self.model.add_adapt(encoded, labels, lr=self.lr) @@ -194,7 +208,6 @@ def __init__( *, epochs: int = 120, lr: float = 0.035, - batch_size: Union[int, None] = 1024, device: torch.device = None, dtype: torch.dtype = None ) -> None: @@ -204,15 +217,21 @@ def __init__( self.epochs = epochs self.lr = lr - self.batch_size = batch_size self.encoder = Sinusoid(n_features, n_dimensions, device=device, dtype=dtype) self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype) + @property + def device(self) -> torch.device: + return self.model.weight.device + def fit(self, data_loader: DataLoader) -> Self: for _ in range(self.epochs): for samples, labels in data_loader: + samples = samples.to(self.device) + labels = labels.to(self.device) + encoded = self.encoder(samples) self.model.add_online(encoded, labels, lr=self.lr) @@ -236,7 +255,6 @@ def __init__( regen_rate: float = 0.04, epochs: int = 120, lr: float = 0.37, - batch_size: Union[int, None] = 1024, device: torch.device = None, dtype: torch.dtype = None ) -> None: @@ -248,24 +266,34 @@ def __init__( self.regen_rate = regen_rate self.epochs = epochs self.lr = lr - self.batch_size = batch_size self.encoder = Sinusoid(n_features, n_dimensions, device=device, dtype=dtype) self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype) + @property + def device(self) -> torch.device: + return self.model.weight.device + def fit(self, data_loader: DataLoader) -> Self: n_regen_dims = math.ceil(self.regen_rate * self.n_dimensions) for samples, labels in data_loader: + samples = samples.to(self.device) + labels = labels.to(self.device) + encoded = self.encoder(samples) self.model.add(encoded, labels) for epoch_idx in range(1, self.epochs): for samples, labels in data_loader: + samples = samples.to(self.device) + labels = labels.to(self.device) + encoded = self.encoder(samples) self.model.add_adapt(encoded, labels, lr=self.lr) + # Regenerate feature dimensions if (epoch_idx % self.regen_freq) == (self.regen_freq - 1): weight = F.normalize(self.model.weight, dim=1) scores = torch.var(weight, dim=0) @@ -298,7 +326,6 @@ def __init__( theta: float = 0.25, epochs: int = 120, lr: float = 0.05, - batch_size: Union[int, None] = 1024, device: torch.device = None, dtype: torch.dtype = None ) -> None: @@ -313,23 +340,33 @@ def __init__( self.theta = theta self.epochs = epochs self.lr = lr - self.batch_size = batch_size self.encoder = Projection(n_features, n_dimensions, device=device, dtype=dtype) self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype) + @property + def device(self) -> torch.device: + return self.model.weight.device + def fit(self, data_loader: DataLoader) -> Self: n_regen_dims = math.ceil(self.regen_rate * self.n_dimensions) for epoch_idx in range(self.epochs): for samples, labels in data_loader: + samples = samples.to(self.device) + labels = labels.to(self.device) + encoded = self.encoder(samples) self.model.add_online(encoded, labels, lr=self.lr) + # Regenerate feature dimensions if (epoch_idx % self.regen_freq) == (self.regen_freq - 1): scores = 0 for samples, labels in data_loader: + samples = samples.to(self.device) + labels = labels.to(self.device) + scores += self.regen_score(samples, labels) regen_dims = torch.topk(scores, n_regen_dims, largest=False).indices @@ -351,7 +388,7 @@ def regen_score(self, samples, labels): weight = F.normalize(self.model.weight, dim=1) - # partial correct + # Partial correct partial = pred2 == labels dist2corr = torch.abs(weight[labels[partial]] - samples[partial]) dist2incorr = torch.abs(weight[pred1[partial]] - samples[partial]) @@ -359,7 +396,7 @@ def regen_score(self, samples, labels): (self.beta * dist2incorr - self.alpha * dist2corr), dim=0 ) - # completely incorrect + # Completely incorrect complete = pred2 != labels dist2corr = torch.abs(weight[labels[complete]] - samples[complete]) dist2incorr1 = torch.abs(weight[pred1[complete]] - samples[complete]) @@ -374,3 +411,101 @@ def regen_score(self, samples, labels): ) return 0.5 * partial_dist + complete_dist + + +class LeHDC(Classifier): + r"""Implements `DistHD: A Learner-Aware Dynamic Encoding Method for Hyperdimensional Classification `_.""" + + encoder: Projection + model: Centroid + + def __init__( + self, + n_features: int, + n_dimensions: int, + n_classes: int, + *, + n_levels: int = 100, + min_level: int = -1, + max_level: int = 1, + epochs: int = 120, + lr: float = 0.01, + weight_decay: float = 0.03, + dropout_rate: float = 0.3, + device: torch.device = None, + dtype: torch.dtype = None + ) -> None: + super().__init__( + n_features, n_dimensions, n_classes, device=device, dtype=dtype + ) + + self.epochs = epochs + self.lr = lr + self.weight_decay = weight_decay + + self.keys = Random(n_features, n_dimensions, device=device, dtype=dtype) + self.levels = Level( + n_levels, + n_dimensions, + low=min_level, + high=max_level, + device=device, + dtype=dtype, + ) + self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype) + self.dropout = torch.nn.Dropout(dropout_rate) + # Gradient model accumulates gradients + self.grad_model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype, requires_grad=True) + # Regular model is a binarized version of the gradient model + self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype, requires_grad=True) + + @property + def device(self) -> torch.device: + return self.model.weight.device + + def encoder(self, samples: Tensor) -> Tensor: + return functional.hash_table(self.keys.weight, self.levels(samples)).sign() + + def forward(self, samples: Tensor) -> Tensor: + return self.model(self.dropout(self.encoder(samples))) + + def fit(self, data_loader: DataLoader) -> Self: + + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.Adam( + self.grad_model.parameters(), + lr=self.lr, + weight_decay=self.weight_decay, + ) + + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2) + + for _ in range(self.epochs): + accumulated_loss = 0 + + for samples, labels in data_loader: + samples = samples.to(self.device) + labels = labels.to(self.device) + + logits = self(samples) + loss = criterion(logits, labels) + accumulated_loss += loss.detach().item() + + # Zero out all the gradients + self.grad_model.zero_grad() + self.model.zero_grad() + + loss.backward() + + # The gradient model is updated using the gradients from the binarized model + self.grad_model.weight.grad = self.model.weight.grad + optimizer.step() + + # Quantize the weights + with torch.no_grad(): + self.model.weight.data = self.grad_model.weight.sign() + + scheduler.step(accumulated_loss) + + return self +