Skip to content

Commit

Permalink
FIX make sure the decision function of weak learner is symmetric (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
pierrenodet committed Jun 17, 2023
1 parent ecac4da commit 0a48d24
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,15 @@ class BoostingClassificationModel(
val res = Array.fill(numClasses)(0.0)
var i = 0
while (i < numModels) {
res(models(i).predict(features).toInt) += weights(i)
val prediction = models(i).predict(features).toInt
val weight = weights(i)
var c = 0
while (c < numClasses) {
if (prediction == c) {
res(c) += weight
} else { res(c) -= 1.0 / (numClasses - 1) * weight }
c += 1
}
i += 1
}
Vectors.dense(res)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers._

import scala.collection.mutable.ListBuffer
import org.apache.spark.ml.linalg.DenseVector

class BoostingClassifierSuite extends AnyFunSuite with BeforeAndAfterAll {

Expand Down Expand Up @@ -81,7 +82,7 @@ class BoostingClassifierSuite extends AnyFunSuite with BeforeAndAfterAll {
bcModel.models.take(i))
metrics += mce.evaluate(model.transform(test))
})

assert(
metrics.toList
.sliding(2)
Expand Down Expand Up @@ -122,6 +123,36 @@ class BoostingClassifierSuite extends AnyFunSuite with BeforeAndAfterAll {
mce.evaluate(bcrModel.transform(test)) === mce.evaluate(bcdModel.transform(test)) +- 0.02)
}

List("discrete", "real").foreach { algorithm =>
test(
f"$algorithm boosting decision function respects the symmetric constraint for weak learners") {
val data =
spark.read
.format("libsvm")
.load("../data/letter/letter.svm")
.withColumn("label", col("label") - lit(1))
.cache()
data.count()
val numClasses = data.select("label").distinct().count()

val lr = new DecisionTreeClassifier().setMaxDepth(1)
val boost = new BoostingClassifier()
.setBaseLearner(lr)
.setNumBaseLearners(10)
.setAlgorithm(algorithm)
val boostModel = boost.fit(data)

val predictions = boostModel
.transform(data)
.select("rawPrediction")
.distinct()
.collect()
.map(_.getAs[DenseVector](0).toArray)

predictions.foreach(raw => assert(raw.sum === 0.0 +- 1e-6))
}
}

test("read/write") {
val data =
spark.read
Expand Down
4 changes: 2 additions & 2 deletions docs/bagging.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ new BaggingClassifier()
.setBaseLearner(new DecisionTreeClassifier()) //Base learner used by the meta-estimator.
.setNumBaseLearners(10) //Number of base learners.
.setSubsampleRatio(0.8) //Ratio sampling of examples.
.setReplacement(true) //Exemples drawn with replacement or not.
.setReplacement(true) //Samples drawn with replacement or not.
.setSubspaceRatio(0.8) //Ratio sampling of features.
.setVotingStrategy("soft") //Soft or Hard majority vote.
.setParallelism(4) //Number of base learners trained simultaneously.
Expand All @@ -37,7 +37,7 @@ new BaggingRegressor()
.setBaseLearner(new DecisionTreeRegressor()) //Base learner used by the meta-estimator.
.setNumBaseLearners(10) //Number of base learners.
.setSubsampleRatio(0.8) //Sampling ratio of examples.
.setReplacement(true) //Exemples drawn with replacement or not.
.setReplacement(true) //Samples drawn with replacement or not.
.setSubspaceRatio(0.8) //Sampling ratio of features.
.setParallelism(4) //Number of base learners trained simultaneously.
```
Expand Down

0 comments on commit 0a48d24

Please sign in to comment.