Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Scala 2.12 / Spark 3 upgrade #550

Open
wants to merge 100 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 85 commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
f6264a7
Update to Spark 2.4.3 and XGBoost 0.90
tovbinm May 30, 2019
685d6e1
special double serializer fix
tovbinm May 30, 2019
e62772d
fix serialization
tovbinm May 30, 2019
69247ac
fix serialization
tovbinm May 30, 2019
330bf50
docs
tovbinm May 30, 2019
d6b0723
fixed missng value for test
wsuchy May 30, 2019
63b77b5
meta fix
tovbinm May 30, 2019
4e46e31
Merge branch 'mt/spark-2.4' of github.com:salesforce/TransmogrifAI in…
tovbinm May 30, 2019
5a528e1
Updated DecisionTreeNumericMapBucketizer test to deal with the change…
Jauntbox May 31, 2019
5f39603
Merge branch 'mt/spark-2.4' of github.com:salesforce/TransmogrifAI in…
Jauntbox May 31, 2019
0d1a0c0
fix params meta test
tovbinm May 31, 2019
0a4f906
FIxed failing xgboost test
wsuchy May 31, 2019
660db62
Merge branch 'mt/spark-2.4' of github.com:salesforce/TransmogrifAI in…
wsuchy May 31, 2019
3ecca64
ident
tovbinm May 31, 2019
507503a
cleanup
tovbinm May 31, 2019
348a392
added dataframe reader and writer extensions
tovbinm Jun 3, 2019
f43cb26
added const
tovbinm Jun 3, 2019
4455034
Merge branch 'master' into mt/spark-2.4
tovbinm Jun 3, 2019
a0978bf
Merge branch 'master' into mt/spark-2.4
tovbinm Jun 10, 2019
82aa188
build for scala 2.12
koertkuipers Jun 20, 2019
b27b47a
Merge branch 'master' of github.com:salesforce/TransmogrifAI into mt/…
tovbinm Jun 21, 2019
6535e4e
added xgboost params + update models to use public predict method
tovbinm Jun 21, 2019
d1d7b9a
blarg
tovbinm Jun 21, 2019
ac75e15
double ser test
tovbinm Jun 21, 2019
761b889
Merge remote-tracking branch 'upstream/mt/spark-2.4' into feat-scala212
koertkuipers Jun 21, 2019
95095ed
fix unit tests by have lambdas implement concrete classes
koertkuipers Jul 9, 2019
76b411b
Merge branch 'master' into feat-scala212
koertkuipers Aug 5, 2019
ecfb902
remove unnecessary method defaultMatches
koertkuipers Aug 5, 2019
a1a2579
Merge branch 'master' into feat-scala212
koertkuipers Aug 8, 2019
785ddc5
Merge branch 'master' into feat-scala212
koertkuipers Aug 27, 2019
aacf00c
Merge branch 'master' into feat-scala212
koertkuipers Aug 28, 2019
c0a888f
Merge branch 'master' into feat-scala212
koertkuipers Aug 31, 2019
9ececc9
use mleap release
koertkuipers Sep 3, 2019
25a4449
Merge commit '51037a80ee6ef48c5c905ee967187288d78559cb' into feat-sca…
koertkuipers Sep 14, 2019
53df597
Merge commit '95a77b17269a71bf0d53c54df7d76f0bfe862275' into feat-sca…
koertkuipers Sep 14, 2019
4460fe5
Merge branch 'master' into feat-scala212
koertkuipers Sep 14, 2019
5b29d8b
Merge branch 'master' into feat-scala212
koertkuipers Sep 25, 2019
713a9f4
Merge branch 'master' into feat-scala212
koertkuipers Oct 6, 2019
69a3678
Merge branch 'feat-scala212' of server02:oss/TransmogrifAI into feat-…
koertkuipers Oct 7, 2019
f4b3f01
Merge branch 'master' into feat-scala212
koertkuipers Oct 10, 2019
142f121
Merge branch 'master' into feat-scala212
koertkuipers Oct 23, 2019
5ee32b1
Merge branch 'master' into feat-scala212
koertkuipers Nov 22, 2019
6e8e130
Merge branch 'feat-scala212' of server02:oss/TransmogrifAI into feat-…
koertkuipers Nov 22, 2019
c3ccdee
Merge branch 'master' into feat-scala212
koertkuipers Jan 23, 2020
ae1dfcf
Merge branch 'master' into feat-scala212
koertkuipers Feb 24, 2020
fd723d6
Increment scala hotfix prompted test change for random based doubles
tresata-gbernard Feb 28, 2020
e0f0bd8
Merge branch 'master' into feat-scala212
koertkuipers May 6, 2020
98dafde
fix random numbers somehow being different in scala 2.12
koertkuipers May 6, 2020
accd2ba
Merge branch 'master' into feat-scala212
koertkuipers Jun 17, 2020
27fdd3e
Merge branch 'master' into feat-scala212
koertkuipers Aug 21, 2020
f0cbc9e
WIP scala-multiversion-plugin
nicodv Sep 11, 2020
7fb9f0a
Merge remote-tracking branch 'tresata/feat-scala212' into ndv/scala212
nicodv Sep 11, 2020
ff29d1b
upgrade xgboost to version that has 2.11 and 2.12 versions published
nicodv Sep 11, 2020
20b8584
version string fixes
nicodv Sep 11, 2020
ca30345
add TODO
nicodv Sep 11, 2020
e2078e1
update TODO
nicodv Sep 11, 2020
fb16bd9
Merge branch 'master' into ndv/scala212
nicodv Mar 10, 2021
5b61508
update version strings
nicodv Mar 11, 2021
807eca9
update several versions to be scala 2.12 and spark 3 compatible
nicodv Mar 11, 2021
3fba576
various compilation fixes
nicodv Mar 11, 2021
dc4adbc
stack is deprecated, use var List
nicodv Mar 11, 2021
2cca254
use new udf interface
nicodv Mar 11, 2021
d3fbf8f
fix test
nicodv Mar 11, 2021
9fbc9da
compilation fix
nicodv Mar 11, 2021
e8c5b7a
compilation fix
nicodv Mar 11, 2021
c61a5b7
deal with moved csv utils
nicodv Mar 11, 2021
017676a
deal with deprecated operator
nicodv Mar 11, 2021
0538892
disable test for now
nicodv Mar 11, 2021
3e252db
add TODO
nicodv Mar 18, 2021
4fe2fdf
Merge branch 'master' into ndv/scala212
tovbinm Mar 18, 2021
c1941e1
be explicit about xgboost4j dependency
nicodv Mar 18, 2021
bdfae00
Merge remote-tracking branch 'origin/ndv/scala212' into ndv/scala212
nicodv Mar 18, 2021
fe4f2fb
drop support for joined data readers and update docs accordingly
nicodv Mar 21, 2021
c649974
deal with deprecated operator
nicodv Mar 22, 2021
c391aac
refactor for Spark API changes to bin. class. metrics
nicodv Mar 22, 2021
5f55dd9
use new 2.12 optimization options
nicodv Mar 22, 2021
642d27c
adhere to new xgboost interface
nicodv Mar 22, 2021
1605bd4
deal with deprecated syntax
nicodv Mar 22, 2021
64ea9d2
update TODO
nicodv Mar 22, 2021
51806fd
fix tree param overrides
crupley Mar 23, 2021
09b2960
replace deprecated range with bigdecimal range
crupley Mar 23, 2021
a946ffb
Use public wrapper to SparkUserDefinedFunction (SparkUDFFactory) to g…
emitc2h Apr 9, 2021
ec7da39
update stack in while loop in FeatureLike.prettyParentStages
emitc2h Apr 16, 2021
5b555e3
re-enabling @JSONdeserialize annotations while preserving the missing…
emitc2h Apr 17, 2021
30e61a3
ensuring consistent behavior between FeatureDistribution equals and h…
emitc2h Apr 23, 2021
9363b20
Merge branch 'master' into ndv/scala212
tovbinm Apr 23, 2021
6f7c841
Added MomentsSerializer to allow json4s to serialize Algebird's Momen…
emitc2h Apr 28, 2021
9a04faf
Merge branch 'ndv/scala212' of github.com:salesforce/TransmogrifAI in…
emitc2h Apr 28, 2021
4f752ab
Fix random seed issues + coefficient ordering issues in ModelInsights
emitc2h Apr 28, 2021
6731b9d
Fix expected results that changed due to changes in random number gen…
emitc2h Apr 28, 2021
b9e18ce
handle nulls and missing keys in cardinality calculations in SmartTex…
emitc2h Apr 28, 2021
c42163d
make test hash function consistent with OpHashingTF hashing (both now…
emitc2h Apr 28, 2021
7082707
Don't shut down sparkContext after running a test suite, clear cache …
emitc2h Apr 29, 2021
355bbe2
fixing unit tests in features
emitc2h May 3, 2021
2cb1827
fixing unit test failures in testkit due to rng outcome changes
emitc2h May 10, 2021
fc5cdc8
Allow for some tolerance when comparing scores after model write/read…
emitc2h May 10, 2021
dc014fa
use legacy mode to read parquet files written with Spark 2.x (SPARK-3…
emitc2h May 10, 2021
f31ce9f
Store input schema column metadata in its own param during stage exec…
emitc2h May 17, 2021
421b9bc
remove debug line
emitc2h May 17, 2021
0038823
Rolling back most of the ColumnMetadata infra since inputSchema metad…
emitc2h May 18, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ Start by picking TransmogrifAI version to match your project dependencies from t

| TransmogrifAI Version | Spark Version | Scala Version | Java Version |
|-------------------------------------------------------|:-------------:|:-------------:|:------------:|
| 0.7.1 (unreleased, master), **0.7.0 (stable)** | **2.4** | **2.11** | **1.8** |
| 0.8.0 (unreleased, master) | 3.1 | 2.12 | 1.8 |
| **0.7.1 (stable)**, 0.7.0 | **2.4** | **2.11** | **1.8** |
| 0.6.1, 0.6.0, 0.5.3, 0.5.2, 0.5.1, 0.5.0 | 2.3 | 2.11 | 1.8 |
| 0.4.0, 0.3.4 | 2.2 | 2.11 | 1.8 |

Expand All @@ -140,10 +141,10 @@ repositories {
}
dependencies {
// TransmogrifAI core dependency
compile 'com.salesforce.transmogrifai:transmogrifai-core_2.11:0.7.0'
compile 'com.salesforce.transmogrifai:transmogrifai-core_2.12:0.8.0'

// TransmogrifAI pretrained models, e.g. OpenNLP POS/NER models etc. (optional)
// compile 'com.salesforce.transmogrifai:transmogrifai-models_2.11:0.7.0'
// compile 'com.salesforce.transmogrifai:transmogrifai-models_2.12:0.8.0'
}
```

Expand All @@ -154,10 +155,10 @@ scalaVersion := "2.11.12"
resolvers += Resolver.jcenterRepo

// TransmogrifAI core dependency
libraryDependencies += "com.salesforce.transmogrifai" %% "transmogrifai-core" % "0.7.0"
libraryDependencies += "com.salesforce.transmogrifai" %% "transmogrifai-core" % "0.8.0"

// TransmogrifAI pretrained models, e.g. OpenNLP POS/NER models etc. (optional)
// libraryDependencies += "com.salesforce.transmogrifai" %% "transmogrifai-models" % "0.7.0"
// libraryDependencies += "com.salesforce.transmogrifai" %% "transmogrifai-models" % "0.8.0"
```

Then import TransmogrifAI into your code:
Expand Down
63 changes: 32 additions & 31 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ buildscript {
dependencies {
classpath 'org.github.ngbinh.scalastyle:gradle-scalastyle-plugin_2.11:1.0.1'
classpath 'com.commercehub.gradle.plugin:gradle-avro-plugin:0.16.0'
classpath 'com.adtran:scala-multiversion-plugin:1.+'
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially I wanted to go with cross-compilation for 2.11/2.12, but now that we're upgrading to Spark 3 too we'll only build for 2.12.

This cross-compilation plugin seems to work well, though, so keeping it around for future 2.13 support.

}
}

Expand Down Expand Up @@ -46,6 +47,7 @@ configure(allProjs) {
apply plugin: 'net.minecrell.licenser'
apply plugin: 'com.github.jk1.dependency-license-report'
apply plugin: 'com.github.johnrengelman.shadow'
apply plugin: 'com.adtran.scala-multiversion-plugin'

sourceCompatibility = 1.8
targetCompatibility = 1.8
Expand All @@ -54,23 +56,21 @@ configure(allProjs) {
mainClassName = "please.set.main.class.in.build.gradle"

ext {
scalaVersion = '2.11'
scalaVersionRevision = '12'
scalaTestVersion = '3.0.5'
scalaCheckVersion = '1.14.0'
junitVersion = '4.12'
avroVersion = '1.8.2'
sparkVersion = '2.4.5'
sparkVersion = '3.1.1'
scalaGraphVersion = '1.12.5'
scalafmtVersion = '1.5.1'
hadoopVersion = 'hadoop2'
json4sVersion = '3.5.3' // matches Spark dependency version
json4sVersion = '3.7.0-M5' // matches Spark dependency version
jodaTimeVersion = '2.9.4'
jodaConvertVersion = '1.8.1'
algebirdVersion = '0.13.4'
jacksonVersion = '2.7.3'
jacksonVersion = '2.12.2'
luceneVersion = '7.3.0'
enumeratumVersion = '1.4.12'
enumeratumVersion = '1.4.18'
scoptVersion = '3.5.0'
googleLibPhoneNumberVersion = '8.8.5'
googleGeoCoderVersion = '2.82'
Expand All @@ -80,15 +80,15 @@ configure(allProjs) {
collectionsVersion = '3.2.2'
optimaizeLangDetectorVersion = '0.0.1'
tikaVersion = '1.22'
sparkTestingBaseVersion = '2.4.3_0.12.0'
sparkTestingBaseVersion = '3.0.1_1.0.0'
sourceCodeVersion = '0.1.3'
pegdownVersion = '1.4.2'
commonsValidatorVersion = '1.6'
commonsIOVersion = '2.6'
scoveragePluginVersion = '1.3.1'
xgboostVersion = '0.90'
akkaSlf4jVersion = '2.3.11'
mleapVersion = '0.16.0'
xgboostVersion = '1.3.1'
akkaSlf4jVersion = '2.5.23'
mleapVersion = '0.16.0' // TODO: upgrade to Spark 3-compatibel 0.17 when ready: https://github.com/combust/mleap/issues/727
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

memoryFilesystemVersion = '2.1.0'
}

Expand All @@ -100,37 +100,37 @@ configure(allProjs) {
dependencies {
// Scala
zinc 'com.typesafe.zinc:zinc:0.3.15'
scoverage "org.scoverage:scalac-scoverage-plugin_$scalaVersion:$scoveragePluginVersion"
scoverage "org.scoverage:scalac-scoverage-runtime_$scalaVersion:$scoveragePluginVersion"
scalaLibrary "org.scala-lang:scala-library:$scalaVersion.$scalaVersionRevision"
scalaCompiler "org.scala-lang:scala-compiler:$scalaVersion.$scalaVersionRevision"
compile "org.scala-lang:scala-library:$scalaVersion.$scalaVersionRevision"
scoverage "org.scoverage:scalac-scoverage-plugin_%%:$scoveragePluginVersion"
scoverage "org.scoverage:scalac-scoverage-runtime_%%:$scoveragePluginVersion"
scalaLibrary "org.scala-lang:scala-library:$scalaVersion"
scalaCompiler "org.scala-lang:scala-compiler:$scalaVersion"
compile "org.scala-lang:scala-library:$scalaVersion"

// Spark
compileOnly "org.apache.spark:spark-core_$scalaVersion:$sparkVersion"
testCompile "org.apache.spark:spark-core_$scalaVersion:$sparkVersion"
compileOnly "org.apache.spark:spark-mllib_$scalaVersion:$sparkVersion"
testCompile "org.apache.spark:spark-mllib_$scalaVersion:$sparkVersion"
compileOnly "org.apache.spark:spark-sql_$scalaVersion:$sparkVersion"
testCompile "org.apache.spark:spark-sql_$scalaVersion:$sparkVersion"
compileOnly "org.apache.spark:spark-core_%%:$sparkVersion"
testCompile "org.apache.spark:spark-core_%%:$sparkVersion"
compileOnly "org.apache.spark:spark-mllib_%%:$sparkVersion"
testCompile "org.apache.spark:spark-mllib_%%:$sparkVersion"
compileOnly "org.apache.spark:spark-sql_%%:$sparkVersion"
testCompile "org.apache.spark:spark-sql_%%:$sparkVersion"

// Test
compileOnly "org.scalatest:scalatest_$scalaVersion:$scalaTestVersion"
testCompile "org.scalatest:scalatest_$scalaVersion:$scalaTestVersion"
compileOnly "org.scalacheck:scalacheck_$scalaVersion:$scalaCheckVersion"
testCompile "org.scoverage:scalac-scoverage-plugin_$scalaVersion:$scoveragePluginVersion"
testCompile "org.scoverage:scalac-scoverage-runtime_$scalaVersion:$scoveragePluginVersion"
testCompile "org.scalacheck:scalacheck_$scalaVersion:$scalaCheckVersion"
testCompile ("com.holdenkarau:spark-testing-base_$scalaVersion:$sparkTestingBaseVersion") { transitive = false }
compileOnly "org.scalatest:scalatest_%%:$scalaTestVersion"
testCompile "org.scalatest:scalatest_%%:$scalaTestVersion"
compileOnly "org.scalacheck:scalacheck_%%:$scalaCheckVersion"
testCompile "org.scoverage:scalac-scoverage-plugin_%%:$scoveragePluginVersion"
testCompile "org.scoverage:scalac-scoverage-runtime_%%:$scoveragePluginVersion"
testCompile "org.scalacheck:scalacheck_%%:$scalaCheckVersion"
testCompile ("com.holdenkarau:spark-testing-base_%%:$sparkTestingBaseVersion") { transitive = false }
testCompile "junit:junit:$junitVersion"
testRuntime "org.pegdown:pegdown:$pegdownVersion"
}

configurations.all {
resolutionStrategy {
force "commons-collections:commons-collections:$collectionsVersion",
"org.scala-lang:scala-library:$scalaVersion.$scalaVersionRevision",
"org.scala-lang:scala-reflect:$scalaVersion.$scalaVersionRevision"
"org.scala-lang:scala-library:$scalaVersion",
"org.scala-lang:scala-reflect:$scalaVersion"
}
}
configurations.zinc {
Expand All @@ -149,7 +149,7 @@ configure(allProjs) {
"-language:implicitConversions", "-language:existentials", "-language:postfixOps"
]
}
compileScala.scalaCompileOptions.additionalParameters += "-optimize"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove optimization option?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-optimize is deprecated. I've now added the new optimization flags that replace it.

compileScala.scalaCompileOptions.additionalParameters += ["-opt:l:inline", "-opt-inline-from:**"]
[compileJava, compileTestJava]*.options.collect { options -> options.encoding = 'UTF-8' }

jar {
Expand All @@ -161,6 +161,7 @@ configure(allProjs) {
}

scalaStyle {
scalaVersion = '$scalaVersion'
configLocation = "$rootProject.rootDir/gradle/scalastyle-config.xml"
includeTestSourceDirectory = true
source = "src/main/scala"
Expand Down
7 changes: 3 additions & 4 deletions cli/build.gradle
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
dependencies {
// scopt
compile "com.github.scopt:scopt_$scalaVersion:$scoptVersion"
compile "com.github.scopt:scopt_%%:$scoptVersion"

// scalafmt
compile "com.geirsson:scalafmt-core_$scalaVersion:$scalafmtVersion"
compile "com.geirsson:scalafmt-core_%%:$scalafmtVersion"

// Reflections
compile "org.reflections:reflections:$reflectionsVersion"

compile "org.apache.spark:spark-sql_$scalaVersion:$sparkVersion"
compile "org.apache.spark:spark-sql_%%:$sparkVersion"

testCompile project(':utils')

Expand Down Expand Up @@ -71,7 +71,6 @@ task copyTemplates(type: Copy) {
expand([
version: scalaVersion,
scalaVersion: scalaVersion,
scalaVersionRevision: scalaVersionRevision,
scalaTestVersion: scalaTestVersion,
junitVersion: junitVersion,
sparkVersion: sparkVersion,
Expand Down
7 changes: 4 additions & 3 deletions core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ dependencies {
compile "org.apache.lucene:lucene-suggest:$luceneVersion"

// Scopt
compile "com.github.scopt:scopt_$scalaVersion:$scoptVersion"
compile "com.github.scopt:scopt_%%:$scoptVersion"

// Zip util
compile 'org.zeroturnaround:zt-zip:1.14'

// XGBoost
compile ("ml.dmlc:xgboost4j-spark:$xgboostVersion") { exclude group: 'com.esotericsoftware.kryo', module: 'kryo' }
compile ("ml.dmlc:xgboost4j_%%:$xgboostVersion") { exclude group: 'com.esotericsoftware.kryo', module: 'kryo' }
compile ("ml.dmlc:xgboost4j-spark_%%:$xgboostVersion") { exclude group: 'com.esotericsoftware.kryo', module: 'kryo' }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need both xgboost4j and xgboost4j-spark here, according to this. However, can we also use the gpu-enabled version of xgboost now ? The artifact names are xgboost4j-gpu_2.12 and xgboost4j-spark-gpu_2.12

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xgboost4j already gets pulled in, but let me add it to be explicit.

W.r.t. GPUs: let's leave that baby alone for now and get this upgrade done first. 😅

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you finish this PR you can aim at an e2e on GPU throughout data transformations and Xgboost 😎

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes @gerashegalov i was trying to get people to use rapids ;)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the Spark side we require 3.x , so you are on the right track @TuanNguyen27

// Akka slfj4 logging (version matches XGBoost dependency)
testCompile "com.typesafe.akka:akka-slf4j_$scalaVersion:$akkaSlf4jVersion"
testCompile "com.typesafe.akka:akka-slf4j_%%:$akkaSlf4jVersion"
}
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ case object ModelInsights {
): ModelInsights = {

// TODO support other model types?
val models = stages.collect{
val models: Array[OPStage with Model[_]] = stages.collect{
case s: SelectedModel => s
case s: OpPredictorWrapperModel[_] => s
case s: SelectedCombinerModel => s
Expand Down
8 changes: 3 additions & 5 deletions core/src/main/scala/com/salesforce/op/OpWorkflow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,11 @@ import com.salesforce.op.stages.impl.preparators.CorrelationType
import com.salesforce.op.stages.impl.selector.ModelSelector
import com.salesforce.op.utils.reflection.ReflectionUtils
import com.salesforce.op.utils.spark.{JobGroupUtil, OpStep}
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.stages.FitStagesUtil
import com.salesforce.op.utils.stages.FitStagesUtil.{CutDAG, FittedDAG, Layer, StagesDAG}
import enumeratum.{Enum, EnumEntry}
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Transformer}
import org.apache.spark.ml.Transformer
import org.apache.spark.sql.{DataFrame, SparkSession}

import scala.collection.mutable.{MutableList => MList}
Expand Down Expand Up @@ -91,7 +90,6 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore {
val featuresArr = features.toArray
resultFeatures = featuresArr
rawFeatures = featuresArr.flatMap(_.rawFeatures).distinct.sortBy(_.name)
checkUnmatchedFeatures()
setStagesDAG(features = featuresArr)
validateStages()

Expand Down Expand Up @@ -238,7 +236,7 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore {
case (None, None) => throw new IllegalArgumentException(
"Data reader must be set either directly on the workflow or through the RawFeatureFilter")
case (Some(r), None) =>
checkReadersAndFeatures()
checkFeatures()
r.generateDataFrame(rawFeatures, parameters).persist()
case (rd, Some(rf)) =>
rd match {
Expand All @@ -247,7 +245,7 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore {
"Workflow data reader and RawFeatureFilter training reader do not match! " +
"The RawFeatureFilter training reader will be used to generate the data for training")
}
checkReadersAndFeatures()
checkFeatures()

val FilteredRawData(cleanedData, featuresToDrop, mapKeysToDrop, rawFeatureFilterResults) =
rf.generateFilteredRaw(rawFeatures, parameters)
Expand Down
36 changes: 2 additions & 34 deletions core/src/main/scala/com/salesforce/op/OpWorkflowCore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ private[op] trait OpWorkflowCore {
*/
final def setReader(r: Reader[_]): this.type = {
reader = Option(r)
checkUnmatchedFeatures()
this
}

Expand All @@ -149,7 +148,6 @@ private[op] trait OpWorkflowCore {
def readFn(params: OpParams)(implicit spark: SparkSession): Either[RDD[T], Dataset[T]] = Right(ds)
}
reader = Option(newReader)
checkUnmatchedFeatures()
this
}

Expand All @@ -166,7 +164,6 @@ private[op] trait OpWorkflowCore {
def readFn(params: OpParams)(implicit spark: SparkSession): Either[RDD[T], Dataset[T]] = Left(rdd)
}
reader = Option(newReader)
checkUnmatchedFeatures()
this
}

Expand Down Expand Up @@ -247,40 +244,11 @@ private[op] trait OpWorkflowCore {
*/
final def getRawFeatureFilterResults(): RawFeatureFilterResults = rawFeatureFilterResults


/**
* Determine if any of the raw features do not have a matching reader
* Check that features are set and that params match them
*/
protected def checkUnmatchedFeatures(): Unit = {
if (rawFeatures.nonEmpty && reader.nonEmpty) {
val readerInputTypes = reader.get.subReaders.map(_.fullTypeName).toSet
val unmatchedFeatures = rawFeatures.filterNot(f =>
readerInputTypes
.contains(f.originStage.asInstanceOf[FeatureGeneratorStage[_, _ <: FeatureType]].tti.tpe.toString)
)
require(
unmatchedFeatures.isEmpty,
s"No matching data readers for ${unmatchedFeatures.length} input features:" +
s" ${unmatchedFeatures.mkString(",")}. Readers had types: ${readerInputTypes.mkString(",")}"
)
}
}

/**
* Check that readers and features are set and that params match them
*/
protected def checkReadersAndFeatures() = {
protected def checkFeatures() = {
require(rawFeatures.nonEmpty, "Result features must be set")
checkUnmatchedFeatures()

val subReaderTypes = reader.get.subReaders.map(_.typeName).toSet
val unmatchedReaders = subReaderTypes.filterNot { t => parameters.readerParams.contains(t) }

if (unmatchedReaders.nonEmpty) {
log.info(
"Readers for types: {} do not have an override path in readerParams, so the default will be used",
unmatchedReaders.mkString(","))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams
protected def generateRawData()(implicit spark: SparkSession): DataFrame = {
JobGroupUtil.withJobGroup(OpStep.DataReadingAndFiltering) {
require(reader.nonEmpty, "Data reader must be set")
checkReadersAndFeatures()
checkFeatures()
reader.get.generateDataFrame(rawFeatures, parameters).persist() // don't want to redo this
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@ private[op] class OpBinaryClassificationEvaluator
val aUPR = sparkMLMetrics.areaUnderPR()

val confusionMatrixByThreshold = sparkMLMetrics.confusionMatrixByThreshold().collect()
// Since we're not using sample weights, we simply cast the counts back to Longs.
val (copiedTupPos, copiedTupNeg) = confusionMatrixByThreshold.map { case (_, confusionMatrix) =>
((confusionMatrix.numTruePositives, confusionMatrix.numFalsePositives),
(confusionMatrix.numTrueNegatives, confusionMatrix.numFalseNegatives))
((confusionMatrix.weightedTruePositives.toLong, confusionMatrix.weightedFalsePositives.toLong),
(confusionMatrix.weightedTrueNegatives.toLong, confusionMatrix.weightedFalseNegatives.toLong))
}.unzip
val (tpByThreshold, fpByThreshold) = copiedTupPos.unzip
val (tnByThreshold, fnByThreshold) = copiedTupNeg.unzip
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ private[op] class OpRegressionEvaluator
isValid = l => l.nonEmpty && (l sameElements l.sorted)
)
setDefault(signedPercentageErrorHistogramBins,
Array(Double.NegativeInfinity) ++ (-100.0 to 100.0 by 10) ++ Array(Double.PositiveInfinity)
Array(Double.NegativeInfinity)
++ (Range.BigDecimal(-100, 100, 10)).map(_.toDouble)
++ Array(Double.PositiveInfinity)
)

def setPercentageErrorHistogramBins(v: Array[Double]): this.type = set(signedPercentageErrorHistogramBins, v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ case class FeatureDistribution
case _ => false
}

override def hashCode(): Int = Objects.hashCode(name, key, count, nulls, distribution,
summaryInfo, moments, cardEstimate, `type`)
override def hashCode(): Int = Objects.hashCode((name, key, count, nulls, distribution.deep,
summaryInfo.deep, moments, cardEstimate, `type`))
}

object FeatureDistribution {
Expand Down
Loading