Skip to content

Commit

Permalink
Fix for #1296 (#1312)
Browse files Browse the repository at this point in the history
Fix of copying categorical features in data splitting
  • Loading branch information
aPovidlo committed Jul 23, 2024
1 parent e0b4ee7 commit a7e4243
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
31 changes: 19 additions & 12 deletions fedot/core/data/data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,28 @@ def _split_input_data_by_indexes(origin_input_data: Union[InputData, MultiModalD
target = np.take(origin_input_data.target, index, 0)
features = np.take(origin_input_data.features, index, 0)

if origin_input_data.categorical_features is not None:
categorical_features = np.take(origin_input_data.categorical_features, index, 0)
else:
categorical_features = origin_input_data.categorical_features

if retain_first_target and len(target.shape) > 1:
target = target[:, 0]

data = InputData(idx=idx,
features=features,
target=target,
task=deepcopy(origin_input_data.task),
data_type=origin_input_data.data_type,
supplementary_data=origin_input_data.supplementary_data,
categorical_features=origin_input_data.categorical_features,
categorical_idx=origin_input_data.categorical_idx,
numerical_idx=origin_input_data.numerical_idx,
encoded_idx=origin_input_data.encoded_idx,
features_names=origin_input_data.features_names,
)
data = InputData(
idx=idx,
features=features,
target=target,
task=deepcopy(origin_input_data.task),
data_type=origin_input_data.data_type,
supplementary_data=origin_input_data.supplementary_data,
categorical_features=categorical_features,
categorical_idx=origin_input_data.categorical_idx,
numerical_idx=origin_input_data.numerical_idx,
encoded_idx=origin_input_data.encoded_idx,
features_names=origin_input_data.features_names,
)

return data
else:
raise TypeError(f'Unknown data type {type(origin_input_data)}')
Expand Down
33 changes: 33 additions & 0 deletions test/unit/data/test_data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import pytest

from fedot.api.api_utils.api_data import ApiDataProcessor
from fedot.core.data.data import InputData
from fedot.core.data.data_split import train_test_data_setup
from fedot.core.data.multi_modal import MultiModalData
Expand All @@ -18,6 +19,8 @@
from test.unit.tasks.test_forecasting import get_ts_data

TABULAR_SIMPLE = {'train_features_size': (8, 5), 'test_features_size': (2, 5), 'test_idx': (8, 9)}
TABULAR_CATEGORICAL = {'train_features_size': (11, 26), 'test_features_size': (3, 26), 'test_idx': (11, 12, 13),
'train_category_size': (11, 4), 'test_category_size': (3, 4)}
TS_SIMPLE = {'train_features_size': (18,), 'test_features_size': (18,), 'test_idx': (18, 19)}
TEXT_SIMPLE = {'train_features_size': (8,), 'test_features_size': (2,), 'test_idx': (8, 9)}
IMAGE_SIMPLE = {'train_features_size': (8, 5, 5, 2), 'test_features_size': (2, 5, 5, 2), 'test_idx': (8, 9)}
Expand Down Expand Up @@ -107,6 +110,32 @@ def get_balanced_data_to_test_mismatch():
return input_data


def get_tabular_classification_data_with_cats():
task = Task(TaskTypesEnum.classification)
x = np.array([[0, 0, 15, 'cat', 'left'],
[0, 1, 2, 'cat', 'right'],
[8, 12, 0, 'dog', 'left'],
[0, 1, 0, 'dog', 'right'],
[1, 1, 0, 'cat', 'left'],
[0, 11, 9, 'cow', 'right'],
[5, 1, 10, 'cat', 'left'],
[8, 16, 4, 'dog', 'right'],
[3, 1, 5, 'cat', 'left'],
[0, 1, 6, 'dog', 'right'],
[2, 7, 9, 'cat', 'left'],
[0, 1, 2, 'dog', 'right'],
[14, 1, 0, 'cat', 'right'],
[0, 4, 10, 'dog', 'left']])
y = np.array([0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1])
input_data = InputData(idx=np.arange(0, len(x)), features=x,
target=y, task=task, data_type=DataTypesEnum.table)

data_preprocessor = ApiDataProcessor(task=task)
preprocessed_input_data = data_preprocessor.fit_transform(input_data)

return preprocessed_input_data


def check_shuffle(sample):
unique = np.unique(np.diff(sample.idx))
test_result = len(unique) > 1 or np.min(unique) > 1
Expand All @@ -133,6 +162,7 @@ def test_split_data():

@pytest.mark.parametrize('data_generator, expected_output',
[(get_tabular_classification_data, TABULAR_SIMPLE),
(get_tabular_classification_data_with_cats, TABULAR_CATEGORICAL),
(get_ts_data_to_forecast_two_elements, TS_SIMPLE),
(get_text_classification_data, TEXT_SIMPLE),
(get_image_classification_data, IMAGE_SIMPLE)])
Expand All @@ -144,6 +174,9 @@ def test_default_train_test_simple(data_generator: Callable, expected_output: di
assert train_data.features.shape == expected_output['train_features_size']
assert test_data.features.shape == expected_output['test_features_size']
assert tuple(test_data.idx) == expected_output['test_idx']
if 'train_category_size' in expected_output and 'test_category_size' in expected_output:
assert train_data.categorical_features.shape == expected_output['train_category_size']
assert test_data.categorical_features.shape == expected_output['test_category_size']


def test_multitarget_train_test_split():
Expand Down

0 comments on commit a7e4243

Please sign in to comment.