Skip to content

Commit

Permalink
fix: force multinomial if both classes are specified in the pmml
Browse files Browse the repository at this point in the history
This is similar to the logic that was there before, but was broken due to scikit-learn refactoring.
  • Loading branch information
iamDecode committed Apr 14, 2024
1 parent 417da54 commit 1c0fd47
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions sklearn_pmml_model/ensemble/gb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
import numpy as np
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
try:
from sklearn.ensemble._gb_losses import MultinomialDeviance
except ImportError:
pass
from sklearn_pmml_model.base import PMMLBaseClassifier, PMMLBaseRegressor, IntegerEncodingMixin
from sklearn_pmml_model.tree import get_tree
from scipy.special import expit
Expand Down Expand Up @@ -86,6 +82,7 @@ def __init__(self, pmml):
self.template_estimator = clf

try:
from sklearn.ensemble._gb_losses import MultinomialDeviance
self._check_params()

if self.n_classes_ == 2 and len(segments) == 3 and segments[-1].find('TreeModel') is None:
Expand All @@ -96,10 +93,18 @@ def __init__(self, pmml):
except AttributeError:
self._loss = MultinomialDeviance(self.n_classes_ + 1)
self._loss.K = 2
except AttributeError:
except ImportError:
from sklearn._loss.loss import HalfMultinomialLoss

self._set_max_features()
self._loss = self._get_loss(sample_weight=None)
self.n_trees_per_iteration_ = 1 if self.n_classes_ == 2 else self.n_classes_

if self.n_classes_ == 2 and len(segments) == 3 and segments[-1].find('TreeModel') is None:
# For binary classification where both sides are specified, we need to force multinomial deviance
self._loss = HalfMultinomialLoss(sample_weight=None, n_classes=self.n_classes_ + 1)
self.n_trees_per_iteration_ = self.n_classes_
else:
self._loss = self._get_loss(sample_weight=None)
self.n_trees_per_iteration_ = 1 if self.n_classes_ == 2 else self.n_classes_

try:
self.init = None
Expand Down

0 comments on commit 1c0fd47

Please sign in to comment.