Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spark 2.4 support #402

Merged
merged 64 commits into from
Feb 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
77e2229
Revert "Revert back to Spark 2.3 (#399)"
tovbinm Sep 4, 2019
485fcd5
Update to Spark 2.4.3 and XGBoost 0.90
tovbinm May 30, 2019
12f5333
special double serializer fix
tovbinm May 30, 2019
a4ee986
fix serialization
tovbinm May 30, 2019
47703fa
fix serialization
tovbinm May 30, 2019
ea4b11f
docs
tovbinm May 30, 2019
ab81e31
fixed missng value for test
wsuchy May 30, 2019
c017869
meta fix
tovbinm May 30, 2019
99ea7e1
Updated DecisionTreeNumericMapBucketizer test to deal with the change…
Jauntbox May 31, 2019
daa2672
fix params meta test
tovbinm May 31, 2019
7d3ebb7
FIxed failing xgboost test
wsuchy May 31, 2019
5661640
ident
tovbinm May 31, 2019
8852d69
cleanup
tovbinm May 31, 2019
2ab8924
added dataframe reader and writer extensions
tovbinm Jun 3, 2019
7c8b988
added const
tovbinm Jun 3, 2019
d98c8a9
cherrypick fixes
nicodv Oct 14, 2019
cce0d8f
added xgboost params + update models to use public predict method
tovbinm Jun 21, 2019
8afbae7
blarg
tovbinm Jun 21, 2019
e8770f6
double ser test
tovbinm Jun 21, 2019
afffc56
update mleap and spark testing base
tovbinm Jun 24, 2019
ed43719
Update README.md
tovbinm Jul 16, 2019
8804acd
type fix
nicodv Oct 14, 2019
a1461c7
Merge branch 'master' into ndv/spark2.4
nicodv Oct 15, 2019
b54d0f5
bump minor version
nicodv Oct 15, 2019
861f862
Merge branch 'master' into ndv/spark2.4
tovbinm Oct 18, 2019
41f9dc4
Merge branch 'master' into revert-399-mt/revert-spark-2.4
tovbinm Oct 18, 2019
cb4cb7b
Update Spark version in the README
tovbinm Oct 18, 2019
abca58b
bump version
nicodv Oct 18, 2019
ab8daf8
Merge remote-tracking branch 'origin/ndv/spark2.4' into ndv/spark2.4
nicodv Oct 18, 2019
e58da6c
Update build.gradle
tovbinm Oct 19, 2019
37614d7
Update pom.xml
tovbinm Oct 19, 2019
e8a27ea
Merge branch 'master' into revert-399-mt/revert-spark-2.4
tovbinm Oct 19, 2019
e205856
Merge branch 'ndv/spark2.4' into revert-399-mt/revert-spark-2.4
nicodv Oct 21, 2019
f93db0b
Merge branch 'master' into revert-399-mt/revert-spark-2.4
nicodv Oct 21, 2019
505a881
set correct json4s version
nicodv Nov 5, 2019
e21d06c
Merge branch 'master' into revert-399-mt/revert-spark-2.4
tovbinm Nov 7, 2019
84bb70a
Merge branch 'master' into revert-399-mt/revert-spark-2.4
nicodv Nov 9, 2019
08db1db
Merge remote-tracking branch 'origin/revert-399-mt/revert-spark-2.4' …
nicodv Nov 9, 2019
3ba95b6
upgrade helloworld deps
nicodv Nov 15, 2019
d77515c
upgrade notebook deps on TMog and Spark
nicodv Nov 15, 2019
d3b53be
Merge branch 'master' into revert-399-mt/revert-spark-2.4
tovbinm Nov 26, 2019
1634927
Merge branch 'master' into revert-399-mt/revert-spark-2.4
nicodv Dec 5, 2019
b8a43b5
bump to version 0.7.0 for Spark update
nicodv Dec 5, 2019
094be05
align helloworld dependencies
nicodv Dec 5, 2019
cc7de97
Merge remote-tracking branch 'origin/revert-399-mt/revert-spark-2.4' …
nicodv Dec 5, 2019
4bd5977
align helloworld dependencies
nicodv Dec 5, 2019
46fc60a
get -> getOrElse with exception
nicodv Dec 6, 2019
0211fde
Merge branch 'master' into revert-399-mt/revert-spark-2.4
nicodv Dec 6, 2019
7147b72
Merge branch 'master' into revert-399-mt/revert-spark-2.4
nicodv Jan 7, 2020
4d686c7
fix helloworld compilation
nicodv Jan 7, 2020
f6a2a4e
Merge branch 'master' into revert-399-mt/revert-spark-2.4
tovbinm Jan 8, 2020
adc5b25
Merge branch 'master' into revert-399-mt/revert-spark-2.4
tovbinm Jan 22, 2020
c64eb51
Merge branch 'master' into revert-399-mt/revert-spark-2.4
nicodv Jan 27, 2020
e0c164d
Merge branch 'master' into revert-399-mt/revert-spark-2.4
tovbinm Jan 30, 2020
b3c0a74
Spark 2.4.5
tovbinm Feb 14, 2020
f4ab3fd
Spark 2.4.5
tovbinm Feb 14, 2020
50d9dfb
Spark 2.4.5
tovbinm Feb 14, 2020
3417972
Update OpTitanicSimple.ipynb
tovbinm Feb 14, 2020
df38bcc
Update OpIris.ipynb
tovbinm Feb 14, 2020
3af27f3
Revert "Spark 2.4.5"
tovbinm Feb 15, 2020
66afd79
Revert "Spark 2.4.5"
tovbinm Feb 15, 2020
39ff755
Revert "Spark 2.4.5"
tovbinm Feb 15, 2020
09c5f2b
Revert "Update OpTitanicSimple.ipynb"
tovbinm Feb 15, 2020
d79bcc3
Revert "Update OpIris.ipynb"
tovbinm Feb 15, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Start by picking TransmogrifAI version to match your project dependencies from t

| TransmogrifAI Version | Spark Version | Scala Version | Java Version |
|-------------------------------------------------------|:-------------:|:-------------:|:------------:|
| 0.6.2 (unreleased, master) | 2.3 | 2.11 | 1.8 |
| 0.7.0 (unreleased, master) | 2.4 | 2.11 | 1.8 |
Copy link
Contributor

@gerashegalov gerashegalov Jan 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for a future PR: I really hope we start making strides to automating this stuff. We should have here placeholders that are populated from build. Even if it's some quick and dirty version of sed script

| **0.6.1 (stable)**, 0.6.0, 0.5.3, 0.5.2, 0.5.1, 0.5.0 | **2.3** | **2.11** | **1.8** |
| 0.4.0, 0.3.4 | 2.2 | 2.11 | 1.8 |

Expand Down
22 changes: 11 additions & 11 deletions build.gradle
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
buildscript {
repositories {
maven { url "https://plugins.gradle.org/m2/" }
mavenCentral()
jcenter()
nicodv marked this conversation as resolved.
Show resolved Hide resolved
maven { url "https://plugins.gradle.org/m2/" }
}
dependencies {
classpath 'org.github.ngbinh.scalastyle:gradle-scalastyle-plugin_2.11:1.0.1'
classpath 'com.commercehub.gradle.plugin:gradle-avro-plugin:0.16.0'
nicodv marked this conversation as resolved.
Show resolved Hide resolved
}
}

plugins {
id 'com.commercehub.gradle.plugin.avro' version '0.8.0'
id 'org.scoverage' version '2.5.0'
id 'net.minecrell.licenser' version '0.4.1'
id 'com.github.jk1.dependency-license-report' version '0.5.0'
Expand Down Expand Up @@ -57,14 +58,13 @@ configure(allProjs) {
scalaVersionRevision = '12'
scalaTestVersion = '3.0.5'
scalaCheckVersion = '1.14.0'
junitVersion = '4.11'
avroVersion = '1.7.7'
sparkVersion = '2.3.2'
sparkAvroVersion = '4.0.0'
junitVersion = '4.12'
avroVersion = '1.8.2'
sparkVersion = '2.4.4'
scalaGraphVersion = '1.12.5'
scalafmtVersion = '1.5.1'
hadoopVersion = 'hadoop2'
json4sVersion = '3.2.11' // matches Spark dependency version
json4sVersion = '3.5.3' // matches Spark dependency version
jodaTimeVersion = '2.9.4'
jodaConvertVersion = '1.8.1'
algebirdVersion = '0.13.4'
Expand All @@ -75,20 +75,20 @@ configure(allProjs) {
googleLibPhoneNumberVersion = '8.8.5'
googleGeoCoderVersion = '2.82'
googleCarrierVersion = '1.72'
chillVersion = '0.8.4'
chillVersion = '0.9.3'
reflectionsVersion = '0.9.11'
collectionsVersion = '3.2.2'
optimaizeLangDetectorVersion = '0.0.1'
tikaVersion = '1.22'
sparkTestingBaseVersion = '2.3.1_0.10.0'
sparkTestingBaseVersion = '2.4.3_0.12.0'
sourceCodeVersion = '0.1.3'
pegdownVersion = '1.4.2'
commonsValidatorVersion = '1.6'
commonsIOVersion = '2.6'
scoveragePluginVersion = '1.3.1'
xgboostVersion = '0.81'
xgboostVersion = '0.90'
akkaSlf4jVersion = '2.3.11'
mleapVersion = '0.13.0'
mleapVersion = '0.14.0'
memoryFilesystemVersion = '2.1.0'
}

Expand Down
6 changes: 1 addition & 5 deletions cli/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,16 @@ task copyTemplates(type: Copy) {
fileName.replace(".gradle.template", ".gradle")
}
expand([
databaseHostname: 'db.company.com',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we absolutely need to change this file ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was a cleanup, why not?

Copy link
Contributor

@gerashegalov gerashegalov Jan 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the PR is big enough, and for easier maintenance I hoped to keep it focussed on the Spark upgrade

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, this has no material difference for the change. You always get a bonus of such with my PRs 😛

version: scalaVersion,
scalaVersion: scalaVersion,
scalaVersionRevision: scalaVersionRevision,
scalaTestVersion: scalaTestVersion,
junitVersion: junitVersion,
sparkVersion: sparkVersion,
avroVersion: avroVersion,
sparkAvroVersion: sparkAvroVersion,
hadoopVersion: hadoopVersion,
collectionsVersion: collectionsVersion,
transmogrifaiVersion: version,
buildNumber: (int)(Math.random() * 1000),
date: new Date()
transmogrifaiVersion: version
])
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ case class AutomaticSchema(recordClassName: String)(dataFile: File) extends Sche
case Some(actualType) =>
val newSchema = Schema.create(actualType)
val schemaField =
new Schema.Field(field.name, newSchema, "auto-generated", orgSchemaField.defaultValue)
new Schema.Field(field.name, newSchema, "auto-generated", orgSchemaField.defaultVal())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this change? defaultValue is deprecated but still there, technically not required to change until it's dropped entirely

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am keeping the existing functionality as is. We can change the behavior separately.

AvroField.from(schemaField)
}
} else field
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class AvroFieldTest extends FlatSpec with TestCommon with Assertions {
val allSchemas = (enum::unions)++simpleSchemas // NULL does not work

val fields = allSchemas.zipWithIndex map {
case (s, i) => new Schema.Field("x" + i, s, "Who", null)
case (s, i) => new Schema.Field("x" + i, s, "Who", null: Object)
}

val expected = List(
Expand All @@ -86,7 +86,7 @@ class AvroFieldTest extends FlatSpec with TestCommon with Assertions {

an[IllegalArgumentException] should be thrownBy {
val nullSchema = Schema.create(Schema.Type.NULL)
val nullField = new Schema.Field("xxx", null, "Nobody", null)
val nullField = new Schema.Field("xxx", null, "Nobody", null: Object)
AvroField from nullField
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ class OpLinearSVCModel
) extends OpPredictorWrapperModel[LinearSVCModel](uid = uid, operationName = operationName, sparkModel = sparkModel) {

@transient lazy private val predictRaw = reflectMethod(getSparkMlStage().get, "predictRaw")
@transient lazy private val predict = reflectMethod(getSparkMlStage().get, "predict")
@transient lazy private val predict: Vector => Double = getSparkMlStage().get.predict(_)
tovbinm marked this conversation as resolved.
Show resolved Hide resolved

/**
* Function used to convert input to output
*/
override def transformFn: (RealNN, OPVector) => Prediction = (label, features) => {
val raw = predictRaw(features.value).asInstanceOf[Vector]
val pred = predict(features.value).asInstanceOf[Double]
val pred = predict(features.value)

Prediction(rawPrediction = raw, prediction = pred)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,11 @@ class OpXGBoostClassifier(uid: String = UID[OpXGBoostClassifier])
*/
def setMaxBins(value: Int): this.type = set(maxBins, value)

/**
* Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set.
*/
def setMaxLeaves(value: Int): this.type = set(maxLeaves, value)
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved

/**
* This is only used for approximate greedy algorithm.
* This roughly translated into O(1 / sketch_eps) number of bins. Compared to directly select
Expand Down Expand Up @@ -282,8 +287,19 @@ class OpXGBoostClassifier(uid: String = UID[OpXGBoostClassifier])
def setLambdaBias(value: Double): this.type = set(lambdaBias, value)

// setters for learning params

/**
* Specify the learning task and the corresponding learning objective.
* options: reg:squarederror, reg:logistic, binary:logistic, binary:logitraw, count:poisson,
* multi:softmax, multi:softprob, rank:pairwise, reg:gamma. default: reg:squarederror
*/
def setObjective(value: String): this.type = set(objective, value)

/**
* Objective type used for training. For options see [[ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams]]
*/
def setObjectiveType(value: String): this.type = set(objectiveType, value)

/**
* Specify the learning task and the corresponding learning objective.
* options: reg:linear, reg:logistic, binary:logistic, binary:logitraw, count:poisson,
Expand All @@ -310,6 +326,11 @@ class OpXGBoostClassifier(uid: String = UID[OpXGBoostClassifier])
*/
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)

/**
* Define the expected optimization to the evaluation metrics, true to maximize otherwise minimize it
*/
def setMaximizeEvaluationMetrics(value: Boolean): this.type = set(maximizeEvaluationMetrics, value)

/**
* Customized objective function provided by user. default: null
*/
Expand Down Expand Up @@ -359,17 +380,18 @@ class OpXGBoostClassificationModel

private lazy val model = getSparkMlStage().get
private lazy val booster = model.nativeBooster
private lazy val treeLimit = model.getTreeLimit.toInt
private lazy val treeLimit = model.getTreeLimit
private lazy val missing = model.getMissing

override def transformFn: (RealNN, OPVector) => Prediction = (label, features) => {
val data = removeMissingValues(Iterator(features.value.asXGB), missing)
val data = processMissingValues(Iterator(features.value.asXGB), missing)
val dm = new DMatrix(dataIter = data)
val rawPred = booster.predict(dm, outPutMargin = true, treeLimit = treeLimit)(0).map(_.toDouble)
val rawPrediction = if (model.numClasses == 2) Array(-rawPred(0), rawPred(0)) else rawPred
val prob = booster.predict(dm, outPutMargin = false, treeLimit = treeLimit)(0).map(_.toDouble)
val probability = if (model.numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
val prediction = probability2predictionMirror(Vectors.dense(probability)).asInstanceOf[Double]

Prediction(prediction = prediction, rawPrediction = rawPrediction, probability = probability)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import com.salesforce.op.UID
import com.salesforce.op.features.types.{OPVector, Prediction, RealNN}
import com.salesforce.op.stages.impl.CheckIsResponseValues
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictionModel, OpPredictorWrapper}
import com.salesforce.op.utils.reflection.ReflectionUtils.reflectMethod
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor, OpDecisionTreeRegressorParams}

import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -113,7 +112,4 @@ class OpDecisionTreeRegressionModel
ttov: TypeTag[Prediction#Value]
) extends OpPredictionModel[DecisionTreeRegressionModel](
sparkModel = sparkModel, uid = uid, operationName = operationName
) {
@transient lazy val predictMirror = reflectMethod(getSparkMlStage().get, "predict")
}

)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import com.salesforce.op.UID
import com.salesforce.op.features.types.{OPVector, Prediction, RealNN}
import com.salesforce.op.stages.impl.CheckIsResponseValues
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictionModel, OpPredictorWrapper}
import com.salesforce.op.utils.reflection.ReflectionUtils.reflectMethod
import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor, OpGBTRegressorParams}

import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -139,7 +138,4 @@ class OpGBTRegressionModel
ttov: TypeTag[Prediction#Value]
) extends OpPredictionModel[GBTRegressionModel](
sparkModel = sparkModel, uid = uid, operationName = operationName
) {
@transient lazy val predictMirror = reflectMethod(getSparkMlStage().get, "predict")
}

)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import com.salesforce.op._
import com.salesforce.op.features.types.{OPVector, Prediction, RealNN}
import com.salesforce.op.stages.impl.CheckIsResponseValues
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictionModel, OpPredictorWrapper}
import com.salesforce.op.utils.reflection.ReflectionUtils.reflectMethod
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel, OpLinearRegressionParams}

import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -205,7 +204,4 @@ class OpLinearRegressionModel
ttov: TypeTag[Prediction#Value]
) extends OpPredictionModel[LinearRegressionModel](
sparkModel = sparkModel, uid = uid, operationName = operationName
) {
@transient lazy val predictMirror = reflectMethod(getSparkMlStage().get, "predict")
}

)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import com.salesforce.op.UID
import com.salesforce.op.features.types.{OPVector, Prediction, RealNN}
import com.salesforce.op.stages.impl.CheckIsResponseValues
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictionModel, OpPredictorWrapper}
import com.salesforce.op.utils.reflection.ReflectionUtils.reflectMethod
import org.apache.spark.ml.regression.{OpRandomForestRegressorParams, RandomForestRegressionModel, RandomForestRegressor}

import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -126,8 +125,4 @@ class OpRandomForestRegressionModel
ttov: TypeTag[Prediction#Value]
) extends OpPredictionModel[RandomForestRegressionModel](
sparkModel = sparkModel, uid = uid, operationName = operationName
) {
@transient lazy val predictMirror = reflectMethod(getSparkMlStage().get, "predict")
}


)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import com.salesforce.op.UID
import com.salesforce.op.features.types.{OPVector, Prediction, RealNN}
import com.salesforce.op.stages.impl.CheckIsResponseValues
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictionModel, OpPredictorWrapper}
import com.salesforce.op.utils.reflection.ReflectionUtils.reflectMethod
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
import ml.dmlc.xgboost4j.scala.spark.{OpXGBoostRegressorParams, TrackerConf, XGBoostRegressionModel, XGBoostRegressor}

Expand Down Expand Up @@ -234,6 +233,11 @@ class OpXGBoostRegressor(uid: String = UID[OpXGBoostRegressor])
*/
def setMaxBins(value: Int): this.type = set(maxBins, value)

/**
* Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set.
*/
def setMaxLeaves(value: Int): this.type = set(maxLeaves, value)

/**
* This is only used for approximate greedy algorithm.
* This roughly translated into O(1 / sketch_eps) number of bins. Compared to directly select
Expand Down Expand Up @@ -281,8 +285,19 @@ class OpXGBoostRegressor(uid: String = UID[OpXGBoostRegressor])
def setLambdaBias(value: Double): this.type = set(lambdaBias, value)

// setters for learning params

/**
* Specify the learning task and the corresponding learning objective.
* options: reg:squarederror, reg:logistic, binary:logistic, binary:logitraw, count:poisson,
* multi:softmax, multi:softprob, rank:pairwise, reg:gamma. default: reg:squarederror
*/
def setObjective(value: String): this.type = set(objective, value)

/**
* Objective type used for training. For options see [[ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams]]
*/
def setObjectiveType(value: String): this.type = set(objectiveType, value)

/**
* Specify the learning task and the corresponding learning objective.
* options: reg:linear, reg:logistic, binary:logistic, binary:logitraw, count:poisson,
Expand All @@ -309,6 +324,11 @@ class OpXGBoostRegressor(uid: String = UID[OpXGBoostRegressor])
*/
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)

/**
* Define the expected optimization to the evaluation metrics, true to maximize otherwise minimize it
*/
def setMaximizeEvaluationMetrics(value: Boolean): this.type = set(maximizeEvaluationMetrics, value)

/**
* Customized objective function provided by user. default: null
*/
Expand Down Expand Up @@ -341,6 +361,4 @@ class OpXGBoostRegressionModel
ttov: TypeTag[Prediction#Value]
) extends OpPredictionModel[XGBoostRegressionModel](
sparkModel = sparkModel, uid = uid, operationName = operationName
) {
@transient lazy val predictMirror = reflectMethod(getSparkMlStage().get, "predict")
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ import com.salesforce.op.stages.impl.tuning.{BestEstimator, _}
import com.salesforce.op.stages.sparkwrappers.generic.SparkWrapperParams
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapperModel, SparkModelConverter}
import com.salesforce.op.utils.spark.RichMetadata._
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.spark.RichParamMap._
import com.salesforce.op.utils.stages.FitStagesUtil._
import org.apache.spark.ml.param._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,12 @@ abstract class OpPredictionModel[T <: PredictionModel[Vector, T]]
operationName: String
) extends OpPredictorWrapperModel[T](uid = uid, operationName = operationName, sparkModel = sparkModel) {

protected def predictMirror: MethodMirror

protected def predict(features: Vector): Double = predictMirror.apply(features).asInstanceOf[Double]
/**
* Predict label for the given features
*/
@transient protected lazy val predict: Vector => Double = getSparkMlStage().getOrElse(
throw new RuntimeException(s"Could not find the wrapped Spark stage.")
).predict(_)

/**
* Function used to convert input to output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ case object OpXGBoost {
}

/**
* Hack to access [[ml.dmlc.xgboost4j.scala.spark.XGBoost.removeMissingValues]] private method
* Hack to access [[ml.dmlc.xgboost4j.scala.spark.XGBoost.processMissingValues]] private method
*/
def removeMissingValues(xgbLabelPoints: Iterator[LabeledPoint], missing: Float): Iterator[LabeledPoint] =
XGBoost.removeMissingValues(xgbLabelPoints, missing)
def processMissingValues(xgbLabelPoints: Iterator[LabeledPoint], missing: Float): Iterator[LabeledPoint] =
XGBoost.processMissingValues(xgbLabelPoints, missing)
}
Loading