diff --git a/sklearnex/ensemble/_forest.py b/sklearnex/ensemble/_forest.py index 9c697d7023..716eb71dff 100644 --- a/sklearnex/ensemble/_forest.py +++ b/sklearnex/ensemble/_forest.py @@ -453,14 +453,12 @@ def __init__( # The estimator is checked against the class attribute for conformance. # This should only trigger if the user uses this class directly. - if ( - self.estimator.__class__ == DecisionTreeClassifier - and self._onedal_factory != onedal_RandomForestClassifier + if self.estimator.__class__ == DecisionTreeClassifier and not issubclass( + self._onedal_factory, onedal_RandomForestClassifier ): self._onedal_factory = onedal_RandomForestClassifier - elif ( - self.estimator.__class__ == ExtraTreeClassifier - and self._onedal_factory != onedal_ExtraTreesClassifier + elif self.estimator.__class__ == ExtraTreeClassifier and not issubclass( + self._onedal_factory, onedal_ExtraTreesClassifier ): self._onedal_factory = onedal_ExtraTreesClassifier @@ -843,14 +841,12 @@ def __init__( # The splitter is checked against the class attribute for conformance # This should only trigger if the user uses this class directly. - if ( - self.estimator.__class__ == DecisionTreeRegressor - and self._onedal_factory != onedal_RandomForestRegressor + if self.estimator.__class__ == DecisionTreeRegressor and not issubclass( + self._onedal_factory, onedal_RandomForestRegressor ): self._onedal_factory = onedal_RandomForestRegressor - elif ( - self.estimator.__class__ == ExtraTreeRegressor - and self._onedal_factory != onedal_ExtraTreesRegressor + elif self.estimator.__class__ == ExtraTreeRegressor and not issubclass( + self._onedal_factory, onedal_ExtraTreesRegressor ): self._onedal_factory = onedal_ExtraTreesRegressor diff --git a/sklearnex/spmd/ensemble/forest.py b/sklearnex/spmd/ensemble/forest.py index c0f1cb7042..428cf97b77 100644 --- a/sklearnex/spmd/ensemble/forest.py +++ b/sklearnex/spmd/ensemble/forest.py @@ -23,16 +23,9 @@ from ...ensemble import RandomForestRegressor as RandomForestRegressor_Batch -class BaseForestSPMD(ABC): - def _onedal_classifier(self, **onedal_params): - return onedal_RandomForestClassifier(**onedal_params) - - def _onedal_regressor(self, **onedal_params): - return onedal_RandomForestRegressor(**onedal_params) - - -class RandomForestClassifier(BaseForestSPMD, RandomForestClassifier_Batch): +class RandomForestClassifier(RandomForestClassifier_Batch): __doc__ = RandomForestClassifier_Batch.__doc__ + _onedal_factory = onedal_RandomForestClassifier def _onedal_cpu_supported(self, method_name, *data): # TODO: @@ -55,8 +48,9 @@ def _onedal_gpu_supported(self, method_name, *data): return ready -class RandomForestRegressor(BaseForestSPMD, RandomForestRegressor_Batch): +class RandomForestRegressor(RandomForestRegressor_Batch): __doc__ = RandomForestRegressor_Batch.__doc__ + _onedal_factory = onedal_RandomForestRegressor def _onedal_cpu_supported(self, method_name, *data): # TODO: