diff --git a/core/src/main/scala/com/salesforce/op/dsl/RichDateFeature.scala b/core/src/main/scala/com/salesforce/op/dsl/RichDateFeature.scala index ee8d1f1e20..26f4afc71c 100644 --- a/core/src/main/scala/com/salesforce/op/dsl/RichDateFeature.scala +++ b/core/src/main/scala/com/salesforce/op/dsl/RichDateFeature.scala @@ -56,7 +56,7 @@ trait RichDateFeature { f.transformWith( new UnaryLambdaTransformer[Date, DateList]( operationName = "dateToList", - RichDateFeatureLambdas.toDateList + new RichDateFeatureLambdas.ToDateList ) ) } @@ -137,7 +137,7 @@ trait RichDateFeature { f.transformWith( new UnaryLambdaTransformer[DateTime, DateTimeList]( operationName = "dateTimeToList", - RichDateFeatureLambdas.toDateTimeList + new RichDateFeatureLambdas.ToDateTimeList ) ) } @@ -204,7 +204,13 @@ trait RichDateFeature { } object RichDateFeatureLambdas { - def toDateList: Date => DateList = (x: Date) => x.value.toSeq.toDateList - def toDateTimeList: DateTime => DateTimeList = (x: DateTime) => x.value.toSeq.toDateTimeList + class ToDateList extends Function1[Date, DateList] with Serializable { + def apply(v: Date): DateList = v.value.toSeq.toDateList + } + + class ToDateTimeList extends Function1[Date, DateTimeList] with Serializable { + def apply(v: Date): DateTimeList = v.value.toSeq.toDateTimeList + } + } diff --git a/core/src/main/scala/com/salesforce/op/dsl/RichMapFeature.scala b/core/src/main/scala/com/salesforce/op/dsl/RichMapFeature.scala index 55c93d44f2..d602ec1656 100644 --- a/core/src/main/scala/com/salesforce/op/dsl/RichMapFeature.scala +++ b/core/src/main/scala/com/salesforce/op/dsl/RichMapFeature.scala @@ -30,7 +30,6 @@ package com.salesforce.op.dsl -import com.salesforce.op.dsl.RichMapFeatureLambdas._ import com.salesforce.op.features.FeatureLike import com.salesforce.op.features.types._ import com.salesforce.op.stages.impl.feature._ @@ -1098,9 +1097,10 @@ trait RichMapFeature { * @return prediction, rawPrediction, probability */ def tupled(): (FeatureLike[RealNN], FeatureLike[OPVector], FeatureLike[OPVector]) = { - (f.map[RealNN](predictionToRealNN), - f.map[OPVector](predictionToRaw), - f.map[OPVector](predictionToProbability) + import RichMapFeatureLambdas._ + (f.map[RealNN](new PredictionToRealNN), + f.map[OPVector](new PredictionToRaw), + f.map[OPVector](new PredictionToProbability) ) } @@ -1121,11 +1121,17 @@ trait RichMapFeature { object RichMapFeatureLambdas { - def predictionToRealNN: Prediction => RealNN = _.prediction.toRealNN + class PredictionToRealNN extends Function1[Prediction, RealNN] with Serializable { + def apply(p: Prediction): RealNN = p.prediction.toRealNN + } - def predictionToRaw: Prediction => OPVector = p => Vectors.dense(p.rawPrediction).toOPVector + class PredictionToRaw extends Function1[Prediction, OPVector] with Serializable { + def apply(p: Prediction): OPVector = Vectors.dense(p.rawPrediction).toOPVector + } - def predictionToProbability: Prediction => OPVector = p => Vectors.dense(p.probability).toOPVector + class PredictionToProbability extends Function1[Prediction, OPVector] with Serializable { + def apply(p: Prediction): OPVector = Vectors.dense(p.probability).toOPVector + } } diff --git a/core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala b/core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala index fa175d8153..bd78b58aa1 100644 --- a/core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala +++ b/core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala @@ -37,6 +37,8 @@ import com.salesforce.op.stages.impl.feature._ import com.salesforce.op.utils.text._ import scala.reflect.runtime.universe.TypeTag + + trait RichTextFeature { self: RichFeature => @@ -48,7 +50,7 @@ trait RichTextFeature { * * @return A new MultiPickList feature */ - def toMultiPickList: FeatureLike[MultiPickList] = f.map[MultiPickList](textToMultiPickList) + def toMultiPickList: FeatureLike[MultiPickList] = f.map[MultiPickList](new TextToMultiPickList) /** @@ -560,14 +562,14 @@ trait RichTextFeature { * * @return email prefix */ - def toEmailPrefix: FeatureLike[Text] = f.map[Text](emailToPrefix, "prefix") + def toEmailPrefix: FeatureLike[Text] = f.map[Text](new EmailPrefixToText, "prefix") /** * Extract email domains * * @return email domain */ - def toEmailDomain: FeatureLike[Text] = f.map[Text](emailToDomain, "domain") + def toEmailDomain: FeatureLike[Text] = f.map[Text](new EmailDomainToText, "domain") /** * Check if email is valid @@ -600,7 +602,7 @@ trait RichTextFeature { others: Array[FeatureLike[Email]] = Array.empty, maxPctCardinality: Double = OpOneHotVectorizer.MaxPctCardinality ): FeatureLike[OPVector] = { - val domains = (f +: others).map(_.map[PickList](emailToPickList)) + val domains = (f +: others).map(_.map[PickList](new EmailDomainToPickList)) domains.head.pivot(others = domains.tail, topK = topK, minSupport = minSupport, cleanText = cleanText, trackNulls = trackNulls, maxPctCardinality = maxPctCardinality ) @@ -613,19 +615,19 @@ trait RichTextFeature { /** * Extract url domain, i.e. salesforce.com, data.com etc. */ - def toDomain: FeatureLike[Text] = f.map[Text](urlToDomain, "urlDomain") + def toDomain: FeatureLike[Text] = f.map[Text](new URLDomainToText, "urlDomain") /** * Extracts url protocol, i.e. http, https, ftp etc. */ - def toProtocol: FeatureLike[Text] = f.map[Text](urlToProtocol, "urlProtocol") + def toProtocol: FeatureLike[Text] = f.map[Text](new URLProtocolToText, "urlProtocol") /** * Verifies if the url is of correct form of "Uniform Resource Identifiers (URI): Generic Syntax" * RFC2396 (http://www.ietf.org/rfc/rfc2396.txt) * Default valid protocols are: http, https, ftp. */ - def isValidUrl: FeatureLike[Binary] = f.exists(urlIsValid) + def isValidUrl: FeatureLike[Binary] = f.exists(new URLIsValid) /** * Converts a sequence of [[URL]] features into a vector, extracting the domains of the valid urls @@ -650,7 +652,7 @@ trait RichTextFeature { others: Array[FeatureLike[URL]] = Array.empty, maxPctCardinality: Double = OpOneHotVectorizer.MaxPctCardinality ): FeatureLike[OPVector] = { - val domains = (f +: others).map(_.map[PickList](urlToPickList)) + val domains = (f +: others).map(_.map[PickList](new URLDomainToPickList)) domains.head.pivot(others = domains.tail, topK = topK, minSupport = minSupport, cleanText = cleanText, trackNulls = trackNulls, maxPctCardinality = maxPctCardinality ) @@ -697,7 +699,7 @@ trait RichTextFeature { ): FeatureLike[OPVector] = { val feats: Array[FeatureLike[PickList]] = - (f +: others).map(_.detectMimeTypes(typeHint).map[PickList](textToPickList)) + (f +: others).map(_.detectMimeTypes(typeHint).map[PickList](new TextToPickList)) feats.head.vectorize( topK = topK, minSupport = minSupport, cleanText = cleanText, trackNulls = trackNulls, others = feats.tail, @@ -801,22 +803,40 @@ trait RichTextFeature { object RichTextFeatureLambdas { - def emailToPickList: Email => PickList = _.domain.toPickList + class EmailDomainToPickList extends Function1[Email, PickList] with Serializable { + def apply(v: Email): PickList = v.domain.toPickList + } - def emailToPrefix: Email => Text = _.prefix.toText + class EmailDomainToText extends Function1[Email, Text] with Serializable { + def apply(v: Email): Text = v.domain.toText + } - def emailToDomain: Email => Text = _.domain.toText + class EmailPrefixToText extends Function1[Email, Text] with Serializable { + def apply(v: Email): Text = v.prefix.toText + } - def urlToPickList: URL => PickList = (v: URL) => if (v.isValid) v.domain.toPickList else PickList.empty + class URLDomainToPickList extends Function1[URL, PickList] with Serializable { + def apply(v: URL): PickList = if (v.isValid) v.domain.toPickList else PickList.empty + } - def urlToDomain: URL => Text = _.domain.toText + class URLDomainToText extends Function1[URL, Text] with Serializable { + def apply(v: URL): Text = v.domain.toText + } - def urlToProtocol: URL => Text = _.protocol.toText + class URLProtocolToText extends Function1[URL, Text] with Serializable { + def apply(v: URL): Text = v.protocol.toText + } - def urlIsValid: URL => Boolean = _.isValid + class URLIsValid extends Function1[URL, Boolean] with Serializable { + def apply(v: URL): Boolean = v.isValid + } - def textToPickList: Text => PickList = _.value.toPickList + class TextToPickList extends Function1[Text, PickList] with Serializable { + def apply(v: Text): PickList = v.value.toPickList + } - def textToMultiPickList: Text => MultiPickList = _.value.toSet[String].toMultiPickList + class TextToMultiPickList extends Function1[Text, MultiPickList] with Serializable { + def apply(v: Text): MultiPickList = v.value.toSet[String].toMultiPickList + } } diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/EmailToPickListMapTransformer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/EmailToPickListMapTransformer.scala index 748bdb4058..31934e1418 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/EmailToPickListMapTransformer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/EmailToPickListMapTransformer.scala @@ -41,10 +41,14 @@ class EmailToPickListMapTransformer(uid: String = UID[EmailToPickListMapTransfor operationName = "emailToPickListMap", transformer = new UnaryLambdaTransformer[Email, PickList]( operationName = "emailToPickList", - transformFn = EmailToPickListMapTransformer.emailToPickList + transformFn = new EmailToPickListMapTransformer.EmailToPickList ) ) object EmailToPickListMapTransformer { - def emailToPickList: Email => PickList = email => email.domain.toPickList + + class EmailToPickList extends Function1[Email, PickList] with Serializable { + def apply(v: Email): PickList = v.domain.toPickList + } + } diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/ScalerTransformer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/ScalerTransformer.scala index 437737d02a..14419bec77 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/ScalerTransformer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/ScalerTransformer.scala @@ -33,10 +33,8 @@ package com.salesforce.op.stages.impl.feature import com.salesforce.op.UID import com.salesforce.op.features.types._ import com.salesforce.op.stages.base.unary.UnaryTransformer -import com.salesforce.op.utils.json.{JsonLike, JsonUtils} +import com.salesforce.op.utils.json.JsonUtils import org.apache.spark.sql.types.{Metadata, MetadataBuilder} -import org.json4s.JsonAST.{JField, JNothing} -import org.json4s.{CustomSerializer, JObject} import scala.reflect.runtime.universe.TypeTag import scala.util.{Failure, Try} diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/ToOccurTransformer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/ToOccurTransformer.scala index 066e780a58..de92c524cd 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/ToOccurTransformer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/ToOccurTransformer.scala @@ -47,7 +47,7 @@ import scala.reflect.runtime.universe.TypeTag class ToOccurTransformer[I <: FeatureType] ( uid: String = UID[ToOccurTransformer[I]], - val matchFn: I => Boolean = ToOccurTransformer.defaultMatches[I] + val matchFn: I => Boolean = new ToOccurTransformer.DefaultMatches[I] )(implicit tti: TypeTag[I]) extends UnaryTransformer[I, RealNN](operationName = "toOccur", uid = uid) { @@ -60,11 +60,13 @@ class ToOccurTransformer[I <: FeatureType] object ToOccurTransformer { - def defaultMatches[T <: FeatureType]: T => Boolean = { - case num: OPNumeric[_] if num.nonEmpty => num.toDouble.get > 0.0 - case text: Text if text.nonEmpty => text.value.get.length > 0 - case collection: OPCollection => collection.nonEmpty - case _ => false + class DefaultMatches[T <: FeatureType] extends Function1[T, Boolean] with Serializable { + def apply(t: T): Boolean = t match { + case num: OPNumeric[_] if num.nonEmpty => num.toDouble.get > 0.0 + case text: Text if text.nonEmpty => text.value.get.length > 0 + case collection: OPCollection => collection.nonEmpty + case _ => false + } } } diff --git a/core/src/test/scala/com/salesforce/op/OpWorkflowModelReaderWriterTest.scala b/core/src/test/scala/com/salesforce/op/OpWorkflowModelReaderWriterTest.scala index d65b5c2352..977c19270a 100644 --- a/core/src/test/scala/com/salesforce/op/OpWorkflowModelReaderWriterTest.scala +++ b/core/src/test/scala/com/salesforce/op/OpWorkflowModelReaderWriterTest.scala @@ -141,7 +141,7 @@ class OpWorkflowModelReaderWriterTest } trait SwSingleStageFlow { - val vec = FeatureBuilder.OPVector[Passenger].extract(OpWorkflowModelReaderWriterTest.emptyVectorFn).asPredictor + val vec = FeatureBuilder.OPVector[Passenger].extract(new OpWorkflowModelReaderWriterTest.EmptyVectorFn).asPredictor val scaler = new StandardScaler().setWithStd(false).setWithMean(false) val schema = FeatureSparkTypes.toStructType(vec) val data = spark.createDataFrame(List(Row(Vectors.dense(1.0))).asJava, schema) @@ -158,12 +158,12 @@ class OpWorkflowModelReaderWriterTest trait OldVectorizedFlow extends UIDReset { val cat = Seq(gender, boarded, height, age, description).transmogrify() - val catHead = cat.map[Real](OpWorkflowModelReaderWriterTest.catHeadFn) + val catHead = cat.map[Real](new OpWorkflowModelReaderWriterTest.CatHeadFn) val wf = new OpWorkflow().setParameters(workflowParams).setResultFeatures(catHead) } trait VectorizedFlow extends UIDReset { - val catHead = rawFeatures.transmogrify().map[Real](OpWorkflowModelReaderWriterTest.catHeadFn) + val catHead = rawFeatures.transmogrify().map[Real](new OpWorkflowModelReaderWriterTest.CatHeadFn) val wf = new OpWorkflow().setParameters(workflowParams).setResultFeatures(catHead) } @@ -386,6 +386,13 @@ trait UIDReset { } object OpWorkflowModelReaderWriterTest { - def catHeadFn: OPVector => Real = v => Real(v.value.toArray.headOption) - def emptyVectorFn: Passenger => OPVector = _ => OPVector.empty + + class CatHeadFn extends Function1[OPVector, Real] with Serializable { + def apply(v: OPVector): Real = Real(v.value.toArray.headOption) + } + + class EmptyVectorFn extends Function1[Passenger, OPVector] with Serializable { + def apply(v: Passenger): OPVector = OPVector.empty + } + } diff --git a/core/src/test/scala/com/salesforce/op/stages/Lambdas.scala b/core/src/test/scala/com/salesforce/op/stages/Lambdas.scala index 9ffb467bd8..4492350696 100644 --- a/core/src/test/scala/com/salesforce/op/stages/Lambdas.scala +++ b/core/src/test/scala/com/salesforce/op/stages/Lambdas.scala @@ -34,38 +34,37 @@ import com.salesforce.op.features.types.Real import com.salesforce.op.features.types._ object Lambdas { - def fncUnary: Real => Real = (x: Real) => x.v.map(_ * 0.1234).toReal - def fncSequence: Seq[DateList] => Real = (x: Seq[DateList]) => { - val v = x.foldLeft(0.0)((a, b) => a + b.value.sum) - Math.round(v / 1E6).toReal + class FncSequence extends Function1[Seq[DateList], Real] with Serializable { + def apply(x: Seq[DateList]): Real = { + val v = x.foldLeft(0.0)((a, b) => a + b.value.sum) + Math.round(v / 1E6).toReal + } } - def fncBinarySequence: (Real, Seq[DateList]) => Real = (y: Real, x: Seq[DateList]) => { - val v = x.foldLeft(0.0)((a, b) => a + b.value.sum) - (Math.round(v / 1E6) + y.value.getOrElse(0.0)).toReal + class FncBinarySequence extends Function2[Real, Seq[DateList], Real] with Serializable { + def apply(y: Real, x: Seq[DateList]): Real = { + val v = x.foldLeft(0.0)((a, b) => a + b.value.sum) + (Math.round(v / 1E6) + y.value.getOrElse(0.0)).toReal + } } - def fncBinary: (Real, Real) => Real = (x: Real, y: Real) => ( - for { - yv <- y.value - xv <- x.value - } yield xv * yv - ).toReal + class FncUnary extends Function1[Real, Real] with Serializable { + def apply(x: Real): Real = x.v.map(_ * 0.1234).toReal + } - def fncTernary: (Real, Real, Real) => Real = (x: Real, y: Real, z: Real) => - (for { - xv <- x.value - yv <- y.value - zv <- z.value - } yield xv * yv + zv).toReal + class FncBinary extends Function2[Real, Real, Real] with Serializable { + def apply(x: Real, y: Real): Real = (for {yv <- y.value; xv <- x.value} yield xv * yv).toReal + } - def fncQuaternary: (Real, Real, Text, Real) => Real = (x: Real, y: Real, t: Text, z: Real) => - (for { - xv <- x.value - yv <- y.value - tv <- t.value - zv <- z.value - } yield xv * yv + zv * tv.length).toReal + class FncTernary extends Function3[Real, Real, Real, Real] with Serializable { + def apply(x: Real, y: Real, z: Real): Real = + (for {yv <- y.value; xv <- x.value; zv <- z.value} yield xv * yv + zv).toReal + } + + class FncQuaternary extends Function4[Real, Real, Text, Real, Real] with Serializable { + def apply(x: Real, y: Real, t: Text, z: Real): Real = + (for {yv <- y.value; xv <- x.value; tv <- t.value; zv <- z.value} yield xv * yv + zv * tv.length).toReal + } } diff --git a/core/src/test/scala/com/salesforce/op/stages/OpPipelineStagesTest.scala b/core/src/test/scala/com/salesforce/op/stages/OpPipelineStagesTest.scala index 8e7c862f56..32d1984859 100644 --- a/core/src/test/scala/com/salesforce/op/stages/OpPipelineStagesTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/OpPipelineStagesTest.scala @@ -134,7 +134,7 @@ class OpPipelineStagesTest val testOp = new com.salesforce.op.stages.base.unary.UnaryLambdaTransformer[Real, Real]( operationName = "test", - transformFn = OpPipelineStagesTest.fnc0, + transformFn = new OpPipelineStagesTest.RealIdentity, uid = "myID" ) @@ -162,7 +162,10 @@ class OpPipelineStagesTest } object OpPipelineStagesTest { - def fnc0: Real => Real = x => x + + class RealIdentity extends Function1[Real, Real] with Serializable { + def apply(v: Real): Real = v + } class TestStage(implicit val tto: TypeTag[RealNN], val ttov: TypeTag[RealNN#Value]) extends Pipeline with OpPipelineStage1[RealNN, RealNN] { diff --git a/core/src/test/scala/com/salesforce/op/stages/OpTransformerBinaryReaderWriterTest.scala b/core/src/test/scala/com/salesforce/op/stages/OpTransformerBinaryReaderWriterTest.scala index 91ce639439..f585e8375f 100644 --- a/core/src/test/scala/com/salesforce/op/stages/OpTransformerBinaryReaderWriterTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/OpTransformerBinaryReaderWriterTest.scala @@ -44,7 +44,7 @@ class OpTransformerBinaryReaderWriterTest extends OpPipelineStageReaderWriterTes val stage = new BinaryLambdaTransformer[Real, Real, Real]( operationName = "test", - transformFn = Lambdas.fncBinary + transformFn = new Lambdas.FncBinary ).setInput(weight, age).setMetadata(meta) val expected = Array(8600.toReal, 134.toReal, Real.empty, 2574.toReal, Real.empty, 2144.toReal) diff --git a/core/src/test/scala/com/salesforce/op/stages/OpTransformerBinarySequenceReaderWriterTest.scala b/core/src/test/scala/com/salesforce/op/stages/OpTransformerBinarySequenceReaderWriterTest.scala index fc169b2dc0..cddf0a919d 100644 --- a/core/src/test/scala/com/salesforce/op/stages/OpTransformerBinarySequenceReaderWriterTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/OpTransformerBinarySequenceReaderWriterTest.scala @@ -44,7 +44,7 @@ class OpTransformerBinarySequenceReaderWriterTest extends OpPipelineStageReaderW val stage = new BinarySequenceLambdaTransformer[Real, DateList, Real]( operationName = "test", - transformFn = Lambdas.fncBinarySequence + transformFn = new Lambdas.FncBinarySequence ).setInput(weight, boarded).setMetadata(meta) val expected = Array(3114.toReal, 1538.toReal, 0.toReal, 1549.toReal, 1567.toReal, 1538.toReal) diff --git a/core/src/test/scala/com/salesforce/op/stages/OpTransformerQuaternaryReaderWriterTest.scala b/core/src/test/scala/com/salesforce/op/stages/OpTransformerQuaternaryReaderWriterTest.scala index ff839e6e25..fa338c6fdf 100644 --- a/core/src/test/scala/com/salesforce/op/stages/OpTransformerQuaternaryReaderWriterTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/OpTransformerQuaternaryReaderWriterTest.scala @@ -44,7 +44,7 @@ class OpTransformerQuaternaryReaderWriterTest extends OpPipelineStageReaderWrite val stage = new QuaternaryLambdaTransformer[Real, Real, Text, Real, Real]( operationName = "test", - transformFn = Lambdas.fncQuaternary, + transformFn = new Lambdas.FncQuaternary, uid = "uid_1234" ).setInput(weight, age, description, weight).setMetadata(meta) diff --git a/core/src/test/scala/com/salesforce/op/stages/OpTransformerReaderWriterTest.scala b/core/src/test/scala/com/salesforce/op/stages/OpTransformerReaderWriterTest.scala index 3548301670..c4243b8c5b 100644 --- a/core/src/test/scala/com/salesforce/op/stages/OpTransformerReaderWriterTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/OpTransformerReaderWriterTest.scala @@ -44,7 +44,7 @@ class OpTransformerReaderWriterTest extends OpPipelineStageReaderWriterTest { val stage = new UnaryLambdaTransformer[Real, Real]( operationName = "test", - transformFn = Lambdas.fncUnary, + transformFn = new Lambdas.FncUnary, uid = "uid_1234" ).setInput(weight).setMetadata(meta) diff --git a/core/src/test/scala/com/salesforce/op/stages/OpTransformerSequenceReaderWriterTest.scala b/core/src/test/scala/com/salesforce/op/stages/OpTransformerSequenceReaderWriterTest.scala index c6886170bd..9db2fafa82 100644 --- a/core/src/test/scala/com/salesforce/op/stages/OpTransformerSequenceReaderWriterTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/OpTransformerSequenceReaderWriterTest.scala @@ -44,7 +44,7 @@ class OpTransformerSequenceReaderWriterTest extends OpPipelineStageReaderWriterT val stage = new SequenceLambdaTransformer[DateList, Real]( operationName = "test", - transformFn = Lambdas.fncSequence, + transformFn = new Lambdas.FncSequence, uid = "uid_1234" ).setInput(boarded).setMetadata(meta) diff --git a/core/src/test/scala/com/salesforce/op/stages/OpTransformerTernaryReaderWriterTest.scala b/core/src/test/scala/com/salesforce/op/stages/OpTransformerTernaryReaderWriterTest.scala index 3d2cbbe708..5ca0c09a91 100644 --- a/core/src/test/scala/com/salesforce/op/stages/OpTransformerTernaryReaderWriterTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/OpTransformerTernaryReaderWriterTest.scala @@ -44,7 +44,7 @@ class OpTransformerTernaryReaderWriterTest extends OpPipelineStageReaderWriterTe val stage = new TernaryLambdaTransformer[Real, Real, Real, Real]( operationName = "test", - transformFn = Lambdas.fncTernary, + transformFn = new Lambdas.FncTernary, uid = "uid_1234" ).setInput(weight, age, weight).setMetadata(meta) diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/DropIndicesByTransformerTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/DropIndicesByTransformerTest.scala index c0d0116630..44d645c708 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/DropIndicesByTransformerTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/DropIndicesByTransformerTest.scala @@ -54,9 +54,9 @@ class DropIndicesByTransformerTest extends OpTransformerSpec[OPVector, DropIndic val (data, v) = TestFeatureBuilder(vecData) val meta = OpVectorMetadata(v.name, Array(TransientFeature(v).toColumnMetaData()), Map.empty).toMetadata val inputData = data.withColumn(v.name, col(v.name).as(v.name, meta)) - val stage = new DropIndicesByTransformer( - DropIndicesByTransformerTest.matchFn - ).setInput(v).setInputSchema(inputData.schema) + val stage = + new DropIndicesByTransformer(new DropIndicesByTransformerTest.MatchFn) + .setInput(v).setInputSchema(inputData.schema) inputData -> stage } @@ -85,7 +85,7 @@ class DropIndicesByTransformerTest extends OpTransformerSpec[OPVector, DropIndic val rawMeta = OpVectorMetadata(vectorizedPicklist.name, vectorizedPicklist.originStage.getMetadata()) val trimmedMeta = OpVectorMetadata(materializedFeatures.schema(prunedVector.name)) rawMeta.columns.length - 1 shouldBe trimmedMeta.columns.length - trimmedMeta.columns.foreach(_.indicatorValue == "Red" shouldBe false) + trimmedMeta.columns.foreach(_.indicatorValue.contains("Red") shouldBe false) } it should "work with its shortcut" in { @@ -109,12 +109,15 @@ class DropIndicesByTransformerTest extends OpTransformerSpec[OPVector, DropIndic val nonSer = new NonSerializable(5) val vectorizedPicklist = picklistFeature.vectorize(topK = 10, minSupport = 3, cleanText = false) intercept[IllegalArgumentException]( - vectorizedPicklist.dropIndicesBy(_.indicatorValue.get == nonSer.in) + vectorizedPicklist.dropIndicesBy(_.indicatorValue.get == nonSer.in.toString) ).getMessage shouldBe "Provided function is not serializable" } } object DropIndicesByTransformerTest { - def matchFn: OpVectorColumnMetadata => Boolean = _.isNullIndicator + + class MatchFn extends Function1[OpVectorColumnMetadata, Boolean] with Serializable { + def apply(m: OpVectorColumnMetadata): Boolean = m.isNullIndicator + } } diff --git a/features/src/test/scala/com/salesforce/op/stages/base/binary/BinaryTransformerTest.scala b/features/src/test/scala/com/salesforce/op/stages/base/binary/BinaryTransformerTest.scala index d92ef5641f..bb628b3ae1 100644 --- a/features/src/test/scala/com/salesforce/op/stages/base/binary/BinaryTransformerTest.scala +++ b/features/src/test/scala/com/salesforce/op/stages/base/binary/BinaryTransformerTest.scala @@ -43,7 +43,7 @@ class BinaryTransformerTest extends OpTransformerSpec[Real, BinaryTransformer[Re val (inputData, f1, f2) = TestFeatureBuilder(sample) val transformer = new BinaryLambdaTransformer[Real, RealNN, Real]( - operationName = "bmi", transformFn = BinaryTransformerTest.fn + operationName = "bmi", transformFn = new BinaryTransformerTest.Fun ).setInput(f1, f2) val expectedResult = Seq(Real(Double.PositiveInfinity), Real(0.5), Real.empty) @@ -51,5 +51,8 @@ class BinaryTransformerTest extends OpTransformerSpec[Real, BinaryTransformer[Re } object BinaryTransformerTest { - def fn: (Real, RealNN) => Real = (i1, i2) => new Real(for {v1 <- i1.value; v2 <- i2.value} yield v1 / (v2 * v2)) + + class Fun extends Function2[Real, RealNN, Real] with Serializable { + def apply(i1: Real, i2: RealNN): Real = new Real(for {v1 <- i1.value; v2 <- i2.value} yield v1 / (v2 * v2)) + } } diff --git a/features/src/test/scala/com/salesforce/op/stages/base/quaternary/QuaternaryTransformerTest.scala b/features/src/test/scala/com/salesforce/op/stages/base/quaternary/QuaternaryTransformerTest.scala index 0e029e8d55..fbb789b3e9 100644 --- a/features/src/test/scala/com/salesforce/op/stages/base/quaternary/QuaternaryTransformerTest.scala +++ b/features/src/test/scala/com/salesforce/op/stages/base/quaternary/QuaternaryTransformerTest.scala @@ -49,7 +49,7 @@ class QuaternaryTransformerTest val (inputData, f1, f2, f3, f4) = TestFeatureBuilder(sample) val transformer = new QuaternaryLambdaTransformer[Real, Integral, Text, Binary, Real]( - operationName = "quatro", transformFn = QuaternaryTransformerTest.fn + operationName = "quatro", transformFn = new QuaternaryTransformerTest.Fun ).setInput(f1, f2, f3, f4) val expectedResult = Seq(4.toReal, 6.toReal, 11.toReal) @@ -57,7 +57,11 @@ class QuaternaryTransformerTest } object QuaternaryTransformerTest { - def fn: (Real, Integral, Text, Binary) => Real = (r, i, t, b) => - (r.v.getOrElse(0.0) + i.toDouble.getOrElse(0.0) + b.toDouble.getOrElse(0.0) + - t.value.map(_.length.toDouble).getOrElse(0.0)).toReal + + class Fun extends Function4[Real, Integral, Text, Binary, Real] with Serializable { + def apply(r: Real, i: Integral, t: Text, b: Binary): Real = + (r.v.getOrElse(0.0) + i.toDouble.getOrElse(0.0) + b.toDouble.getOrElse(0.0) + + t.value.map(_.length.toDouble).getOrElse(0.0)).toReal + } + } diff --git a/features/src/test/scala/com/salesforce/op/stages/base/sequence/BinarySequenceTransformerTest.scala b/features/src/test/scala/com/salesforce/op/stages/base/sequence/BinarySequenceTransformerTest.scala index 6ec1bca0fc..73d59d6bec 100644 --- a/features/src/test/scala/com/salesforce/op/stages/base/sequence/BinarySequenceTransformerTest.scala +++ b/features/src/test/scala/com/salesforce/op/stages/base/sequence/BinarySequenceTransformerTest.scala @@ -50,7 +50,7 @@ class BinarySequenceTransformerTest val (inputData, f1, f2, f3) = TestFeatureBuilder(sample) val transformer = new BinarySequenceLambdaTransformer[Real, Text, MultiPickList]( - operationName = "realToMultiPicklist", transformFn = Lambda.fn + operationName = "realToMultiPicklist", transformFn = new BinarySequenceTransformerTest.Fun ).setInput(f1, f2, f3) val expectedResult = Seq( @@ -61,7 +61,11 @@ class BinarySequenceTransformerTest ).map(_.toMultiPickList) } -object Lambda { - def fn: (Real, Seq[Text]) => MultiPickList = - (r, texts) => MultiPickList(texts.map(_.value.get).toSet + r.value.get.toString) +object BinarySequenceTransformerTest { + + class Fun extends Function2[Real, Seq[Text], MultiPickList] with Serializable { + def apply(r: Real, texts: Seq[Text]): MultiPickList = + MultiPickList(texts.map(_.value.get).toSet + r.value.get.toString) + } + } diff --git a/features/src/test/scala/com/salesforce/op/stages/base/sequence/SequenceTransformerTest.scala b/features/src/test/scala/com/salesforce/op/stages/base/sequence/SequenceTransformerTest.scala index a03e3b1ec3..7dceafee1a 100644 --- a/features/src/test/scala/com/salesforce/op/stages/base/sequence/SequenceTransformerTest.scala +++ b/features/src/test/scala/com/salesforce/op/stages/base/sequence/SequenceTransformerTest.scala @@ -48,7 +48,7 @@ class SequenceTransformerTest extends OpTransformerSpec[MultiPickList, SequenceT val (inputData, f1, f2) = TestFeatureBuilder(sample) val transformer = new SequenceLambdaTransformer[Real, MultiPickList]( - operationName = "realToMultiPicklist", transformFn = SequenceTransformerTest.fn + operationName = "realToMultiPicklist", transformFn = new SequenceTransformerTest.Fun ).setInput(f1, f2) val expectedResult = Seq( @@ -61,5 +61,8 @@ class SequenceTransformerTest extends OpTransformerSpec[MultiPickList, SequenceT } object SequenceTransformerTest { - def fn: Seq[Real] => MultiPickList = value => MultiPickList(value.flatMap(_.v.map(_.toString)).toSet) + + class Fun extends Function1[Seq[Real], MultiPickList] with Serializable { + def apply(value: Seq[Real]): MultiPickList = MultiPickList(value.flatMap(_.v.map(_.toString)).toSet) + } } diff --git a/features/src/test/scala/com/salesforce/op/stages/base/ternary/TernaryTransformerTest.scala b/features/src/test/scala/com/salesforce/op/stages/base/ternary/TernaryTransformerTest.scala index 26bdd38533..2e6e56e08d 100644 --- a/features/src/test/scala/com/salesforce/op/stages/base/ternary/TernaryTransformerTest.scala +++ b/features/src/test/scala/com/salesforce/op/stages/base/ternary/TernaryTransformerTest.scala @@ -48,14 +48,18 @@ class TernaryTransformerTest extends OpTransformerSpec[Real, TernaryTransformer[ val (inputData, f1, f2, f3) = TestFeatureBuilder(sample) val transformer = new TernaryLambdaTransformer[Real, Integral, Binary, Real]( - operationName = "trio", transformFn = Lambda.fn + operationName = "trio", transformFn = new TernaryTransformerTest.Fun ).setInput(f1, f2, f3) val expectedResult = Seq(1.toReal, 5.toReal, 4.toReal) } -object Lambda { - def fn: (Real, Integral, Binary) => Real = - (r, i, b) => (r.v.getOrElse(0.0) + i.toDouble.getOrElse(0.0) + b.toDouble.getOrElse(0.0)).toReal +object TernaryTransformerTest { + + class Fun extends Function3[Real, Integral, Binary, Real] with Serializable { + def apply(r: Real, i: Integral, b: Binary): Real = + (r.v.getOrElse(0.0) + i.toDouble.getOrElse(0.0) + b.toDouble.getOrElse(0.0)).toReal + } + } diff --git a/features/src/test/scala/com/salesforce/op/stages/base/unary/UnaryTransformerTest.scala b/features/src/test/scala/com/salesforce/op/stages/base/unary/UnaryTransformerTest.scala index d7cbbfcccc..a163c659b2 100644 --- a/features/src/test/scala/com/salesforce/op/stages/base/unary/UnaryTransformerTest.scala +++ b/features/src/test/scala/com/salesforce/op/stages/base/unary/UnaryTransformerTest.scala @@ -48,7 +48,7 @@ class UnaryTransformerTest extends OpTransformerSpec[Real, UnaryLambdaTransforme * [[OpTransformer]] instance to be tested */ val transformer = new UnaryLambdaTransformer[Real, Real]( - operationName = "unary", transformFn = UnaryTransformerTest.fn + operationName = "unary", transformFn = new UnaryTransformerTest.Fun ).setInput(f1) /** @@ -59,5 +59,9 @@ class UnaryTransformerTest extends OpTransformerSpec[Real, UnaryLambdaTransforme } object UnaryTransformerTest { - def fn: Real => Real = r => r.v.map(_ * 2.0).toReal + + class Fun extends Function1[Real, Real] with Serializable { + def apply(r: Real): Real = r.v.map(_ * 2.0).toReal + } + } diff --git a/readers/src/main/scala/com/salesforce/op/test/PassengerFeaturesTest.scala b/readers/src/main/scala/com/salesforce/op/test/PassengerFeaturesTest.scala index 975bbf567a..ffd4b77398 100644 --- a/readers/src/main/scala/com/salesforce/op/test/PassengerFeaturesTest.scala +++ b/readers/src/main/scala/com/salesforce/op/test/PassengerFeaturesTest.scala @@ -34,24 +34,25 @@ import com.salesforce.op.aggregators.MaxReal import com.salesforce.op.features.types._ import com.salesforce.op.features.{FeatureBuilder, OPFeature} import org.joda.time.Duration -import PassengerFeaturesTestLambdas._ +import PassengerFeaturesTest._ trait PassengerFeaturesTest { - val age = FeatureBuilder.Real[Passenger].extract(ageFn).aggregate(MaxReal).asPredictor - val gender = FeatureBuilder.MultiPickList[Passenger].extract(genderFn).asPredictor - val genderPL = FeatureBuilder.PickList[Passenger].extract(genderPLFn).asPredictor - val height = FeatureBuilder.RealNN[Passenger].extract(heightFn).window(Duration.millis(300)).asPredictor - val heightNoWindow = FeatureBuilder.Real[Passenger].extract(heightToReal).asPredictor - val weight = FeatureBuilder.Real[Passenger].extract(weightToReal).asPredictor - val description = FeatureBuilder.Text[Passenger].extract(descriptionFn).asPredictor - val boarded = FeatureBuilder.DateList[Passenger].extract(boardedToDL).asPredictor - val stringMap = FeatureBuilder.TextMap[Passenger].extract(stringMapFn).asPredictor - val numericMap = FeatureBuilder.RealMap[Passenger].extract(numericMapFn).asPredictor - val booleanMap = FeatureBuilder.BinaryMap[Passenger].extract(booleanMapFn).asPredictor - val survived = FeatureBuilder.Binary[Passenger].extract(survivedFn).asResponse - val boardedTime = FeatureBuilder.Date[Passenger].extract(boardedTimeFn).asPredictor - val boardedTimeAsDateTime = FeatureBuilder.DateTime[Passenger].extract(boardedDTFn).asPredictor + val age = FeatureBuilder.Real[Passenger].extract(new AgeExtract).aggregate(MaxReal).asPredictor + val gender = FeatureBuilder.MultiPickList[Passenger].extract(new GenderAsMultiPickListExtract).asPredictor + val genderPL = FeatureBuilder.PickList[Passenger].extract(new GenderAsPickListExtract).asPredictor + val height = FeatureBuilder.RealNN[Passenger].extract(new HeightToRealNNExtract) + .window(Duration.millis(300)).asPredictor + val heightNoWindow = FeatureBuilder.Real[Passenger].extract(new HeightToRealExtract).asPredictor + val weight = FeatureBuilder.Real[Passenger].extract(new WeightToRealExtract).asPredictor + val description = FeatureBuilder.Text[Passenger].extract(new DescriptionExtract).asPredictor + val boarded = FeatureBuilder.DateList[Passenger].extract(new BoardedToDateListExtract).asPredictor + val stringMap = FeatureBuilder.TextMap[Passenger].extract(new StringMapExtract).asPredictor + val numericMap = FeatureBuilder.RealMap[Passenger].extract(new NumericMapExtract).asPredictor + val booleanMap = FeatureBuilder.BinaryMap[Passenger].extract(new BooleanMapExtract).asPredictor + val survived = FeatureBuilder.Binary[Passenger].extract(new SurvivedExtract).asResponse + val boardedTime = FeatureBuilder.Date[Passenger].extract(new BoardedToDateExtract).asPredictor + val boardedTimeAsDateTime = FeatureBuilder.DateTime[Passenger].extract(new BoardedToDateTimeExtract).asPredictor val rawFeatures: Array[OPFeature] = Array( survived, age, gender, height, weight, description, boarded, stringMap, numericMap, booleanMap @@ -59,19 +60,49 @@ trait PassengerFeaturesTest { } -object PassengerFeaturesTestLambdas { - def genderFn: Passenger => MultiPickList = p => Set(p.getGender).toMultiPickList - def genderPLFn: Passenger => PickList = p => p.getGender.toPickList - def heightFn: Passenger => RealNN = p => Option(p.getHeight).map(_.toDouble).toRealNN(0.0) - def heightToReal: Passenger => Real = _.getHeight.toReal - def weightToReal: Passenger => Real = _.getWeight.toReal - def descriptionFn: Passenger => Text = _.getDescription.toText - def boardedToDL: Passenger => DateList = p => Seq(p.getBoarded.toLong).toDateList - def stringMapFn: Passenger => TextMap = p => p.getStringMap.toTextMap - def numericMapFn: Passenger => RealMap = p => p.getNumericMap.toRealMap - def booleanMapFn: Passenger => BinaryMap = p => p.getBooleanMap.toBinaryMap - def survivedFn: Passenger => Binary = p => Option(p.getSurvived).map(_ == 1).toBinary - def boardedTimeFn: Passenger => Date = _.getBoarded.toLong.toDate - def boardedDTFn: Passenger => DateTime = _.getBoarded.toLong.toDateTime - def ageFn: Passenger => Real = _.getAge.toReal +object PassengerFeaturesTest { + + class GenderAsMultiPickListExtract extends Function1[Passenger, MultiPickList] with Serializable { + def apply(p: Passenger): MultiPickList = Set(p.getGender).toMultiPickList + } + class GenderAsPickListExtract extends Function1[Passenger, PickList] with Serializable { + def apply(p: Passenger): PickList = p.getGender.toPickList + } + class HeightToRealNNExtract extends Function1[Passenger, RealNN] with Serializable { + def apply(p: Passenger): RealNN = Option(p.getHeight).map(_.toDouble).toRealNN(0.0) + } + class HeightToRealExtract extends Function1[Passenger, Real] with Serializable { + def apply(p: Passenger): Real = p.getHeight.toReal + } + class WeightToRealExtract extends Function1[Passenger, Real] with Serializable { + def apply(p: Passenger): Real = p.getWeight.toReal + } + class DescriptionExtract extends Function1[Passenger, Text] with Serializable { + def apply(p: Passenger): Text = p.getDescription.toText + } + class BoardedToDateListExtract extends Function1[Passenger, DateList] with Serializable { + def apply(p: Passenger): DateList = Seq(p.getBoarded.toLong).toDateList + } + class BoardedToDateExtract extends Function1[Passenger, Date] with Serializable { + def apply(p: Passenger): Date = p.getBoarded.toLong.toDate + } + class BoardedToDateTimeExtract extends Function1[Passenger, DateTime] with Serializable { + def apply(p: Passenger): DateTime = p.getBoarded.toLong.toDateTime + } + class SurvivedExtract extends Function1[Passenger, Binary] with Serializable { + def apply(p: Passenger): Binary = Option(p.getSurvived).map(_ == 1).toBinary + } + class StringMapExtract extends Function1[Passenger, TextMap] with Serializable { + def apply(p: Passenger): TextMap = p.getStringMap.toTextMap + } + class NumericMapExtract extends Function1[Passenger, RealMap] with Serializable { + def apply(p: Passenger): RealMap = p.getNumericMap.toRealMap + } + class BooleanMapExtract extends Function1[Passenger, BinaryMap] with Serializable { + def apply(p: Passenger): BinaryMap = p.getBooleanMap.toBinaryMap + } + class AgeExtract extends Function1[Passenger, Real] with Serializable { + def apply(p: Passenger): Real = p.getAge.toReal + } + }