Skip to content

Commit

Permalink
Add IntRVFL
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeheddes committed Mar 12, 2024
1 parent f46644d commit 5637645
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 48 deletions.
10 changes: 10 additions & 0 deletions docs/_templates/class_classifier.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
.. role:: hidden
:class: hidden-section
.. currentmodule:: {{ module }}


{{ name | underline}}

.. autoclass:: {{ name }}
:members: fit, predict, accuracy
:special-members: __call__
10 changes: 6 additions & 4 deletions docs/classifiers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@ torchhd.classifiers
.. currentmodule:: torchhd.classifiers

.. autosummary::
:nosignatures:
:toctree: generated/
:template: class.rst

:template: class_classifier.rst
Classifier
Vanilla
AdaptHD
OnlineHD
RefineHD
NeuralHD
DistHD
LeHDC
CompHD
SparseHD
QuantHD
QuantHD
LeHDC
IntRVFL
142 changes: 98 additions & 44 deletions torchhd/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
from typing import Type, Union, Optional, Literal, Callable, Iterable, Tuple
from typing import Optional, Literal, Callable, Iterable, Tuple
import math
import scipy.linalg
import torch
Expand All @@ -30,8 +30,8 @@
from torch import Tensor, LongTensor

import torchhd.functional as functional
from torchhd.embeddings import Random, Level, Projection, Sinusoid
from torchhd.models import Centroid
from torchhd.embeddings import Random, Level, Projection, Sinusoid, Density
from torchhd.models import Centroid, IntRVFL as IntRVFLModel

DataLoader = Iterable[Tuple[Tensor, LongTensor]]

Expand All @@ -40,13 +40,13 @@
"Vanilla",
"AdaptHD",
"OnlineHD",
"RefineHD",
"NeuralHD",
"DistHD",
"LeHDC",
"CompHD",
"SparseHD",
"QuantHD",
"LeHDC",
"IntRVFL",
]


Expand Down Expand Up @@ -85,6 +85,9 @@ def device(self) -> torch.device:
return self.model.weight.device

def forward(self, samples: Tensor) -> Tensor:
return self.model(self.encoder(samples))

def __call__(self, samples: Tensor) -> Tensor:
"""Evaluate the logits of the classifier for the given samples.
Args:
Expand All @@ -94,7 +97,7 @@ def forward(self, samples: Tensor) -> Tensor:
Tensor: Logits of each samples for each class.
"""
return self.model(self.encoder(samples))
return super().__call__(samples)

def fit(self, data_loader: DataLoader):
"""Fits the classifier to the provided data.
Expand Down Expand Up @@ -142,9 +145,9 @@ class Vanilla(Classifier):
n_features (int): Size of each input sample.
n_dimensions (int): The number of hidden dimensions to use.
n_classes (int): The number of classes.
n_levels (int): The number of discretized levels for the level-hypervectors.
min_level (int): The lower-bound of the range represented by the level-hypervectors.
max_level (int): The upper-bound of the range represented by the level-hypervectors.
n_levels (int, optional): The number of discretized levels for the level-hypervectors.
min_level (int, optional): The lower-bound of the range represented by the level-hypervectors.
max_level (int, optional): The upper-bound of the range represented by the level-hypervectors.
device (``torch.device``, optional): the desired device of the weights. Default: if ``None``, uses the current device for the default tensor type (see ``torch.set_default_tensor_type()``). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
dtype (``torch.dtype``, optional): the desired data type of the weights. Default: if ``None``, uses ``torch.get_default_dtype()``.
Expand Down Expand Up @@ -200,11 +203,11 @@ class AdaptHD(Classifier):
n_features (int): Size of each input sample.
n_dimensions (int): The number of hidden dimensions to use.
n_classes (int): The number of classes.
n_levels (int): The number of discretized levels for the level-hypervectors.
min_level (int): The lower-bound of the range represented by the level-hypervectors.
max_level (int): The upper-bound of the range represented by the level-hypervectors.
epochs (int): The number of iteration over the training data.
lr (float): The learning rate.
n_levels (int, optional): The number of discretized levels for the level-hypervectors.
min_level (int, optional): The lower-bound of the range represented by the level-hypervectors.
max_level (int, optional): The upper-bound of the range represented by the level-hypervectors.
epochs (int, optional): The number of iteration over the training data.
lr (float, optional): The learning rate.
device (``torch.device``, optional): the desired device of the weights. Default: if ``None``, uses the current device for the default tensor type (see ``torch.set_default_tensor_type()``). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
dtype (``torch.dtype``, optional): the desired data type of the weights. Default: if ``None``, uses ``torch.get_default_dtype()``.
Expand Down Expand Up @@ -268,8 +271,8 @@ class OnlineHD(Classifier):
n_features (int): Size of each input sample.
n_dimensions (int): The number of hidden dimensions to use.
n_classes (int): The number of classes.
epochs (int): The number of iteration over the training data.
lr (float): The learning rate.
epochs (int, optional): The number of iteration over the training data.
lr (float, optional): The learning rate.
device (``torch.device``, optional): the desired device of the weights. Default: if ``None``, uses the current device for the default tensor type (see ``torch.set_default_tensor_type()``). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
dtype (``torch.dtype``, optional): the desired data type of the weights. Default: if ``None``, uses ``torch.get_default_dtype()``.
Expand Down Expand Up @@ -413,10 +416,10 @@ class NeuralHD(Classifier):
n_features (int): Size of each input sample.
n_dimensions (int): The number of hidden dimensions to use.
n_classes (int): The number of classes.
regen_freq (int): The frequency in epochs at which to regenerate hidden dimensions.
regen_rate (int): The fraction of hidden dimensions to regenerate.
epochs (int): The number of iteration over the training data.
lr (float): The learning rate.
regen_freq (int, optional): The frequency in epochs at which to regenerate hidden dimensions.
regen_rate (int, optional): The fraction of hidden dimensions to regenerate.
epochs (int, optional): The number of iteration over the training data.
lr (float, optional): The learning rate.
device (``torch.device``, optional): the desired device of the weights. Default: if ``None``, uses the current device for the default tensor type (see ``torch.set_default_tensor_type()``). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
dtype (``torch.dtype``, optional): the desired data type of the weights. Default: if ``None``, uses ``torch.get_default_dtype()``.
Expand Down Expand Up @@ -604,18 +607,17 @@ def regen_score(self, samples, labels):
class LeHDC(Classifier):
r"""Implements `DistHD: A Learner-Aware Dynamic Encoding Method for Hyperdimensional Classification <https://ieeexplore.ieee.org/document/10247876>`_.
Args:
n_features (int): Size of each input sample.
n_dimensions (int): The number of hidden dimensions to use.
n_classes (int): The number of classes.
n_levels (int): The number of discretized levels for the level-hypervectors.
min_level (int): The lower-bound of the range represented by the level-hypervectors.
max_level (int): The upper-bound of the range represented by the level-hypervectors.
epochs (int): The number of iteration over the training data.
lr (float): The learning rate.
weight_decay (float): The rate at which the weights of the model are decayed during training.
dropout_rate (float): The fraction of hidden dimensions to randomly zero-out.
n_levels (int, optional): The number of discretized levels for the level-hypervectors.
min_level (int, optional): The lower-bound of the range represented by the level-hypervectors.
max_level (int, optional): The upper-bound of the range represented by the level-hypervectors.
epochs (int, optional): The number of iteration over the training data.
lr (float, optional): The learning rate.
weight_decay (float, optional): The rate at which the weights of the model are decayed during training.
dropout_rate (float, optional): The fraction of hidden dimensions to randomly zero-out.
device (``torch.device``, optional): the desired device of the weights. Default: if ``None``, uses the current device for the default tensor type (see ``torch.set_default_tensor_type()``). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
dtype (``torch.dtype``, optional): the desired data type of the weights. Default: if ``None``, uses ``torch.get_default_dtype()``.
Expand Down Expand Up @@ -725,10 +727,10 @@ class CompHD(Classifier):
n_features (int): Size of each input sample.
n_dimensions (int): The number of hidden dimensions to use.
n_classes (int): The number of classes.
n_levels (int): The number of discretized levels for the level-hypervectors.
min_level (int): The lower-bound of the range represented by the level-hypervectors.
max_level (int): The upper-bound of the range represented by the level-hypervectors.
chunks (int): The number of times the model is reduced in size.
n_levels (int, optional): The number of discretized levels for the level-hypervectors.
min_level (int, optional): The lower-bound of the range represented by the level-hypervectors.
max_level (int, optional): The upper-bound of the range represented by the level-hypervectors.
chunks (int, optional): The number of times the model is reduced in size.
device (``torch.device``, optional): the desired device of the weights. Default: if ``None``, uses the current device for the default tensor type (see ``torch.set_default_tensor_type()``). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
dtype (``torch.dtype``, optional): the desired data type of the weights. Default: if ``None``, uses ``torch.get_default_dtype()``.
Expand Down Expand Up @@ -814,13 +816,13 @@ class SparseHD(Classifier):
n_features (int): Size of each input sample.
n_dimensions (int): The number of hidden dimensions to use.
n_classes (int): The number of classes.
n_levels (int): The number of discretized levels for the level-hypervectors.
min_level (int): The lower-bound of the range represented by the level-hypervectors.
max_level (int): The upper-bound of the range represented by the level-hypervectors.
epochs (int): The number of iteration over the training data.
lr (float): The learning rate.
sparsity (float): The fraction of weights to be zero.
sparsity_type (str): The way in which to apply the sparsity, per hidden dimension, or per class.
n_levels (int, optional): The number of discretized levels for the level-hypervectors.
min_level (int, optional): The lower-bound of the range represented by the level-hypervectors.
max_level (int, optional): The upper-bound of the range represented by the level-hypervectors.
epochs (int, optional): The number of iteration over the training data.
lr (float, optional): The learning rate.
sparsity (float, optional): The fraction of weights to be zero.
sparsity_type (str, optional): The way in which to apply the sparsity, per hidden dimension, or per class.
device (``torch.device``, optional): the desired device of the weights. Default: if ``None``, uses the current device for the default tensor type (see ``torch.set_default_tensor_type()``). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
dtype (``torch.dtype``, optional): the desired data type of the weights. Default: if ``None``, uses ``torch.get_default_dtype()``.
Expand Down Expand Up @@ -908,11 +910,11 @@ class QuantHD(Classifier):
n_features (int): Size of each input sample.
n_dimensions (int): The number of hidden dimensions to use.
n_classes (int): The number of classes.
n_levels (int): The number of discretized levels for the level-hypervectors.
min_level (int): The lower-bound of the range represented by the level-hypervectors.
max_level (int): The upper-bound of the range represented by the level-hypervectors.
epochs (int): The number of iteration over the training data.
lr (float): The learning rate.
n_levels (int, optional): The number of discretized levels for the level-hypervectors.
min_level (int, optional): The lower-bound of the range represented by the level-hypervectors.
max_level (int, optional): The upper-bound of the range represented by the level-hypervectors.
epochs (int, optional): The number of iteration over the training data.
lr (float, optional): The learning rate.
device (``torch.device``, optional): the desired device of the weights. Default: if ``None``, uses the current device for the default tensor type (see ``torch.set_default_tensor_type()``). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
dtype (``torch.dtype``, optional): the desired data type of the weights. Default: if ``None``, uses ``torch.get_default_dtype()``.
Expand Down Expand Up @@ -1000,3 +1002,55 @@ def fit(self, data_loader: DataLoader):
self.binarize()

return self


class IntRVFL(Classifier):
r"""Implements `Density Encoding Enables Resource-Efficient Randomly Connected Neural Networks <https://doi.org/10.1109/TNNLS.2020.3015971>`_.
Args:
n_features (int): Size of each input sample.
n_dimensions (int): The number of hidden dimensions to use.
n_classes (int): The number of classes.
kappa (int, optional): Parameter of the clipping function limiting the range of values; used as the part of transforming input data.
alpha (float, optional): Scalar for the variance of the samples. Default is 1.
device (``torch.device``, optional): the desired device of the weights. Default: if ``None``, uses the current device for the default tensor type (see ``torch.set_default_tensor_type()``). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
dtype (``torch.dtype``, optional): the desired data type of the weights. Default: if ``None``, uses ``torch.get_default_dtype()``.
requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``.
"""

model: IntRVFLModel
encoder: Density

def __init__(
self,
n_features: int,
n_dimensions: int,
n_classes: int,
*,
kappa: Optional[int] = None,
alpha: float = 1,
device: torch.device = None,
dtype: torch.dtype = None
) -> None:
super().__init__(
n_features, n_dimensions, n_classes, device=device, dtype=dtype
)

self.alpha = alpha

self.model = IntRVFLModel(
n_features, n_dimensions, n_classes, kappa=kappa, device=device, dtype=dtype
)
self.encoder = self.model.encoding

def forward(self, samples: Tensor) -> Tensor:
return self.model(samples)

def fit(self, data_loader: DataLoader):

samples, labels = list(zip(*data_loader))
samples = torch.cat(samples, dim=0).to(self.device)
labels = torch.cat(labels, dim=0).to(self.device)

return self.model.fit_ridge_regression(samples, labels, alpha=self.alpha)

0 comments on commit 5637645

Please sign in to comment.