Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add OTOS oversampling algorithm #43

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyloras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._loras import LORAS
from ._prowras import ProWRAS
from ._gamus import GAMUS
from ._otos import OTOS
94 changes: 94 additions & 0 deletions pyloras/_otos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from imblearn.over_sampling.base import BaseOverSampler
from imblearn.utils import Substitution
from imblearn.utils._docstring import (
_random_state_docstring,
_n_jobs_docstring,
)
from sklearn.svm import LinearSVC
import numpy as np

from ._common import check_random_state, safe_random_state


@Substitution(
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
random_state=_random_state_docstring,
n_jobs=_n_jobs_docstring
)
class OTOS(BaseOverSampler):
def __init__(
self,
*,
sampling_strategy="auto",
svc_reg=1.0,
ot_reg=1.0,
tradeoff=1.0,
random_state=None,
max_iter=100,
):
super().__init__(sampling_strategy=sampling_strategy)
self.svc_reg = svc_reg
self.ot_reg = ot_reg
self.tradeoff = tradeoff
self.random_state = random_state
self.max_iter = max_iter

def _fit_resample(self, X, y):
import ot
random_state = check_random_state(self.random_state)
X_res = [X.copy()]
y_res = [y.copy()]
svc = LinearSVC(
loss="hinge", C=self.svc_reg, random_state=safe_random_state(random_state)
)
for minority_class, samples_to_make in self.sampling_strategy_.items():
if samples_to_make == 0:
continue
X_p = X[y == minority_class]
X_n = X[y != minority_class]
n_p = X_p.shape[0]
n_n = X_n.shape[0]
n_r = samples_to_make
one_r = np.ones((n_r, 1))
one_n = np.ones((n_n, 1))
# set initial distribution for mu_r and mu_p
mu_r = np.asarray([1.0 / n_r] * n_r)
mu_p = np.asarray([1.0 / n_p] * n_p)
T = mu_r[:, None] @ mu_p[:, None].T
# manufactor a binary classification problem
_y = np.empty_like(y)
_y[y == minority_class] = 0
_y[y != minority_class] = 1
svc.fit(X, _y)
w = svc.coef_.T

hingelosses = np.concatenate(
[
np.atleast_1d(max(1 - y_i * svc.coef_ @ x_row, 0.0))
for y_i, x_row in zip(y[y == minority_class], X_p)
]
)
mu_p = np.exp(hingelosses)
mu_p /= mu_p.sum()

D_r = np.diag(1 / mu_r)
X_r = D_r @ T @ X_p
# C_p = np.apply_along_axis(c_row, axis=-1, arr=X_r)
C_p = np.asarray(
[
[np.linalg.norm(x_row - row) for row in X_p]
for x_row in X_r
]
)
wwT = w @ w.T
Theta = self.tradeoff * C_p.T - X_p @ np.kron(one_r.T, wwT @ X_n.T @ one_n + n_n * w) @ D_r
Phi = X_p @ wwT @ X_p.T
Psi = D_r.T @ D_r

for _ in range(self.max_iter):
transport_cost = Theta.T + n_n * Psi @ T @ Phi
T = ot.sinkhorn(mu_r, mu_p, transport_cost, self.ot_reg)
X_res.append(D_r @ T @ X_p)
y_res.append([minority_class] * n_r)

return np.concatenate(X_res), np.concatenate(y_res)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ classifiers = [
"Operating System :: MacOS :: MacOS X",
]

[project.optional-dependencies]
otos = ["POT"]

[project.urls]
source = "https://github.com/zoj613/pyloras"
tracker = "https://github.com/zoj613/pyloras/issues"
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
build==0.10.0
imbalanced-learn==0.10.1
numpy==1.23.2
POT==0.9.0
pytest==7.3.1
pytest-cov==4.0.0
23 changes: 23 additions & 0 deletions tests/test_otos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np
import pytest
from sklearn.datasets import make_classification

from pyloras import OTOS


@pytest.fixture
def data():
return make_classification(n_samples=150, n_features=4, n_informative=4,
n_redundant=0, n_repeated=0, n_classes=3,
n_clusters_per_class=2,
weights=[0.01, 0.05, 0.94],
class_sep=0.8, random_state=0)


def test_otos(data):
X, y = data
rng = np.random.RandomState(12345)
otos = OTOS(random_state=rng)
X_res, y_res = otos.fit_resample(X, y)
_, y_counts = np.unique(y_res, return_counts=True)
np.testing.assert_allclose(y_counts[0], y_counts[1])