Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert lambda functions into concrete classes to allow compatibility with Scala 2.11/2.12 #357

Merged
merged 3 commits into from
Jul 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions core/src/main/scala/com/salesforce/op/dsl/RichDateFeature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ trait RichDateFeature {
f.transformWith(
new UnaryLambdaTransformer[Date, DateList](
operationName = "dateToList",
RichDateFeatureLambdas.toDateList
new RichDateFeatureLambdas.ToDateList
)
)
}
Expand Down Expand Up @@ -137,7 +137,7 @@ trait RichDateFeature {
f.transformWith(
new UnaryLambdaTransformer[DateTime, DateTimeList](
operationName = "dateTimeToList",
RichDateFeatureLambdas.toDateTimeList
new RichDateFeatureLambdas.ToDateTimeList
)
)
}
Expand Down Expand Up @@ -204,7 +204,13 @@ trait RichDateFeature {
}

object RichDateFeatureLambdas {
def toDateList: Date => DateList = (x: Date) => x.value.toSeq.toDateList

def toDateTimeList: DateTime => DateTimeList = (x: DateTime) => x.value.toSeq.toDateTimeList
class ToDateList extends Function1[Date, DateList] with Serializable {
def apply(v: Date): DateList = v.value.toSeq.toDateList
}

class ToDateTimeList extends Function1[Date, DateTimeList] with Serializable {
def apply(v: Date): DateTimeList = v.value.toSeq.toDateTimeList
}

}
20 changes: 13 additions & 7 deletions core/src/main/scala/com/salesforce/op/dsl/RichMapFeature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

package com.salesforce.op.dsl

import com.salesforce.op.dsl.RichMapFeatureLambdas._
import com.salesforce.op.features.FeatureLike
import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.feature._
Expand Down Expand Up @@ -1098,9 +1097,10 @@ trait RichMapFeature {
* @return prediction, rawPrediction, probability
*/
def tupled(): (FeatureLike[RealNN], FeatureLike[OPVector], FeatureLike[OPVector]) = {
(f.map[RealNN](predictionToRealNN),
f.map[OPVector](predictionToRaw),
f.map[OPVector](predictionToProbability)
import RichMapFeatureLambdas._
(f.map[RealNN](new PredictionToRealNN),
f.map[OPVector](new PredictionToRaw),
f.map[OPVector](new PredictionToProbability)
)
}

Expand All @@ -1121,11 +1121,17 @@ trait RichMapFeature {

object RichMapFeatureLambdas {

def predictionToRealNN: Prediction => RealNN = _.prediction.toRealNN
class PredictionToRealNN extends Function1[Prediction, RealNN] with Serializable {
def apply(p: Prediction): RealNN = p.prediction.toRealNN
}

def predictionToRaw: Prediction => OPVector = p => Vectors.dense(p.rawPrediction).toOPVector
class PredictionToRaw extends Function1[Prediction, OPVector] with Serializable {
def apply(p: Prediction): OPVector = Vectors.dense(p.rawPrediction).toOPVector
}

def predictionToProbability: Prediction => OPVector = p => Vectors.dense(p.probability).toOPVector
class PredictionToProbability extends Function1[Prediction, OPVector] with Serializable {
def apply(p: Prediction): OPVector = Vectors.dense(p.probability).toOPVector
}

}

Expand Down
56 changes: 38 additions & 18 deletions core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ import com.salesforce.op.stages.impl.feature._
import com.salesforce.op.utils.text._

import scala.reflect.runtime.universe.TypeTag


trait RichTextFeature {
self: RichFeature =>

Expand All @@ -48,7 +50,7 @@ trait RichTextFeature {
*
* @return A new MultiPickList feature
*/
def toMultiPickList: FeatureLike[MultiPickList] = f.map[MultiPickList](textToMultiPickList)
def toMultiPickList: FeatureLike[MultiPickList] = f.map[MultiPickList](new TextToMultiPickList)


/**
Expand Down Expand Up @@ -560,14 +562,14 @@ trait RichTextFeature {
*
* @return email prefix
*/
def toEmailPrefix: FeatureLike[Text] = f.map[Text](emailToPrefix, "prefix")
def toEmailPrefix: FeatureLike[Text] = f.map[Text](new EmailPrefixToText, "prefix")

/**
* Extract email domains
*
* @return email domain
*/
def toEmailDomain: FeatureLike[Text] = f.map[Text](emailToDomain, "domain")
def toEmailDomain: FeatureLike[Text] = f.map[Text](new EmailDomainToText, "domain")

/**
* Check if email is valid
Expand Down Expand Up @@ -600,7 +602,7 @@ trait RichTextFeature {
others: Array[FeatureLike[Email]] = Array.empty,
maxPctCardinality: Double = OpOneHotVectorizer.MaxPctCardinality
): FeatureLike[OPVector] = {
val domains = (f +: others).map(_.map[PickList](emailToPickList))
val domains = (f +: others).map(_.map[PickList](new EmailDomainToPickList))
domains.head.pivot(others = domains.tail, topK = topK, minSupport = minSupport, cleanText = cleanText,
trackNulls = trackNulls, maxPctCardinality = maxPctCardinality
)
Expand All @@ -613,19 +615,19 @@ trait RichTextFeature {
/**
* Extract url domain, i.e. salesforce.com, data.com etc.
*/
def toDomain: FeatureLike[Text] = f.map[Text](urlToDomain, "urlDomain")
def toDomain: FeatureLike[Text] = f.map[Text](new URLDomainToText, "urlDomain")

/**
* Extracts url protocol, i.e. http, https, ftp etc.
*/
def toProtocol: FeatureLike[Text] = f.map[Text](urlToProtocol, "urlProtocol")
def toProtocol: FeatureLike[Text] = f.map[Text](new URLProtocolToText, "urlProtocol")

/**
* Verifies if the url is of correct form of "Uniform Resource Identifiers (URI): Generic Syntax"
* RFC2396 (http://www.ietf.org/rfc/rfc2396.txt)
* Default valid protocols are: http, https, ftp.
*/
def isValidUrl: FeatureLike[Binary] = f.exists(urlIsValid)
def isValidUrl: FeatureLike[Binary] = f.exists(new URLIsValid)

/**
* Converts a sequence of [[URL]] features into a vector, extracting the domains of the valid urls
Expand All @@ -650,7 +652,7 @@ trait RichTextFeature {
others: Array[FeatureLike[URL]] = Array.empty,
maxPctCardinality: Double = OpOneHotVectorizer.MaxPctCardinality
): FeatureLike[OPVector] = {
val domains = (f +: others).map(_.map[PickList](urlToPickList))
val domains = (f +: others).map(_.map[PickList](new URLDomainToPickList))
domains.head.pivot(others = domains.tail, topK = topK, minSupport = minSupport, cleanText = cleanText,
trackNulls = trackNulls, maxPctCardinality = maxPctCardinality
)
Expand Down Expand Up @@ -697,7 +699,7 @@ trait RichTextFeature {
): FeatureLike[OPVector] = {

val feats: Array[FeatureLike[PickList]] =
(f +: others).map(_.detectMimeTypes(typeHint).map[PickList](textToPickList))
(f +: others).map(_.detectMimeTypes(typeHint).map[PickList](new TextToPickList))

feats.head.vectorize(
topK = topK, minSupport = minSupport, cleanText = cleanText, trackNulls = trackNulls, others = feats.tail,
Expand Down Expand Up @@ -801,22 +803,40 @@ trait RichTextFeature {

object RichTextFeatureLambdas {

def emailToPickList: Email => PickList = _.domain.toPickList
class EmailDomainToPickList extends Function1[Email, PickList] with Serializable {
def apply(v: Email): PickList = v.domain.toPickList
}

def emailToPrefix: Email => Text = _.prefix.toText
class EmailDomainToText extends Function1[Email, Text] with Serializable {
def apply(v: Email): Text = v.domain.toText
}

def emailToDomain: Email => Text = _.domain.toText
class EmailPrefixToText extends Function1[Email, Text] with Serializable {
def apply(v: Email): Text = v.prefix.toText
}

def urlToPickList: URL => PickList = (v: URL) => if (v.isValid) v.domain.toPickList else PickList.empty
class URLDomainToPickList extends Function1[URL, PickList] with Serializable {
def apply(v: URL): PickList = if (v.isValid) v.domain.toPickList else PickList.empty
}

def urlToDomain: URL => Text = _.domain.toText
class URLDomainToText extends Function1[URL, Text] with Serializable {
def apply(v: URL): Text = v.domain.toText
}

def urlToProtocol: URL => Text = _.protocol.toText
class URLProtocolToText extends Function1[URL, Text] with Serializable {
def apply(v: URL): Text = v.protocol.toText
}

def urlIsValid: URL => Boolean = _.isValid
class URLIsValid extends Function1[URL, Boolean] with Serializable {
def apply(v: URL): Boolean = v.isValid
}

def textToPickList: Text => PickList = _.value.toPickList
class TextToPickList extends Function1[Text, PickList] with Serializable {
def apply(v: Text): PickList = v.value.toPickList
}

def textToMultiPickList: Text => MultiPickList = _.value.toSet[String].toMultiPickList
class TextToMultiPickList extends Function1[Text, MultiPickList] with Serializable {
def apply(v: Text): MultiPickList = v.value.toSet[String].toMultiPickList
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,14 @@ class EmailToPickListMapTransformer(uid: String = UID[EmailToPickListMapTransfor
operationName = "emailToPickListMap",
transformer = new UnaryLambdaTransformer[Email, PickList](
operationName = "emailToPickList",
transformFn = EmailToPickListMapTransformer.emailToPickList
transformFn = new EmailToPickListMapTransformer.EmailToPickList
)
)

object EmailToPickListMapTransformer {
def emailToPickList: Email => PickList = email => email.domain.toPickList

class EmailToPickList extends Function1[Email, PickList] with Serializable {
def apply(v: Email): PickList = v.domain.toPickList
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@ package com.salesforce.op.stages.impl.feature
import com.salesforce.op.UID
import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.unary.UnaryTransformer
import com.salesforce.op.utils.json.{JsonLike, JsonUtils}
import com.salesforce.op.utils.json.JsonUtils
import org.apache.spark.sql.types.{Metadata, MetadataBuilder}
import org.json4s.JsonAST.{JField, JNothing}
import org.json4s.{CustomSerializer, JObject}

import scala.reflect.runtime.universe.TypeTag
import scala.util.{Failure, Try}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import scala.reflect.runtime.universe.TypeTag
class ToOccurTransformer[I <: FeatureType]
(
uid: String = UID[ToOccurTransformer[I]],
val matchFn: I => Boolean = ToOccurTransformer.defaultMatches[I]
val matchFn: I => Boolean = new ToOccurTransformer.DefaultMatches[I]
)(implicit tti: TypeTag[I])
extends UnaryTransformer[I, RealNN](operationName = "toOccur", uid = uid) {

Expand All @@ -60,11 +60,13 @@ class ToOccurTransformer[I <: FeatureType]

object ToOccurTransformer {

def defaultMatches[T <: FeatureType]: T => Boolean = {
case num: OPNumeric[_] if num.nonEmpty => num.toDouble.get > 0.0
case text: Text if text.nonEmpty => text.value.get.length > 0
case collection: OPCollection => collection.nonEmpty
case _ => false
class DefaultMatches[T <: FeatureType] extends Function1[T, Boolean] with Serializable {
def apply(t: T): Boolean = t match {
case num: OPNumeric[_] if num.nonEmpty => num.toDouble.get > 0.0
case text: Text if text.nonEmpty => text.value.get.length > 0
case collection: OPCollection => collection.nonEmpty
case _ => false
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class OpWorkflowModelReaderWriterTest
}

trait SwSingleStageFlow {
val vec = FeatureBuilder.OPVector[Passenger].extract(OpWorkflowModelReaderWriterTest.emptyVectorFn).asPredictor
val vec = FeatureBuilder.OPVector[Passenger].extract(new OpWorkflowModelReaderWriterTest.EmptyVectorFn).asPredictor
val scaler = new StandardScaler().setWithStd(false).setWithMean(false)
val schema = FeatureSparkTypes.toStructType(vec)
val data = spark.createDataFrame(List(Row(Vectors.dense(1.0))).asJava, schema)
Expand All @@ -158,12 +158,12 @@ class OpWorkflowModelReaderWriterTest

trait OldVectorizedFlow extends UIDReset {
val cat = Seq(gender, boarded, height, age, description).transmogrify()
val catHead = cat.map[Real](OpWorkflowModelReaderWriterTest.catHeadFn)
val catHead = cat.map[Real](new OpWorkflowModelReaderWriterTest.CatHeadFn)
val wf = new OpWorkflow().setParameters(workflowParams).setResultFeatures(catHead)
}

trait VectorizedFlow extends UIDReset {
val catHead = rawFeatures.transmogrify().map[Real](OpWorkflowModelReaderWriterTest.catHeadFn)
val catHead = rawFeatures.transmogrify().map[Real](new OpWorkflowModelReaderWriterTest.CatHeadFn)
val wf = new OpWorkflow().setParameters(workflowParams).setResultFeatures(catHead)
}

Expand Down Expand Up @@ -386,6 +386,13 @@ trait UIDReset {
}

object OpWorkflowModelReaderWriterTest {
def catHeadFn: OPVector => Real = v => Real(v.value.toArray.headOption)
def emptyVectorFn: Passenger => OPVector = _ => OPVector.empty

class CatHeadFn extends Function1[OPVector, Real] with Serializable {
def apply(v: OPVector): Real = Real(v.value.toArray.headOption)
}

class EmptyVectorFn extends Function1[Passenger, OPVector] with Serializable {
def apply(v: Passenger): OPVector = OPVector.empty
}

}
51 changes: 25 additions & 26 deletions core/src/test/scala/com/salesforce/op/stages/Lambdas.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,38 +34,37 @@ import com.salesforce.op.features.types.Real
import com.salesforce.op.features.types._

object Lambdas {
def fncUnary: Real => Real = (x: Real) => x.v.map(_ * 0.1234).toReal

def fncSequence: Seq[DateList] => Real = (x: Seq[DateList]) => {
val v = x.foldLeft(0.0)((a, b) => a + b.value.sum)
Math.round(v / 1E6).toReal
class FncSequence extends Function1[Seq[DateList], Real] with Serializable {
def apply(x: Seq[DateList]): Real = {
val v = x.foldLeft(0.0)((a, b) => a + b.value.sum)
Math.round(v / 1E6).toReal
}
}

def fncBinarySequence: (Real, Seq[DateList]) => Real = (y: Real, x: Seq[DateList]) => {
val v = x.foldLeft(0.0)((a, b) => a + b.value.sum)
(Math.round(v / 1E6) + y.value.getOrElse(0.0)).toReal
class FncBinarySequence extends Function2[Real, Seq[DateList], Real] with Serializable {
def apply(y: Real, x: Seq[DateList]): Real = {
val v = x.foldLeft(0.0)((a, b) => a + b.value.sum)
(Math.round(v / 1E6) + y.value.getOrElse(0.0)).toReal
}
}

def fncBinary: (Real, Real) => Real = (x: Real, y: Real) => (
for {
yv <- y.value
xv <- x.value
} yield xv * yv
).toReal
class FncUnary extends Function1[Real, Real] with Serializable {
def apply(x: Real): Real = x.v.map(_ * 0.1234).toReal
}

def fncTernary: (Real, Real, Real) => Real = (x: Real, y: Real, z: Real) =>
(for {
xv <- x.value
yv <- y.value
zv <- z.value
} yield xv * yv + zv).toReal
class FncBinary extends Function2[Real, Real, Real] with Serializable {
def apply(x: Real, y: Real): Real = (for {yv <- y.value; xv <- x.value} yield xv * yv).toReal
}

def fncQuaternary: (Real, Real, Text, Real) => Real = (x: Real, y: Real, t: Text, z: Real) =>
(for {
xv <- x.value
yv <- y.value
tv <- t.value
zv <- z.value
} yield xv * yv + zv * tv.length).toReal
class FncTernary extends Function3[Real, Real, Real, Real] with Serializable {
def apply(x: Real, y: Real, z: Real): Real =
(for {yv <- y.value; xv <- x.value; zv <- z.value} yield xv * yv + zv).toReal
}

class FncQuaternary extends Function4[Real, Real, Text, Real, Real] with Serializable {
def apply(x: Real, y: Real, t: Text, z: Real): Real =
(for {yv <- y.value; xv <- x.value; tv <- t.value; zv <- z.value} yield xv * yv + zv * tv.length).toReal
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class OpPipelineStagesTest

val testOp = new com.salesforce.op.stages.base.unary.UnaryLambdaTransformer[Real, Real](
operationName = "test",
transformFn = OpPipelineStagesTest.fnc0,
transformFn = new OpPipelineStagesTest.RealIdentity,
uid = "myID"
)

Expand Down Expand Up @@ -162,7 +162,10 @@ class OpPipelineStagesTest
}

object OpPipelineStagesTest {
def fnc0: Real => Real = x => x

class RealIdentity extends Function1[Real, Real] with Serializable {
def apply(v: Real): Real = v
}

class TestStage(implicit val tto: TypeTag[RealNN], val ttov: TypeTag[RealNN#Value])
extends Pipeline with OpPipelineStage1[RealNN, RealNN] {
Expand Down
Loading