Skip to content

Commit

Permalink
deterministic simplex sampling feature added
Browse files Browse the repository at this point in the history
  • Loading branch information
gykovacs committed Aug 21, 2022
1 parent 38ca6ab commit ef38e72
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 6 deletions.
2 changes: 1 addition & 1 deletion smote_variants/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
@author: gykovacs
"""

__version__= '0.6.6'
__version__= '0.6.7'
62 changes: 58 additions & 4 deletions smote_variants/base/_simplexsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,54 @@ def add_samples(*,
samples_by_count = [base_points + diffweight*s for s in splits]
return np.vstack(samples_by_count)

def counts_to_vector(counts):
"""
Expand a count vector to a
Args:
counts (np.array): count vector
Returns:
np.array: the expanded vector
"""

return np.hstack([np.repeat(idx, count) for idx, count in enumerate(counts)])

def deterministic_sample(choices, n_to_sample, p):
"""
Take a deterministic sample
Args:
choices (list): the list of choices
n_to_sample (int): the number of samples to take
p (np.array): the distribution
Returns:
np.array: the choices
"""

sample_counts = np.ceil(n_to_sample * p).astype(int)

n_to_remove = np.sum(sample_counts) - n_to_sample

if n_to_remove == 0:
return choices[counts_to_vector(sample_counts)]

non_zero_mask = sample_counts > 0

removal_indices = np.floor(np.linspace(0.0, np.sum(non_zero_mask), n_to_remove, endpoint=False)).astype(int)

tmp = sample_counts[non_zero_mask]
tmp[removal_indices] = tmp[removal_indices] - 1

sample_counts[non_zero_mask] = tmp

assert np.sum(sample_counts) == n_to_sample

samples = choices[counts_to_vector(sample_counts)]

return samples

class SimplexSamplingMixin(RandomStateMixin):
"""
The mixin class for all simplex sampling based techniques.
Expand Down Expand Up @@ -436,11 +484,17 @@ def simplices(self,

weights = weights * node_weights

# sample the simplices
choices = np.arange(all_simplices.shape[0])
selected_indices = self.random_state.choice(choices,
n_to_sample,
p=weights/np.sum(weights))

if self.simplex_sampling == 'random':
# sample the simplices
selected_indices = self.random_state.choice(choices,
n_to_sample,
p=weights/np.sum(weights))
elif self.simplex_sampling == 'deterministic':
selected_simplices = deterministic_sample(choices,
n_to_sample,
p=weights/np.sum(weights))
return all_simplices[selected_indices]

def add_gaussian_noise(self, samples):
Expand Down
3 changes: 2 additions & 1 deletion tests/oversampling/test_simplex_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def test_simplex_deterministic(smote_class):
Args:
smote_class (class): an oversampler class.
"""
ss_params = {'within_simplex_sampling': 'deterministic'}
ss_params = {'within_simplex_sampling': 'deterministic',
'simplex_sampling': 'deterministic'}
X, y = smote_class(ss_params=ss_params).sample(dataset['data'],
dataset['target'])

Expand Down

0 comments on commit ef38e72

Please sign in to comment.