From 13c46c661294006938b0b29c5c6121690c9aa42d Mon Sep 17 00:00:00 2001 From: Victor Vargas Date: Thu, 11 Jul 2024 08:19:55 +0200 Subject: [PATCH 1/2] fixed bug in CLM thresholds. Changed CLM initialisation for better convergence. Splitted CLM computation in different methods --- dlordinal/output_layers/clm.py | 109 ++++++++++++++-------- dlordinal/output_layers/tests/test_clm.py | 79 ++++++++++------ 2 files changed, 123 insertions(+), 65 deletions(-) diff --git a/dlordinal/output_layers/clm.py b/dlordinal/output_layers/clm.py index 1b87952..511ff6c 100644 --- a/dlordinal/output_layers/clm.py +++ b/dlordinal/output_layers/clm.py @@ -1,5 +1,5 @@ import warnings -from math import sqrt +from typing import Literal import torch from torch.nn import Module @@ -7,8 +7,9 @@ class CLM(Module): """ - Implementation of the cumulative link model from :footcite:t:`vargas2020clm` as a torch layer. - Different link functions can be used, including logit, probit and cloglog. + Implementation of the cumulative link models from :footcite:t:`vargas2020clm` as a + torch layer. Different link functions can be used, including logit, probit + and cloglog. Parameters ---------- @@ -21,52 +22,72 @@ class CLM(Module): clip_warning : bool, default=True Whether to print the clipping value warning or not. + Attributes + ---------- + num_classes : int + The number of classes. + link_function : str + The link function to use. Can be ``'logit'``, ``'probit'`` or ``'cloglog'``. + min_distance : float + The minimum distance between thresholds + clip_warning : bool + Whether to print the clipping value warning or not. + dist_ : torch.distributions.Normal + The normal (0,1) distribution used to compute the probit link function. + thresholds_b_ : torch.nn.Parameter + The torch parameter for the first threshold. + thresholds_a_ : torch.nn.Parameter + The torch parameter for the alphas of the thresholds. + clip_warning_shown_ : bool + Whether the clipping warning has been shown or not. + + Example --------- >>> import torch - >>> from dlordinal.layers import CLM - + >>> from dlordinal.output_layers import CLM >>> inp = torch.randn(10, 5) >>> fc = torch.nn.Linear(5, 1) >>> clm = CLM(5, "logit") - >>> output = clm(fc(inp)) >>> print(output) - tensor([[0.4704, 0.0063, 0.0441, 0.0423, 0.4369], - [0.2496, 0.0048, 0.0349, 0.0363, 0.6745], - [0.6384, 0.0058, 0.0393, 0.0357, 0.2808], - [0.4862, 0.0063, 0.0441, 0.0420, 0.4214], - [0.3768, 0.0060, 0.0425, 0.0421, 0.5327], - [0.4740, 0.0063, 0.0441, 0.0422, 0.4334], - [0.2868, 0.0052, 0.0378, 0.0387, 0.6315], - [0.2583, 0.0049, 0.0356, 0.0369, 0.6643], - [0.1811, 0.0038, 0.0281, 0.0300, 0.7570], - [0.5734, 0.0062, 0.0423, 0.0392, 0.3389]], grad_fn=) + tensor([[0.7944, 0.1187, 0.0531, 0.0211, 0.0127], + [0.4017, 0.2443, 0.1862, 0.0987, 0.0690], + [0.4619, 0.2381, 0.1638, 0.0814, 0.0548], + [0.4636, 0.2378, 0.1632, 0.0809, 0.0545], + [0.4330, 0.2419, 0.1746, 0.0893, 0.0612], + [0.5006, 0.2309, 0.1495, 0.0716, 0.0473], + [0.6011, 0.2027, 0.1138, 0.0504, 0.0320], + [0.5995, 0.2032, 0.1144, 0.0507, 0.0322], + [0.4014, 0.2443, 0.1863, 0.0988, 0.0691], + [0.6922, 0.1672, 0.0838, 0.0351, 0.0217]], grad_fn=) + """ def __init__( - self, num_classes, link_function, min_distance=0.0, clip_warning=True, **kwargs + self, + num_classes: int, + link_function: Literal["logit", "probit", "cloglog"], + min_distance: int = 0.0, + clip_warning: bool = True, + **kwargs, ): super().__init__() self.num_classes = num_classes self.link_function = link_function self.min_distance = min_distance self.clip_warning = clip_warning - self.dist = torch.distributions.Normal(0, 1) + self.dist_ = torch.distributions.Normal(0, 1) - self.thresholds_b = torch.nn.Parameter(data=torch.Tensor(1), requires_grad=True) - torch.nn.init.uniform_(self.thresholds_b, 0.0, 0.1) - - self.thresholds_a = torch.nn.Parameter( - data=torch.Tensor(self.num_classes - 2), requires_grad=True + self.thresholds_b_ = torch.nn.Parameter( + data=torch.Tensor([0]), requires_grad=True ) - torch.nn.init.uniform_( - self.thresholds_a, - sqrt((1.0 / (self.num_classes - 2)) / 2), - sqrt(1.0 / (self.num_classes - 2)), + self.thresholds_a_ = torch.nn.Parameter( + data=torch.Tensor([1.0 for _ in range(self.num_classes - 2)]), + requires_grad=True, ) - self.clip_warning_shown = False + self.clip_warning_shown_ = False def _convert_thresholds(self, b, a, min_distance): a = a**2 @@ -77,7 +98,7 @@ def _convert_thresholds(self, b, a, min_distance): torch.ones( (self.num_classes - 1, self.num_classes - 1), device=a.device ), - diagonal=-1, + diagonal=0, ) * torch.reshape( torch.tile(thresholds_param, (self.num_classes - 1,)).to(a.device), @@ -87,9 +108,7 @@ def _convert_thresholds(self, b, a, min_distance): ) return th - def _clm(self, projected: torch.Tensor, thresholds: torch.Tensor): - projected = torch.reshape(projected, shape=(-1,)) - + def _compute_z3(self, projected: torch.Tensor, thresholds: torch.Tensor): m = projected.shape[0] a = torch.reshape(torch.tile(thresholds, (m,)), shape=(m, -1)) b = torch.transpose( @@ -102,22 +121,34 @@ def _clm(self, projected: torch.Tensor, thresholds: torch.Tensor): z3 = a - b if torch.any(z3 > 10) or torch.any(z3 < -10): - if self.clip_warning and not self.clip_warning_shown: + if self.clip_warning and not self.clip_warning_shown_: warnings.warn( - "The output value of the CLM layer is out of the range [-10, 10]." - " Clipping value prior to applying the link function for numerical" - " stability." + f"The output value of the CLM layer (max: {z3.abs().max()}) is out " + "of the range [-10, 10]. Clipping value prior to applying the " + "link function for numerical stability." ) z3 = torch.clip(a - b, -10, 10) - self.clip_warning_shown = True + self.clip_warning_shown_ = True + + return z3 + def _apply_link_function(self, z3): if self.link_function == "probit": - a3T = self.dist.cdf(z3) + a3T = self.dist_.cdf(z3) elif self.link_function == "cloglog": a3T = 1 - torch.exp(-torch.exp(z3)) else: a3T = 1.0 / (1.0 + torch.exp(-z3)) + return a3T + + def _clm(self, projected: torch.Tensor, thresholds: torch.Tensor): + projected = torch.reshape(projected, shape=(-1,)) + + m = projected.shape[0] + z3 = self._compute_z3(projected, thresholds) + a3T = self._apply_link_function(z3) + ones = torch.ones((m, 1), device=projected.device) a3 = torch.cat((a3T, ones), dim=1) a3[:, 1:] = a3[:, 1:] - a3[:, 0:-1] @@ -138,7 +169,7 @@ def forward(self, x): """ thresholds = self._convert_thresholds( - self.thresholds_b, self.thresholds_a, self.min_distance + self.thresholds_b_, self.thresholds_a_, self.min_distance ) return self._clm(x, thresholds) diff --git a/dlordinal/output_layers/tests/test_clm.py b/dlordinal/output_layers/tests/test_clm.py index 6e89874..0bc8738 100644 --- a/dlordinal/output_layers/tests/test_clm.py +++ b/dlordinal/output_layers/tests/test_clm.py @@ -1,11 +1,22 @@ import warnings +import numpy as np import pytest import torch from dlordinal.output_layers import CLM +def _test_probas(clm): + projections = torch.rand(32, 1) + probas = clm(projections) + total_probas = torch.sum(probas, dim=1) + assert torch.allclose(total_probas, torch.ones_like(total_probas)) + assert isinstance(probas, torch.Tensor) + + return projections, probas, total_probas + + def test_clm_creation(): num_classes = 3 link_function = "logit" @@ -18,8 +29,7 @@ def test_clm_creation(): assert isinstance(clm, CLM) -def test_clm_logit(): - input_shape = 10 +def test_clm_probas(): num_classes = 5 link_function = "logit" min_distance = 0.0 @@ -27,43 +37,57 @@ def test_clm_logit(): clm = CLM( num_classes=num_classes, link_function=link_function, min_distance=min_distance ) - input_data = torch.rand(32, input_shape) - output = clm(input_data) - assert isinstance(output, torch.Tensor) - assert clm.link_function == "logit" + _test_probas(clm) -def test_clm_probit(): - input_shape = 8 - num_classes = 4 - link_function = "probit" +def test_clm_thresholds(): + num_classes = 5 + link_function = "logit" min_distance = 0.0 clm = CLM( num_classes=num_classes, link_function=link_function, min_distance=min_distance ) - input_data = torch.rand(16, input_shape) - output = clm(input_data) - assert isinstance(output, torch.Tensor) - assert clm.link_function == "probit" + thresholds = clm._convert_thresholds( + clm.thresholds_b_, clm.thresholds_a_, min_distance + ) + expected_thresholds = torch.tensor([float(i) for i in range(num_classes - 2 + 1)]) + assert ( + thresholds.shape[0] == clm.thresholds_b_.shape[0] + clm.thresholds_a_.shape[0] + ) -def test_clm_cloglog(): - input_shape = 12 - num_classes = 6 - link_function = "cloglog" - min_distance = 0.0 + assert torch.allclose(thresholds, expected_thresholds) - clm = CLM( - num_classes=num_classes, link_function=link_function, min_distance=min_distance - ) - input_data = torch.rand(8, input_shape) - output = clm(input_data) + _test_probas(clm) + + +def test_clm_link_functions(): + for link in ["logit", "probit", "cloglog"]: + for num_classes in range(3, 12): + clm = CLM(num_classes=num_classes, link_function=link, min_distance=0.0) + assert clm.link_function == link + assert clm.num_classes == num_classes + + _test_probas(clm) + + +def test_clm_all_combinations(): + for link in ["logit", "probit", "cloglog"]: + for num_classes in range(3, 12): + for min_distance in np.linspace(0.0, 0.1, 10): + clm = CLM( + num_classes=num_classes, + link_function=link, + min_distance=min_distance, + ) + assert clm.link_function == link + assert clm.num_classes == num_classes + assert clm.min_distance == min_distance - assert isinstance(output, torch.Tensor) - assert clm.link_function == "cloglog" + _test_probas(clm) def test_clm_clip(): @@ -84,6 +108,7 @@ def test_clm_clip(): warnings.filterwarnings("error") clm(input_data) + _test_probas(clm) clm = CLM( num_classes=num_classes, @@ -92,6 +117,7 @@ def test_clm_clip(): clip_warning=False, ) clm(input_data) + _test_probas(clm) clm = CLM( num_classes=num_classes, @@ -101,4 +127,5 @@ def test_clm_clip(): ) input_data = torch.rand(8, input_shape) * 0.1 clm(input_data) + _test_probas(clm) warnings.resetwarnings() From 5eb7999d2157953a66c2c32ac9b87c21bc90fbdf Mon Sep 17 00:00:00 2001 From: Francisco Berchez Moreno Date: Fri, 12 Jul 2024 12:11:39 +0200 Subject: [PATCH 2/2] Version update --- README.md | 2 +- docs/conf.py | 2 +- pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 4cda9ef..7a7d280 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ ## ⚙️ Installation -`dlordinal v2.1.0` is the last version supported by Python 3.8, Python 3.9 and Python 3.10. +`dlordinal v2.1.1` is the last version supported by Python 3.8, Python 3.9 and Python 3.10. The easiest way to install `dlordinal` is via `pip`: diff --git a/docs/conf.py b/docs/conf.py index 4123dfe..db108f7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -9,7 +9,7 @@ project = "dlordinal" copyright = "2023, Francisco Bérchez, Víctor Vargas" author = "Francisco Bérchez, Víctor Vargas" -release = "2.1.0" +release = "2.1.1" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index 12d3edf..553b550 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dlordinal" -version = "2.1.0" +version = "2.1.1" authors = [ {name = "Francisco Bérchez-Moreno", email = "i72bemof@uco.es"}, {name = "Víctor Manuel Vargas", email = "vvargas@uco.es"},