Skip to content

Commit

Permalink
Add LeHDC implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeheddes committed Mar 8, 2024
1 parent 4d590db commit 0857091
Showing 1 changed file with 148 additions and 13 deletions.
161 changes: 148 additions & 13 deletions torchhd/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"OnlineHD",
"NeuralHD",
"DistHD",
"LeHDC",
]


Expand All @@ -67,6 +68,10 @@ def __init__(
self.n_dimensions = n_dimensions
self.n_classes = n_classes

@property
def device(self) -> torch.device:
raise NotImplementedError()

def forward(self, samples: Tensor) -> Tensor:
return self.model(self.encoder(samples))

Expand Down Expand Up @@ -94,16 +99,13 @@ def __init__(
n_levels: int = 100,
min_level: int = -1,
max_level: int = 1,
batch_size: Union[int, None] = 1024,
device: torch.device = None,
dtype: torch.dtype = None
) -> None:
super().__init__(
n_features, n_dimensions, n_classes, device=device, dtype=dtype
)

self.batch_size = batch_size

self.keys = Random(n_features, n_dimensions, device=device, dtype=dtype)
self.levels = Level(
n_levels,
Expand All @@ -115,12 +117,19 @@ def __init__(
)
self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype)

@property
def device(self) -> torch.device:
return self.model.weight.device

def encoder(self, samples: Tensor) -> Tensor:
return functional.hash_table(self.keys.weight, self.levels(samples)).sign()

def fit(self, data_loader: DataLoader) -> Self:

for samples, labels in data_loader:
samples = samples.to(self.device)
labels = labels.to(self.device)

encoded = self.encoder(samples)
self.model.add(encoded, labels)

Expand All @@ -143,7 +152,6 @@ def __init__(
max_level: int = 1,
epochs: int = 120,
lr: float = 0.035,
batch_size: Union[int, None] = 1024,
device: torch.device = None,
dtype: torch.dtype = None
) -> None:
Expand All @@ -153,7 +161,6 @@ def __init__(

self.epochs = epochs
self.lr = lr
self.batch_size = batch_size

self.keys = Random(n_features, n_dimensions, device=device, dtype=dtype)
self.levels = Level(
Expand All @@ -166,13 +173,20 @@ def __init__(
)
self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype)

@property
def device(self) -> torch.device:
return self.model.weight.device

def encoder(self, samples: Tensor) -> Tensor:
return functional.hash_table(self.keys.weight, self.levels(samples)).sign()

def fit(self, data_loader: DataLoader) -> Self:

for _ in range(self.epochs):
for samples, labels in data_loader:
samples = samples.to(self.device)
labels = labels.to(self.device)

encoded = self.encoder(samples)
self.model.add_adapt(encoded, labels, lr=self.lr)

Expand All @@ -194,7 +208,6 @@ def __init__(
*,
epochs: int = 120,
lr: float = 0.035,
batch_size: Union[int, None] = 1024,
device: torch.device = None,
dtype: torch.dtype = None
) -> None:
Expand All @@ -204,15 +217,21 @@ def __init__(

self.epochs = epochs
self.lr = lr
self.batch_size = batch_size

self.encoder = Sinusoid(n_features, n_dimensions, device=device, dtype=dtype)
self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype)

@property
def device(self) -> torch.device:
return self.model.weight.device

def fit(self, data_loader: DataLoader) -> Self:

for _ in range(self.epochs):
for samples, labels in data_loader:
samples = samples.to(self.device)
labels = labels.to(self.device)

encoded = self.encoder(samples)
self.model.add_online(encoded, labels, lr=self.lr)

Expand All @@ -236,7 +255,6 @@ def __init__(
regen_rate: float = 0.04,
epochs: int = 120,
lr: float = 0.37,
batch_size: Union[int, None] = 1024,
device: torch.device = None,
dtype: torch.dtype = None
) -> None:
Expand All @@ -248,24 +266,34 @@ def __init__(
self.regen_rate = regen_rate
self.epochs = epochs
self.lr = lr
self.batch_size = batch_size

self.encoder = Sinusoid(n_features, n_dimensions, device=device, dtype=dtype)
self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype)

@property
def device(self) -> torch.device:
return self.model.weight.device

def fit(self, data_loader: DataLoader) -> Self:

n_regen_dims = math.ceil(self.regen_rate * self.n_dimensions)

for samples, labels in data_loader:
samples = samples.to(self.device)
labels = labels.to(self.device)

encoded = self.encoder(samples)
self.model.add(encoded, labels)

for epoch_idx in range(1, self.epochs):
for samples, labels in data_loader:
samples = samples.to(self.device)
labels = labels.to(self.device)

encoded = self.encoder(samples)
self.model.add_adapt(encoded, labels, lr=self.lr)

# Regenerate feature dimensions
if (epoch_idx % self.regen_freq) == (self.regen_freq - 1):
weight = F.normalize(self.model.weight, dim=1)
scores = torch.var(weight, dim=0)
Expand Down Expand Up @@ -298,7 +326,6 @@ def __init__(
theta: float = 0.25,
epochs: int = 120,
lr: float = 0.05,
batch_size: Union[int, None] = 1024,
device: torch.device = None,
dtype: torch.dtype = None
) -> None:
Expand All @@ -313,23 +340,33 @@ def __init__(
self.theta = theta
self.epochs = epochs
self.lr = lr
self.batch_size = batch_size

self.encoder = Projection(n_features, n_dimensions, device=device, dtype=dtype)
self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype)

@property
def device(self) -> torch.device:
return self.model.weight.device

def fit(self, data_loader: DataLoader) -> Self:

n_regen_dims = math.ceil(self.regen_rate * self.n_dimensions)

for epoch_idx in range(self.epochs):
for samples, labels in data_loader:
samples = samples.to(self.device)
labels = labels.to(self.device)

encoded = self.encoder(samples)
self.model.add_online(encoded, labels, lr=self.lr)

# Regenerate feature dimensions
if (epoch_idx % self.regen_freq) == (self.regen_freq - 1):
scores = 0
for samples, labels in data_loader:
samples = samples.to(self.device)
labels = labels.to(self.device)

scores += self.regen_score(samples, labels)

regen_dims = torch.topk(scores, n_regen_dims, largest=False).indices
Expand All @@ -351,15 +388,15 @@ def regen_score(self, samples, labels):

weight = F.normalize(self.model.weight, dim=1)

# partial correct
# Partial correct
partial = pred2 == labels
dist2corr = torch.abs(weight[labels[partial]] - samples[partial])
dist2incorr = torch.abs(weight[pred1[partial]] - samples[partial])
partial_dist = torch.sum(
(self.beta * dist2incorr - self.alpha * dist2corr), dim=0
)

# completely incorrect
# Completely incorrect
complete = pred2 != labels
dist2corr = torch.abs(weight[labels[complete]] - samples[complete])
dist2incorr1 = torch.abs(weight[pred1[complete]] - samples[complete])
Expand All @@ -374,3 +411,101 @@ def regen_score(self, samples, labels):
)

return 0.5 * partial_dist + complete_dist


class LeHDC(Classifier):
r"""Implements `DistHD: A Learner-Aware Dynamic Encoding Method for Hyperdimensional Classification <https://ieeexplore.ieee.org/document/10247876>`_."""

encoder: Projection
model: Centroid

def __init__(
self,
n_features: int,
n_dimensions: int,
n_classes: int,
*,
n_levels: int = 100,
min_level: int = -1,
max_level: int = 1,
epochs: int = 120,
lr: float = 0.01,
weight_decay: float = 0.03,
dropout_rate: float = 0.3,
device: torch.device = None,
dtype: torch.dtype = None
) -> None:
super().__init__(
n_features, n_dimensions, n_classes, device=device, dtype=dtype
)

self.epochs = epochs
self.lr = lr
self.weight_decay = weight_decay

self.keys = Random(n_features, n_dimensions, device=device, dtype=dtype)
self.levels = Level(
n_levels,
n_dimensions,
low=min_level,
high=max_level,
device=device,
dtype=dtype,
)
self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype)
self.dropout = torch.nn.Dropout(dropout_rate)
# Gradient model accumulates gradients
self.grad_model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype, requires_grad=True)
# Regular model is a binarized version of the gradient model
self.model = Centroid(n_dimensions, n_classes, device=device, dtype=dtype, requires_grad=True)

@property
def device(self) -> torch.device:
return self.model.weight.device

def encoder(self, samples: Tensor) -> Tensor:
return functional.hash_table(self.keys.weight, self.levels(samples)).sign()

def forward(self, samples: Tensor) -> Tensor:
return self.model(self.dropout(self.encoder(samples)))

def fit(self, data_loader: DataLoader) -> Self:

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
self.grad_model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2)

for _ in range(self.epochs):
accumulated_loss = 0

for samples, labels in data_loader:
samples = samples.to(self.device)
labels = labels.to(self.device)

logits = self(samples)
loss = criterion(logits, labels)
accumulated_loss += loss.detach().item()

# Zero out all the gradients
self.grad_model.zero_grad()
self.model.zero_grad()

loss.backward()

# The gradient model is updated using the gradients from the binarized model
self.grad_model.weight.grad = self.model.weight.grad
optimizer.step()

# Quantize the weights
with torch.no_grad():
self.model.weight.data = self.grad_model.weight.sign()

scheduler.step(accumulated_loss)

return self

0 comments on commit 0857091

Please sign in to comment.