From 85e300843796a7ac3b8fe0929cbc26fced7f392d Mon Sep 17 00:00:00 2001 From: Zolisa Bleki Date: Mon, 8 May 2023 00:36:34 -0400 Subject: [PATCH] First attempt at emplemnting OTOS. Still needs work --- pyloras/_otos.py | 73 +++++++++++++++++++++++++++++++++++++++----- pyproject.toml | 3 ++ requirements-dev.txt | 1 + 3 files changed, 70 insertions(+), 7 deletions(-) diff --git a/pyloras/_otos.py b/pyloras/_otos.py index bc7c7b8..fca27ef 100644 --- a/pyloras/_otos.py +++ b/pyloras/_otos.py @@ -1,10 +1,10 @@ -import ot from imblearn.over_sampling.base import BaseOverSampler from imblearn.utils import Substitution from imblearn.utils._docstring import ( _random_state_docstring, _n_jobs_docstring, ) +from sklearn.svm import LinearSVC import numpy as np from ._common import check_random_state, safe_random_state @@ -20,16 +20,75 @@ def __init__( self, *, sampling_strategy="auto", - svm_regularization=1.0, - ot_regularization=1.0, + svc_reg=1.0, + ot_reg=1.0, tradeoff=1.0, random_state=None, + max_iter=100, ): super().__init__(sampling_strategy=sampling_strategy) - self.svm_regularization = svm_regularization - self.ot_regularization = ot_regularization + self.svc_reg = svc_reg + self.ot_reg = ot_reg self.tradeoff = tradeoff self.random_state = random_state + self.max_iter = max_iter - def fit_resample(self, X, y): - return X, y + def _fit_resample(self, X, y): + import ot + random_state = check_random_state(self.random_state) + X_res = [X.copy()] + y_res = [y.copy()] + svc = LinearSVC( + loss="hinge", C=self.svc_reg, random_state=safe_random_state(random_state) + ) + for minority_class, samples_to_make in self.sampling_strategy_.items(): + if samples_to_make == 0: + continue + X_p = X[y == minority_class] + X_n = X[y != minority_class] + n_p = X_p.shape[0] + n_n = X_n.shape[0] + n_r = samples_to_make + one_r = np.ones((n_r, 1)) + one_n = np.ones((n_n, 1)) + # set initial distribution for mu_r and mu_p + mu_r = np.asarray([1.0 / n_r] * n_r) + mu_p = np.asarray([1.0 / n_p] * n_p) + T = mu_r[:, None] @ mu_p[:, None].T + # manufactor a binary classification problem + _y = np.empty_like(y) + _y[y == minority_class] = 0 + _y[y != minority_class] = 1 + svc.fit(X, _y) + w = svc.coef_.T + + hingelosses = np.concatenate( + [ + np.atleast_1d(max(1 - y_i * svc.coef_ @ x_row, 0.0)) + for y_i, x_row in zip(y[y == minority_class], X_p) + ] + ) + mu_p = np.exp(hingelosses) + mu_p /= mu_p.sum() + + D_r = np.diag(1 / mu_r) + X_r = D_r @ T @ X_p + # C_p = np.apply_along_axis(c_row, axis=-1, arr=X_r) + C_p = np.asarray( + [ + [np.linalg.norm(x_row - row) for row in X_p] + for x_row in X_r + ] + ) + wwT = w @ w.T + Theta = self.tradeoff * C_p.T - X_p @ np.kron(one_r.T, wwT @ X_n.T @ one_n + n_n * w) @ D_r + Phi = X_p @ wwT @ X_p.T + Psi = D_r.T @ D_r + + for _ in range(self.max_iter): + transport_cost = Theta.T + n_n * Psi @ T @ Phi + T = ot.sinkhorn(mu_r, mu_p, transport_cost, self.ot_reg) + X_res.append(D_r @ T @ X_p) + y_res.append([minority_class] * n_r) + + return np.concatenate(X_res), np.concatenate(y_res) diff --git a/pyproject.toml b/pyproject.toml index 3300df6..26afa86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,9 @@ classifiers = [ "Operating System :: MacOS :: MacOS X", ] +[project.optional-dependencies] +otos = ["POT"] + [project.urls] source = "https://github.com/zoj613/pyloras" tracker = "https://github.com/zoj613/pyloras/issues" diff --git a/requirements-dev.txt b/requirements-dev.txt index 154c6b2..f9a7d57 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,5 +2,6 @@ build==0.10.0 imbalanced-learn==0.10.1 numpy==1.23.2 +POT==0.9.0 pytest==7.3.1 pytest-cov==4.0.0