Skip to content

Commit

Permalink
[enhancement] create sklearnex/test/utils package for sklearnex…
Browse files Browse the repository at this point in the history
… testing (#2036)

* move to separate directory

* expose sklearn_clone_dict

* forgotten init

* forgotten name change
  • Loading branch information
icfaust authored Sep 18, 2024
1 parent b97e713 commit 1c81337
Show file tree
Hide file tree
Showing 19 changed files with 68 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_convert_to_dataframe,
get_dataframes_and_queues,
)
from sklearnex.tests._utils_spmd import (
from sklearnex.tests.utils.spmd import (
_generate_statistic_data,
_get_local_tensor,
_mpi_libs_and_gpu_available,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_convert_to_dataframe,
get_dataframes_and_queues,
)
from sklearnex.tests._utils_spmd import (
from sklearnex.tests.utils.spmd import (
_generate_statistic_data,
_get_local_tensor,
_mpi_libs_and_gpu_available,
Expand Down
2 changes: 1 addition & 1 deletion sklearnex/spmd/cluster/tests/test_dbscan_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
_convert_to_dataframe,
get_dataframes_and_queues,
)
from sklearnex.tests._utils_spmd import (
from sklearnex.tests.utils.spmd import (
_generate_clustering_data,
_get_local_tensor,
_mpi_libs_and_gpu_available,
Expand Down
2 changes: 1 addition & 1 deletion sklearnex/spmd/cluster/tests/test_kmeans_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_convert_to_dataframe,
get_dataframes_and_queues,
)
from sklearnex.tests._utils_spmd import (
from sklearnex.tests.utils.spmd import (
_assert_kmeans_labels_allclose,
_assert_unordered_allclose,
_generate_clustering_data,
Expand Down
2 changes: 1 addition & 1 deletion sklearnex/spmd/covariance/tests/test_covariance_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_convert_to_dataframe,
get_dataframes_and_queues,
)
from sklearnex.tests._utils_spmd import (
from sklearnex.tests.utils.spmd import (
_generate_statistic_data,
_get_local_tensor,
_mpi_libs_and_gpu_available,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_convert_to_dataframe,
get_dataframes_and_queues,
)
from sklearnex.tests._utils_spmd import (
from sklearnex.tests.utils.spmd import (
_generate_statistic_data,
_get_local_tensor,
_mpi_libs_and_gpu_available,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_convert_to_dataframe,
get_dataframes_and_queues,
)
from sklearnex.tests._utils_spmd import (
from sklearnex.tests.utils.spmd import (
_generate_statistic_data,
_get_local_tensor,
_mpi_libs_and_gpu_available,
Expand Down
2 changes: 1 addition & 1 deletion sklearnex/spmd/decomposition/tests/test_pca_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_convert_to_dataframe,
get_dataframes_and_queues,
)
from sklearnex.tests._utils_spmd import (
from sklearnex.tests.utils.spmd import (
_generate_statistic_data,
_get_local_tensor,
_mpi_libs_and_gpu_available,
Expand Down
2 changes: 1 addition & 1 deletion sklearnex/spmd/ensemble/tests/test_forest_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_convert_to_dataframe,
get_dataframes_and_queues,
)
from sklearnex.tests._utils_spmd import (
from sklearnex.tests.utils.spmd import (
_generate_classification_data,
_generate_regression_data,
_get_local_tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_convert_to_dataframe,
get_dataframes_and_queues,
)
from sklearnex.tests._utils_spmd import (
from sklearnex.tests.utils.spmd import (
_generate_regression_data,
_get_local_tensor,
_mpi_libs_and_gpu_available,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_convert_to_dataframe,
get_dataframes_and_queues,
)
from sklearnex.tests._utils_spmd import (
from sklearnex.tests.utils.spmd import (
_generate_regression_data,
_get_local_tensor,
_mpi_libs_and_gpu_available,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_convert_to_dataframe,
get_dataframes_and_queues,
)
from sklearnex.tests._utils_spmd import (
from sklearnex.tests.utils.spmd import (
_generate_classification_data,
_get_local_tensor,
_mpi_libs_and_gpu_available,
Expand Down
2 changes: 1 addition & 1 deletion sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_convert_to_dataframe,
get_dataframes_and_queues,
)
from sklearnex.tests._utils_spmd import (
from sklearnex.tests.utils.spmd import (
_assert_unordered_allclose,
_generate_classification_data,
_generate_regression_data,
Expand Down
2 changes: 1 addition & 1 deletion sklearnex/tests/test_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
)
from onedal.tests.utils._device_selection import get_queues, is_dpctl_available
from sklearnex import config_context
from sklearnex.tests._utils import PATCHED_FUNCTIONS, PATCHED_MODELS, SPECIAL_INSTANCES
from sklearnex.tests.utils import PATCHED_FUNCTIONS, PATCHED_MODELS, SPECIAL_INSTANCES
from sklearnex.utils._array_api import get_namespace

if _is_dpc_backend:
Expand Down
2 changes: 1 addition & 1 deletion sklearnex/tests/test_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from sklearnex import is_patched_instance
from sklearnex.dispatcher import _is_preview_enabled
from sklearnex.metrics import pairwise_distances, roc_auc_score
from sklearnex.tests._utils import (
from sklearnex.tests.utils import (
DTYPES,
PATCHED_FUNCTIONS,
PATCHED_MODELS,
Expand Down
20 changes: 10 additions & 10 deletions sklearnex/tests/test_run_to_run_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@

import numpy as np
import pytest
from _utils import (
PATCHED_MODELS,
SPECIAL_INSTANCES,
_sklearn_clone_dict,
call_method,
gen_dataset,
gen_models_info,
)
from numpy.testing import assert_allclose
from scipy import sparse
from sklearn.datasets import (
Expand All @@ -52,6 +44,14 @@
NearestNeighbors,
)
from sklearnex.svm import SVC
from sklearnex.tests.utils import (
PATCHED_MODELS,
SPECIAL_INSTANCES,
call_method,
gen_dataset,
gen_models_info,
sklearn_clone_dict,
)

# to reproduce errors even in CI
d4p.daalinit(nthreads=100)
Expand Down Expand Up @@ -124,9 +124,9 @@ def _run_test(estimator, method, datasets):
KMeans(init="k-means++"),
]
)
SPARSE_INSTANCES = _sklearn_clone_dict({str(i): i for i in _sparse_instances})
SPARSE_INSTANCES = sklearn_clone_dict({str(i): i for i in _sparse_instances})

STABILITY_INSTANCES = _sklearn_clone_dict(
STABILITY_INSTANCES = sklearn_clone_dict(
{
str(i): i
for i in [
Expand Down
41 changes: 41 additions & 0 deletions sklearnex/tests/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# ==============================================================================
# Copyright 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from .base import (
DTYPES,
PATCHED_FUNCTIONS,
PATCHED_MODELS,
SPECIAL_INSTANCES,
UNPATCHED_FUNCTIONS,
UNPATCHED_MODELS,
call_method,
gen_dataset,
gen_models_info,
sklearn_clone_dict,
)

__all__ = [
"DTYPES",
"PATCHED_FUNCTIONS",
"PATCHED_MODELS",
"UNPATCHED_FUNCTIONS",
"UNPATCHED_MODELS",
"SPECIAL_INSTANCES",
"call_method",
"gen_models_info",
"gen_dataset",
"sklearn_clone_dict",
]
4 changes: 2 additions & 2 deletions sklearnex/tests/_utils.py → sklearnex/tests/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _load_all_models(with_sklearnex=True, estimator=True):
]


class _sklearn_clone_dict(dict):
class sklearn_clone_dict(dict):
"""Special dict type for returning state-free sklearn/sklearnex estimators
with the same parameters"""

Expand All @@ -118,7 +118,7 @@ def __getitem__(self, key):
# could be because of supported non-default parameters, blocked support via sklearn's
# 'available_if' decorator, or not being a native sklearn estimator (i.e. those not in
# the default PATCHED_MODELS dictionary)
SPECIAL_INSTANCES = _sklearn_clone_dict(
SPECIAL_INSTANCES = sklearn_clone_dict(
{
str(i): i
for i in [
Expand Down
File renamed without changes.

0 comments on commit 1c81337

Please sign in to comment.