Skip to content

Commit

Permalink
upgrade API to torchlambertw=0.0.3; 10x faster MLE
Browse files Browse the repository at this point in the history
  • Loading branch information
gmgeorg committed Dec 25, 2023
1 parent a24d00d commit b6dbd5c
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 13 deletions.
4 changes: 4 additions & 0 deletions pylambertw/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Init for module"""
from ._version import __version__

__all__ = ["__version__"]
3 changes: 3 additions & 0 deletions pylambertw/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Version."""

__version__ = "0.0.2"
4 changes: 2 additions & 2 deletions pylambertw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,12 @@ def __repr__(self):
def tau(self):
"""Converts Theta (distribution dependent) to Tau (transformation only)."""

distr_constr = lwd.get_distribution_constructor(self.distribution_name)
distr_constr = lwd.utils.get_distribution_constructor(self.distribution_name)
distr = distr_constr(**self.beta)

return Tau(
loc=distr.mean.numpy()
if lwd.is_location_family(self.distribution_name)
if lwd.utils.is_location_family(self.distribution_name)
else 0.0,
scale=distr.stddev.numpy(),
lambertw_params=self.lambertw_params,
Expand Down
8 changes: 4 additions & 4 deletions pylambertw/mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
self.distribution_name = distribution_name
self.distribution_constructor = (
distribution_constructor
or lwd.get_distribution_constructor(self.distribution_name)
or lwd.utils.get_distribution_constructor(self.distribution_name)
)
self.lambertw_type = base.LambertWType(lambertw_type)
self.max_iter = max_iter
Expand All @@ -73,14 +73,14 @@ def _initialize_params(self, data: np.ndarray):
if self.lambertw_type == base.LambertWType.H:
self.igmm = igmm.IGMM(
lambertw_type=self.lambertw_type,
location_family=lwd.is_location_family(self.distribution_name),
location_family=lwd.utils.is_location_family(self.distribution_name),
)
self.igmm.fit(data)
x_init = self.igmm.transform(data)

lambertw_params_init = self.igmm.tau.lambertw_params
else:
if lwd.is_location_family(self.distribution_name):
if lwd.utils.is_location_family(self.distribution_name):
# Default to Normal distriubtion for location family.
params_data = ud.estimate_params(data, "Normal")
loc_init = params_data["loc"]
Expand All @@ -92,7 +92,7 @@ def _initialize_params(self, data: np.ndarray):
scale_init = 1.0 / params_data["rate"]

z_init = (data - loc_init) / scale_init
if lwd.is_location_family(self.distribution_name):
if lwd.utils.is_location_family(self.distribution_name):
gamma_init = igmm.gamma_taylor(z_init)
else:
gamma_init = 0.01
Expand Down
4 changes: 2 additions & 2 deletions pylambertw/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ def test_estimate_params(dist_name):
rng = np.random.RandomState(42)
x = rng.normal(100)
params = ud.estimate_params(x, dist_name)
constr = lwd.get_distribution_constructor(dist_name)
param_names = lwd.get_distribution_args(constr)
constr = lwd.utils.get_distribution_constructor(dist_name)
param_names = lwd.utils.get_distribution_args(constr)
assert set(params.keys()) == set(param_names)
4 changes: 2 additions & 2 deletions pylambertw/utils/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def torch_sigmoid(x: torch.tensor) -> torch.tensor:
def get_params_activations(distribution_name: str) -> Dict[str, Callable]:
"""Get activation functions for each distribution parameters."""
assert isinstance(distribution_name, str)
distr_constr = lwd.get_distribution_constructor(distribution_name)
param_names = lwd.get_distribution_args(distr_constr)
distr_constr = lwd.utils.get_distribution_constructor(distribution_name)
param_names = lwd.utils.get_distribution_args(distr_constr)

act_fns = {p: (torch_linear, linear_inverse) for p in param_names}

Expand Down
15 changes: 12 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from setuptools import find_packages, setup
import re

_VERSION_FILE = "pylambertw/_version.py"
verstrline = open(_VERSION_FILE, "rt").read()
_VERSION = r"^__version__ = ['\"]([^'\"]*)['\"]"
mo = re.search(_VERSION, verstrline, re.M)
if mo:
verstr = mo.group(1)
else:
raise RuntimeError("Unable to find version string in %s." % (_VERSION_FILE,))

pkg_descr = """
Python implementation of the Lambert W x F framework for analyzing skewed, heavy-tailed distribution
with an sklearn interface and torch based maximum likelihood estimation (MLE).
"""


setup(
name="pylambertw",
version="0.0.1",
version=verstr,
url="https://github.com/gmgeorg/pylambertw.git",
author="Georg M. Goerg",
author_email="im@gmge.org",
Expand All @@ -25,6 +34,6 @@
"tqdm>=4.46.1",
"dataclasses>=0.6",
"scikit-learn>=1.0.1",
"torchlambertw @ git+ssh://git@github.com/gmgeorg/torchlambertw.git#egg=torchlambertw-0.0.1",
"torchlambertw @ git+ssh://git@github.com/gmgeorg/torchlambertw.git#egg=torchlambertw-0.0.3",
],
)

0 comments on commit b6dbd5c

Please sign in to comment.