Skip to content

Commit

Permalink
Merge pull request #74 from ayrna/development
Browse files Browse the repository at this point in the history
Fixed bug in CLM thresholds
  • Loading branch information
franberchez authored Jul 12, 2024
2 parents 5305630 + 5eb7999 commit 5fe8456
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 68 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

## ⚙️ Installation

`dlordinal v2.1.0` is the last version supported by Python 3.8, Python 3.9 and Python 3.10.
`dlordinal v2.1.1` is the last version supported by Python 3.8, Python 3.9 and Python 3.10.

The easiest way to install `dlordinal` is via `pip`:

Expand Down
109 changes: 70 additions & 39 deletions dlordinal/output_layers/clm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import warnings
from math import sqrt
from typing import Literal

import torch
from torch.nn import Module


class CLM(Module):
"""
Implementation of the cumulative link model from :footcite:t:`vargas2020clm` as a torch layer.
Different link functions can be used, including logit, probit and cloglog.
Implementation of the cumulative link models from :footcite:t:`vargas2020clm` as a
torch layer. Different link functions can be used, including logit, probit
and cloglog.
Parameters
----------
Expand All @@ -21,52 +22,72 @@ class CLM(Module):
clip_warning : bool, default=True
Whether to print the clipping value warning or not.
Attributes
----------
num_classes : int
The number of classes.
link_function : str
The link function to use. Can be ``'logit'``, ``'probit'`` or ``'cloglog'``.
min_distance : float
The minimum distance between thresholds
clip_warning : bool
Whether to print the clipping value warning or not.
dist_ : torch.distributions.Normal
The normal (0,1) distribution used to compute the probit link function.
thresholds_b_ : torch.nn.Parameter
The torch parameter for the first threshold.
thresholds_a_ : torch.nn.Parameter
The torch parameter for the alphas of the thresholds.
clip_warning_shown_ : bool
Whether the clipping warning has been shown or not.
Example
---------
>>> import torch
>>> from dlordinal.layers import CLM
>>> from dlordinal.output_layers import CLM
>>> inp = torch.randn(10, 5)
>>> fc = torch.nn.Linear(5, 1)
>>> clm = CLM(5, "logit")
>>> output = clm(fc(inp))
>>> print(output)
tensor([[0.4704, 0.0063, 0.0441, 0.0423, 0.4369],
[0.2496, 0.0048, 0.0349, 0.0363, 0.6745],
[0.6384, 0.0058, 0.0393, 0.0357, 0.2808],
[0.4862, 0.0063, 0.0441, 0.0420, 0.4214],
[0.3768, 0.0060, 0.0425, 0.0421, 0.5327],
[0.4740, 0.0063, 0.0441, 0.0422, 0.4334],
[0.2868, 0.0052, 0.0378, 0.0387, 0.6315],
[0.2583, 0.0049, 0.0356, 0.0369, 0.6643],
[0.1811, 0.0038, 0.0281, 0.0300, 0.7570],
[0.5734, 0.0062, 0.0423, 0.0392, 0.3389]], grad_fn=<CopySlices>)
tensor([[0.7944, 0.1187, 0.0531, 0.0211, 0.0127],
[0.4017, 0.2443, 0.1862, 0.0987, 0.0690],
[0.4619, 0.2381, 0.1638, 0.0814, 0.0548],
[0.4636, 0.2378, 0.1632, 0.0809, 0.0545],
[0.4330, 0.2419, 0.1746, 0.0893, 0.0612],
[0.5006, 0.2309, 0.1495, 0.0716, 0.0473],
[0.6011, 0.2027, 0.1138, 0.0504, 0.0320],
[0.5995, 0.2032, 0.1144, 0.0507, 0.0322],
[0.4014, 0.2443, 0.1863, 0.0988, 0.0691],
[0.6922, 0.1672, 0.0838, 0.0351, 0.0217]], grad_fn=<CopySlices>)
"""

def __init__(
self, num_classes, link_function, min_distance=0.0, clip_warning=True, **kwargs
self,
num_classes: int,
link_function: Literal["logit", "probit", "cloglog"],
min_distance: int = 0.0,
clip_warning: bool = True,
**kwargs,
):
super().__init__()
self.num_classes = num_classes
self.link_function = link_function
self.min_distance = min_distance
self.clip_warning = clip_warning
self.dist = torch.distributions.Normal(0, 1)
self.dist_ = torch.distributions.Normal(0, 1)

self.thresholds_b = torch.nn.Parameter(data=torch.Tensor(1), requires_grad=True)
torch.nn.init.uniform_(self.thresholds_b, 0.0, 0.1)

self.thresholds_a = torch.nn.Parameter(
data=torch.Tensor(self.num_classes - 2), requires_grad=True
self.thresholds_b_ = torch.nn.Parameter(
data=torch.Tensor([0]), requires_grad=True
)
torch.nn.init.uniform_(
self.thresholds_a,
sqrt((1.0 / (self.num_classes - 2)) / 2),
sqrt(1.0 / (self.num_classes - 2)),
self.thresholds_a_ = torch.nn.Parameter(
data=torch.Tensor([1.0 for _ in range(self.num_classes - 2)]),
requires_grad=True,
)

self.clip_warning_shown = False
self.clip_warning_shown_ = False

def _convert_thresholds(self, b, a, min_distance):
a = a**2
Expand All @@ -77,7 +98,7 @@ def _convert_thresholds(self, b, a, min_distance):
torch.ones(
(self.num_classes - 1, self.num_classes - 1), device=a.device
),
diagonal=-1,
diagonal=0,
)
* torch.reshape(
torch.tile(thresholds_param, (self.num_classes - 1,)).to(a.device),
Expand All @@ -87,9 +108,7 @@ def _convert_thresholds(self, b, a, min_distance):
)
return th

def _clm(self, projected: torch.Tensor, thresholds: torch.Tensor):
projected = torch.reshape(projected, shape=(-1,))

def _compute_z3(self, projected: torch.Tensor, thresholds: torch.Tensor):
m = projected.shape[0]
a = torch.reshape(torch.tile(thresholds, (m,)), shape=(m, -1))
b = torch.transpose(
Expand All @@ -102,22 +121,34 @@ def _clm(self, projected: torch.Tensor, thresholds: torch.Tensor):

z3 = a - b
if torch.any(z3 > 10) or torch.any(z3 < -10):
if self.clip_warning and not self.clip_warning_shown:
if self.clip_warning and not self.clip_warning_shown_:
warnings.warn(
"The output value of the CLM layer is out of the range [-10, 10]."
" Clipping value prior to applying the link function for numerical"
" stability."
f"The output value of the CLM layer (max: {z3.abs().max()}) is out "
"of the range [-10, 10]. Clipping value prior to applying the "
"link function for numerical stability."
)
z3 = torch.clip(a - b, -10, 10)
self.clip_warning_shown = True
self.clip_warning_shown_ = True

return z3

def _apply_link_function(self, z3):
if self.link_function == "probit":
a3T = self.dist.cdf(z3)
a3T = self.dist_.cdf(z3)
elif self.link_function == "cloglog":
a3T = 1 - torch.exp(-torch.exp(z3))
else:
a3T = 1.0 / (1.0 + torch.exp(-z3))

return a3T

def _clm(self, projected: torch.Tensor, thresholds: torch.Tensor):
projected = torch.reshape(projected, shape=(-1,))

m = projected.shape[0]
z3 = self._compute_z3(projected, thresholds)
a3T = self._apply_link_function(z3)

ones = torch.ones((m, 1), device=projected.device)
a3 = torch.cat((a3T, ones), dim=1)
a3[:, 1:] = a3[:, 1:] - a3[:, 0:-1]
Expand All @@ -138,7 +169,7 @@ def forward(self, x):
"""

thresholds = self._convert_thresholds(
self.thresholds_b, self.thresholds_a, self.min_distance
self.thresholds_b_, self.thresholds_a_, self.min_distance
)

return self._clm(x, thresholds)
79 changes: 53 additions & 26 deletions dlordinal/output_layers/tests/test_clm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import warnings

import numpy as np
import pytest
import torch

from dlordinal.output_layers import CLM


def _test_probas(clm):
projections = torch.rand(32, 1)
probas = clm(projections)
total_probas = torch.sum(probas, dim=1)
assert torch.allclose(total_probas, torch.ones_like(total_probas))
assert isinstance(probas, torch.Tensor)

return projections, probas, total_probas


def test_clm_creation():
num_classes = 3
link_function = "logit"
Expand All @@ -18,52 +29,65 @@ def test_clm_creation():
assert isinstance(clm, CLM)


def test_clm_logit():
input_shape = 10
def test_clm_probas():
num_classes = 5
link_function = "logit"
min_distance = 0.0

clm = CLM(
num_classes=num_classes, link_function=link_function, min_distance=min_distance
)
input_data = torch.rand(32, input_shape)
output = clm(input_data)

assert isinstance(output, torch.Tensor)
assert clm.link_function == "logit"
_test_probas(clm)


def test_clm_probit():
input_shape = 8
num_classes = 4
link_function = "probit"
def test_clm_thresholds():
num_classes = 5
link_function = "logit"
min_distance = 0.0

clm = CLM(
num_classes=num_classes, link_function=link_function, min_distance=min_distance
)
input_data = torch.rand(16, input_shape)
output = clm(input_data)

assert isinstance(output, torch.Tensor)
assert clm.link_function == "probit"
thresholds = clm._convert_thresholds(
clm.thresholds_b_, clm.thresholds_a_, min_distance
)
expected_thresholds = torch.tensor([float(i) for i in range(num_classes - 2 + 1)])

assert (
thresholds.shape[0] == clm.thresholds_b_.shape[0] + clm.thresholds_a_.shape[0]
)

def test_clm_cloglog():
input_shape = 12
num_classes = 6
link_function = "cloglog"
min_distance = 0.0
assert torch.allclose(thresholds, expected_thresholds)

clm = CLM(
num_classes=num_classes, link_function=link_function, min_distance=min_distance
)
input_data = torch.rand(8, input_shape)
output = clm(input_data)
_test_probas(clm)


def test_clm_link_functions():
for link in ["logit", "probit", "cloglog"]:
for num_classes in range(3, 12):
clm = CLM(num_classes=num_classes, link_function=link, min_distance=0.0)
assert clm.link_function == link
assert clm.num_classes == num_classes

_test_probas(clm)


def test_clm_all_combinations():
for link in ["logit", "probit", "cloglog"]:
for num_classes in range(3, 12):
for min_distance in np.linspace(0.0, 0.1, 10):
clm = CLM(
num_classes=num_classes,
link_function=link,
min_distance=min_distance,
)
assert clm.link_function == link
assert clm.num_classes == num_classes
assert clm.min_distance == min_distance

assert isinstance(output, torch.Tensor)
assert clm.link_function == "cloglog"
_test_probas(clm)


def test_clm_clip():
Expand All @@ -84,6 +108,7 @@ def test_clm_clip():

warnings.filterwarnings("error")
clm(input_data)
_test_probas(clm)

clm = CLM(
num_classes=num_classes,
Expand All @@ -92,6 +117,7 @@ def test_clm_clip():
clip_warning=False,
)
clm(input_data)
_test_probas(clm)

clm = CLM(
num_classes=num_classes,
Expand All @@ -101,4 +127,5 @@ def test_clm_clip():
)
input_data = torch.rand(8, input_shape) * 0.1
clm(input_data)
_test_probas(clm)
warnings.resetwarnings()
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
project = "dlordinal"
copyright = "2023, Francisco Bérchez, Víctor Vargas"
author = "Francisco Bérchez, Víctor Vargas"
release = "2.1.0"
release = "2.1.1"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "dlordinal"
version = "2.1.0"
version = "2.1.1"
authors = [
{name = "Francisco Bérchez-Moreno", email = "i72bemof@uco.es"},
{name = "Víctor Manuel Vargas", email = "vvargas@uco.es"},
Expand Down

0 comments on commit 5fe8456

Please sign in to comment.