From 54f64c2d2b9e3c709d64b9786adb35a52971205f Mon Sep 17 00:00:00 2001 From: "mergify[bot]" <37929162+mergify[bot]@users.noreply.github.com> Date: Tue, 13 Feb 2024 22:34:59 +0000 Subject: [PATCH] enh: re-enabling spmd rf interfaces (#1700) (#1711) * enh: re-enabling spmd rf interfaces * restoring onedal_factor usage but adding __class__ * lint * another attempt at it * isinstance to issubclass and re-adding self (cherry picked from commit bfa470b1ed71e7b52f7f36a05cb2b6f00deb2047) Co-authored-by: ethanglaser <42726565+ethanglaser@users.noreply.github.com> --- sklearnex/ensemble/_forest.py | 20 ++++++++------------ sklearnex/spmd/ensemble/forest.py | 14 ++++---------- 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/sklearnex/ensemble/_forest.py b/sklearnex/ensemble/_forest.py index 9c697d7023..716eb71dff 100644 --- a/sklearnex/ensemble/_forest.py +++ b/sklearnex/ensemble/_forest.py @@ -453,14 +453,12 @@ def __init__( # The estimator is checked against the class attribute for conformance. # This should only trigger if the user uses this class directly. - if ( - self.estimator.__class__ == DecisionTreeClassifier - and self._onedal_factory != onedal_RandomForestClassifier + if self.estimator.__class__ == DecisionTreeClassifier and not issubclass( + self._onedal_factory, onedal_RandomForestClassifier ): self._onedal_factory = onedal_RandomForestClassifier - elif ( - self.estimator.__class__ == ExtraTreeClassifier - and self._onedal_factory != onedal_ExtraTreesClassifier + elif self.estimator.__class__ == ExtraTreeClassifier and not issubclass( + self._onedal_factory, onedal_ExtraTreesClassifier ): self._onedal_factory = onedal_ExtraTreesClassifier @@ -843,14 +841,12 @@ def __init__( # The splitter is checked against the class attribute for conformance # This should only trigger if the user uses this class directly. - if ( - self.estimator.__class__ == DecisionTreeRegressor - and self._onedal_factory != onedal_RandomForestRegressor + if self.estimator.__class__ == DecisionTreeRegressor and not issubclass( + self._onedal_factory, onedal_RandomForestRegressor ): self._onedal_factory = onedal_RandomForestRegressor - elif ( - self.estimator.__class__ == ExtraTreeRegressor - and self._onedal_factory != onedal_ExtraTreesRegressor + elif self.estimator.__class__ == ExtraTreeRegressor and not issubclass( + self._onedal_factory, onedal_ExtraTreesRegressor ): self._onedal_factory = onedal_ExtraTreesRegressor diff --git a/sklearnex/spmd/ensemble/forest.py b/sklearnex/spmd/ensemble/forest.py index c0f1cb7042..428cf97b77 100644 --- a/sklearnex/spmd/ensemble/forest.py +++ b/sklearnex/spmd/ensemble/forest.py @@ -23,16 +23,9 @@ from ...ensemble import RandomForestRegressor as RandomForestRegressor_Batch -class BaseForestSPMD(ABC): - def _onedal_classifier(self, **onedal_params): - return onedal_RandomForestClassifier(**onedal_params) - - def _onedal_regressor(self, **onedal_params): - return onedal_RandomForestRegressor(**onedal_params) - - -class RandomForestClassifier(BaseForestSPMD, RandomForestClassifier_Batch): +class RandomForestClassifier(RandomForestClassifier_Batch): __doc__ = RandomForestClassifier_Batch.__doc__ + _onedal_factory = onedal_RandomForestClassifier def _onedal_cpu_supported(self, method_name, *data): # TODO: @@ -55,8 +48,9 @@ def _onedal_gpu_supported(self, method_name, *data): return ready -class RandomForestRegressor(BaseForestSPMD, RandomForestRegressor_Batch): +class RandomForestRegressor(RandomForestRegressor_Batch): __doc__ = RandomForestRegressor_Batch.__doc__ + _onedal_factory = onedal_RandomForestRegressor def _onedal_cpu_supported(self, method_name, *data): # TODO: