Skip to content

Commit

Permalink
ENH: Add OTOS oversampling algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
zoj613 committed May 1, 2023
1 parent a6f446c commit d325cdf
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyloras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._loras import LORAS
from ._prowras import ProWRAS
from ._gamus import GAMUS
from ._otos import OTOS
35 changes: 35 additions & 0 deletions pyloras/_otos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import ot
from imblearn.over_sampling.base import BaseOverSampler
from imblearn.utils import Substitution
from imblearn.utils._docstring import (
_random_state_docstring,
_n_jobs_docstring,
)
import numpy as np

from ._common import check_random_state, safe_random_state


@Substitution(
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
random_state=_random_state_docstring,
n_jobs=_n_jobs_docstring
)
class OTOS(BaseOverSampler):
def __init__(
self,
*,
sampling_strategy="auto",
svm_regularization=1.0,
ot_regularization=1.0,
tradeoff=1.0,
random_state=None,
):
super().__init__(sampling_strategy=sampling_strategy)
self.svm_regularization = svm_regularization
self.ot_regularization = ot_regularization
self.tradeoff = tradeoff
self.random_state = random_state

def fit_resample(self, X, y):
return X, y
23 changes: 23 additions & 0 deletions tests/test_otos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np
import pytest
from sklearn.datasets import make_classification

from pyloras import OTOS


@pytest.fixture
def data():
return make_classification(n_samples=150, n_features=4, n_informative=4,
n_redundant=0, n_repeated=0, n_classes=3,
n_clusters_per_class=2,
weights=[0.01, 0.05, 0.94],
class_sep=0.8, random_state=0)


def test_otos(data):
X, y = data
rng = np.random.RandomState(12345)
otos = OTOS(random_state=rng)
X_res, y_res = otos.fit_resample(X, y)
_, y_counts = np.unique(y_res, return_counts=True)
np.testing.assert_allclose(y_counts[0], y_counts[1])

0 comments on commit d325cdf

Please sign in to comment.