Skip to content

Commit

Permalink
add tests to select features
Browse files Browse the repository at this point in the history
  • Loading branch information
paucablop committed Sep 21, 2023
1 parent 15eab96 commit 27498f9
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
36 changes: 32 additions & 4 deletions tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
StandardNormalVariate,
)
from chemotools.smooth import MeanFilter, MedianFilter, WhittakerSmooth
from chemotools.variable_selection import RangeCut
from chemotools.variable_selection import RangeCut, SelectFeatures
from tests.fixtures import (
spectrum,
spectrum_arpls,
Expand Down Expand Up @@ -439,8 +439,8 @@ def test_point_scaler(spectrum):

def test_point_scaler_with_wavenumbers():
# Arrange
wavenumbers = np.array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
spectrum = np.array([[10., 12., 14., 16., 14., 12., 10., 12., 14., 16.]])
wavenumbers = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
spectrum = np.array([[10.0, 12.0, 14.0, 16.0, 14.0, 12.0, 10.0, 12.0, 14.0, 16.0]])

# Act
index_scaler = PointScaler(point=4, wavenumbers=wavenumbers)
Expand All @@ -450,7 +450,6 @@ def test_point_scaler_with_wavenumbers():
assert np.allclose(spectrum_corrected[0], spectrum[0] / spectrum[0][3], atol=1e-8)



def test_range_cut_by_index(spectrum):
# Arrange
range_cut = RangeCut(start=0, end=10)
Expand Down Expand Up @@ -544,6 +543,35 @@ def test_saviszky_golay_filter_3():
assert np.allclose(spectrum_corrected[0], np.ones((1, 10)), atol=1e-2)


def test_select_features():
# Arrange
spectrum = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
expected = np.array([[1, 2, 3, 8, 9, 10]])

# Act
select_features = SelectFeatures(features=np.array([0, 1, 2, 7, 8, 9]))
spectrum_corrected = select_features.fit_transform(spectrum)

# Assert
assert np.allclose(spectrum_corrected[0], expected, atol=1e-8)


def test_select_features_with_wavenumbers():
# Arrange
wavenumbers = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
spectrum = np.array([[1.0, 2.0, 3.0, 5.0, 8.0, 13.0, 21.0, 34.0, 55.0, 89.0]])
expected = np.array([[1.0, 2.0, 3.0, 34.0, 55.0, 89.0]])

# Act
select_features = SelectFeatures(
features=np.array([1, 2, 3, 8, 9, 10]), wavenumbers=wavenumbers
)
spectrum_corrected = select_features.fit_transform(spectrum)

# Assert
assert np.allclose(spectrum_corrected[0], expected, atol=1e-8)


def test_standard_normal_variate(spectrum, reference_snv):
# Arrange
snv = StandardNormalVariate()
Expand Down
10 changes: 9 additions & 1 deletion tests/test_sklearn_compliance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
SavitzkyGolayFilter,
WhittakerSmooth,
)
from chemotools.variable_selection import RangeCut
from chemotools.variable_selection import RangeCut, SelectFeatures

from tests.fixtures import spectrum

Expand Down Expand Up @@ -173,6 +173,14 @@ def test_compliance_savitzky_golay_filter():
check_estimator(transformer)


# SelectFeatures
def test_compliance_select_features():
# Arrange
transformer = SelectFeatures()
# Act & Assert
check_estimator(transformer)


# StandardNormalVariate
def test_compliance_standard_normal_variate():
# Arrange
Expand Down

0 comments on commit 27498f9

Please sign in to comment.