Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML_TRANSFORM and New flavor: Spark ML #16

Open
da-tubi opened this issue Jan 6, 2023 · 0 comments
Open

ML_TRANSFORM and New flavor: Spark ML #16

da-tubi opened this issue Jan 6, 2023 · 0 comments

Comments

@da-tubi
Copy link
Collaborator

da-tubi commented Jan 6, 2023

Conclusion

The Spark ML new flavor requires ML_TRANSFORM but not ML_PREDICT.

The difference between ML_TRANSFORM and ML_PREDICT is the size of the model.

ML_PREDICT is implemented using PySpark pandas_udf, it works well with small models which can be loaded in one node. ML_TRANSFORM is for big models which can not be loaded in one node.

Two Previous Attempts by Renkai

eto-ai/rikai#338
Tried to implement ML_PREDICT for SparkML like @da-tubi did for eto-ai/rikai#326 , but it's much more complex than I thought, maybe the best way to complete it is to implement ML_PREDICT UDF for SparkML in Scala, so the worker will not need SparkContext to get a proper set JVM.

However, it's independent with this issue, we can still implement training by SparkML feature, just can't use ML_PREDICT for SparkML.


eto-ai/rikai#343
Another try to implement ML_PREDICT for SparkMl failed, though I already try it in scala, the key issue that caused the failure is SparkML model can only deal with Dataset, which is not attachable in UDF, we need to replace ML_PREDICT to driver side code generator not only another UDF.

Let us try again

Demo Code: RandomForestClassifier

https://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-classifier

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}

// Load and parse the data file, converting it to a DataFrame.
val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")

// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
val labelIndexer = new StringIndexer()
  .setInputCol("label")
  .setOutputCol("indexedLabel")
  .fit(data)
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
val featureIndexer = new VectorIndexer()
  .setInputCol("features")
  .setOutputCol("indexedFeatures")
  .setMaxCategories(4)
  .fit(data)

// Split the data into training and test sets (30% held out for testing).
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

// Train a RandomForest model.
val rf = new RandomForestClassifier()
  .setLabelCol("indexedLabel")
  .setFeaturesCol("indexedFeatures")
  .setNumTrees(10)

// Convert indexed labels back to original labels.
val labelConverter = new IndexToString()
  .setInputCol("prediction")
  .setOutputCol("predictedLabel")
  .setLabels(labelIndexer.labelsArray(0))

// Chain indexers and forest in a Pipeline.
val pipeline = new Pipeline()
  .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))

// Train model. This also runs the indexers.
val model = pipeline.fit(trainingData)

// Make predictions.
val predictions = model.transform(testData)

// Select example rows to display.
predictions.select("predictedLabel", "label", "features").show(5)

// Select (prediction, true label) and compute test error.
val evaluator = new MulticlassClassificationEvaluator()
  .setLabelCol("indexedLabel")
  .setPredictionCol("prediction")
  .setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println(s"Test Error = ${(1.0 - accuracy)}")

val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
println(s"Learned classification forest model:\n ${rfModel.toDebugString}")

MLflow API

@da-tubi da-tubi changed the title New flavor: Spark ML ML_TRANSFORM and New flavor: Spark ML Jan 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant