diff --git a/connectors/connector-datahub/src/main/java/com/alibaba/alink/common/io/catalog/datahub/datastream/source/DatahubPublicSourceFunction.java b/connectors/connector-datahub/src/main/java/com/alibaba/alink/common/io/catalog/datahub/datastream/source/DatahubPublicSourceFunction.java index f9fb7e01d..7f034e6ce 100644 --- a/connectors/connector-datahub/src/main/java/com/alibaba/alink/common/io/catalog/datahub/datastream/source/DatahubPublicSourceFunction.java +++ b/connectors/connector-datahub/src/main/java/com/alibaba/alink/common/io/catalog/datahub/datastream/source/DatahubPublicSourceFunction.java @@ -12,7 +12,7 @@ import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.io.catalog.datahub.datastream.util.DatahubClientProvider; import com.aliyun.datahub.client.model.Field; import com.aliyun.datahub.client.model.FieldType; diff --git a/connectors/connector-jdbc/connector-jdbc-sqlite/src/test/java/com/alibaba/alink/common/io/catalog/sqlite/SqliteCatalogTest.java b/connectors/connector-jdbc/connector-jdbc-sqlite/src/test/java/com/alibaba/alink/common/io/catalog/sqlite/SqliteCatalogTest.java index 43b525a6b..a9abf25ef 100644 --- a/connectors/connector-jdbc/connector-jdbc-sqlite/src/test/java/com/alibaba/alink/common/io/catalog/sqlite/SqliteCatalogTest.java +++ b/connectors/connector-jdbc/connector-jdbc-sqlite/src/test/java/com/alibaba/alink/common/io/catalog/sqlite/SqliteCatalogTest.java @@ -35,6 +35,7 @@ import org.junit.BeforeClass; import org.junit.ClassRule; import org.junit.Test; +import org.junit.Ignore; import org.junit.rules.TemporaryFolder; import java.math.BigDecimal; @@ -44,6 +45,7 @@ import java.util.Arrays; import java.util.Collections; +@Ignore public class SqliteCatalogTest { @ClassRule @@ -481,4 +483,4 @@ public void sinkBatch() throws Exception { ).collect().isEmpty() ); } -} \ No newline at end of file +} diff --git a/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaMessageDeserialization.java b/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/KafkaMessageDeserialization.java similarity index 95% rename from connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaMessageDeserialization.java rename to connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/KafkaMessageDeserialization.java index 14a3df754..ac7fab0c3 100644 --- a/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaMessageDeserialization.java +++ b/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/KafkaMessageDeserialization.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.common.io.kafka.plugin; +package com.alibaba.alink.common.io.kafka; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; diff --git a/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaSinkBuilder.java b/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/KafkaSinkBuilder.java similarity index 97% rename from connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaSinkBuilder.java rename to connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/KafkaSinkBuilder.java index 773d7a6e5..72c250354 100644 --- a/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaSinkBuilder.java +++ b/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/KafkaSinkBuilder.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.common.io.kafka.plugin; +package com.alibaba.alink.common.io.kafka; import org.apache.flink.api.common.serialization.SerializationSchema; import org.apache.flink.api.common.typeinfo.TypeInformation; diff --git a/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaSourceBuilder.java b/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/KafkaSourceBuilder.java similarity index 98% rename from connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaSourceBuilder.java rename to connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/KafkaSourceBuilder.java index e09c7fb42..be6f596de 100644 --- a/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaSourceBuilder.java +++ b/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/KafkaSourceBuilder.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.common.io.kafka.plugin; +package com.alibaba.alink.common.io.kafka; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.typeutils.RowTypeInfo; diff --git a/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaSourceSinkInPluginFactory.java b/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/KafkaSourceSinkInPluginFactory.java similarity index 95% rename from connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaSourceSinkInPluginFactory.java rename to connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/KafkaSourceSinkInPluginFactory.java index 8ff6af6c5..9303dee1c 100644 --- a/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaSourceSinkInPluginFactory.java +++ b/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/KafkaSourceSinkInPluginFactory.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.common.io.kafka.plugin; +package com.alibaba.alink.common.io.kafka; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.ml.api.misc.param.Params; @@ -9,7 +9,8 @@ import org.apache.flink.util.Preconditions; import org.apache.flink.util.StringUtils; -import com.alibaba.alink.common.io.kafka.plugin.KafkaSourceBuilder.StartupMode; +import com.alibaba.alink.common.io.kafka.KafkaSourceBuilder.StartupMode; +import com.alibaba.alink.operator.stream.sink.KafkaSourceSinkFactory; import com.alibaba.alink.params.io.KafkaSinkParams; import com.alibaba.alink.params.io.KafkaSourceParams; diff --git a/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaUtils.java b/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/KafkaUtils.java similarity index 93% rename from connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaUtils.java rename to connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/KafkaUtils.java index b4ffeb19e..63e4f5308 100644 --- a/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaUtils.java +++ b/connectors/connector-kafka/src/main/java/com/alibaba/alink/common/io/kafka/KafkaUtils.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.common.io.kafka.plugin; +package com.alibaba.alink.common.io.kafka; import java.text.SimpleDateFormat; diff --git a/core/src/main/java/com/alibaba/alink/common/AlinkGlobalConfiguration.java b/core/src/main/java/com/alibaba/alink/common/AlinkGlobalConfiguration.java index 3e59ffe01..1f4bf1147 100644 --- a/core/src/main/java/com/alibaba/alink/common/AlinkGlobalConfiguration.java +++ b/core/src/main/java/com/alibaba/alink/common/AlinkGlobalConfiguration.java @@ -1,7 +1,5 @@ package com.alibaba.alink.common; -import org.apache.flink.runtime.util.EnvironmentInformation; - import com.alibaba.alink.common.io.plugin.PluginConfig; import com.alibaba.alink.common.io.plugin.PluginDownloader; diff --git a/core/src/main/java/com/alibaba/alink/common/LocalMLEnvironment.java b/core/src/main/java/com/alibaba/alink/common/LocalMLEnvironment.java index fcc075f57..7c547b4f3 100644 --- a/core/src/main/java/com/alibaba/alink/common/LocalMLEnvironment.java +++ b/core/src/main/java/com/alibaba/alink/common/LocalMLEnvironment.java @@ -5,9 +5,11 @@ import com.alibaba.alink.operator.local.sql.CalciteFunctionCompiler; public class LocalMLEnvironment { - private static final LocalMLEnvironment INSTANCE = new LocalMLEnvironment(); - - private final LocalOpCalciteSqlExecutor sqlExecutor; + private static final ThreadLocal threadLocalEnv = ThreadLocal.withInitial(() -> new LocalMLEnvironment()); + /** + * lazy load for speed. + */ + private LocalOpCalciteSqlExecutor sqlExecutor; // Compile user defined functions. We need to use its latest classloader when executing SQL. private final CalciteFunctionCompiler calciteFunctionCompiler; @@ -16,19 +18,21 @@ public class LocalMLEnvironment { private LocalMLEnvironment() { calciteFunctionCompiler = new CalciteFunctionCompiler(Thread.currentThread().getContextClassLoader()); - sqlExecutor = new LocalOpCalciteSqlExecutor(this); lazyObjectsManager = new LocalLazyObjectsManager(); } public static LocalMLEnvironment getInstance() { - return INSTANCE; + return threadLocalEnv.get(); } public CalciteFunctionCompiler getCalciteFunctionCompiler() { return calciteFunctionCompiler; } - public LocalOpCalciteSqlExecutor getSqlExecutor() { + public synchronized LocalOpCalciteSqlExecutor getSqlExecutor() { + if (sqlExecutor == null) { + sqlExecutor = new LocalOpCalciteSqlExecutor(this); + } return sqlExecutor; } diff --git a/core/src/main/java/com/alibaba/alink/common/MLEnvironment.java b/core/src/main/java/com/alibaba/alink/common/MLEnvironment.java index e14b5bbc8..64a7ffc17 100644 --- a/core/src/main/java/com/alibaba/alink/common/MLEnvironment.java +++ b/core/src/main/java/com/alibaba/alink/common/MLEnvironment.java @@ -13,13 +13,14 @@ import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.types.Row; -import com.alibaba.alink.common.MTable.MTableKryoSerializer; +import com.alibaba.alink.common.MTable.MTableKryoSerializerV2; +import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; import com.alibaba.alink.common.lazy.LazyObjectsManager; import com.alibaba.alink.common.linalg.tensor.Tensor; import com.alibaba.alink.common.linalg.tensor.TensorKryoSerializer; -import com.alibaba.alink.common.sql.builtin.BuildInAggRegister; -import com.alibaba.alink.common.utils.DataSetConversionUtil; -import com.alibaba.alink.common.utils.DataStreamConversionUtil; +import com.alibaba.alink.common.sql.builtin.BuiltInAggRegister; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.stream.utils.DataStreamConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.TableSourceBatchOp; import com.alibaba.alink.operator.stream.StreamOperator; @@ -104,20 +105,22 @@ public MLEnvironment( this.streamEnv = streamEnv; this.streamTableEnv = streamTableEnv; if (this.env != null) { - env.addDefaultKryoSerializer(MTable.class, new MTableKryoSerializer()); + env.addDefaultKryoSerializer(MTable.class, new MTableKryoSerializerV2()); env.addDefaultKryoSerializer(Tensor.class, new TensorKryoSerializer()); } if (this.streamEnv != null) { - streamEnv.addDefaultKryoSerializer(MTable.class, new MTableKryoSerializer()); + streamEnv.addDefaultKryoSerializer(MTable.class, new MTableKryoSerializerV2()); streamEnv.addDefaultKryoSerializer(Tensor.class, new TensorKryoSerializer()); } if (this.batchTableEnv != null) { - BuildInAggRegister.registerUdf(this.batchTableEnv); - BuildInAggRegister.registerUdaf(this.batchTableEnv); + BuiltInAggRegister.registerUdf(this.batchTableEnv); + BuiltInAggRegister.registerUdtf(this.batchTableEnv); + BuiltInAggRegister.registerUdaf(this.batchTableEnv); } if (this.streamTableEnv != null) { - BuildInAggRegister.registerUdf(this.streamTableEnv); - BuildInAggRegister.registerUdaf(this.streamTableEnv); + BuiltInAggRegister.registerUdf(this.streamTableEnv); + BuiltInAggRegister.registerUdtf(this.streamTableEnv); + BuiltInAggRegister.registerUdaf(this.streamTableEnv); } } @@ -150,7 +153,7 @@ public ExecutionEnvironment getExecutionEnvironment() { env = ExecutionEnvironment.getExecutionEnvironment(); } - env.addDefaultKryoSerializer(MTable.class, new MTableKryoSerializer()); + env.addDefaultKryoSerializer(MTable.class, new MTableKryoSerializerV2()); env.addDefaultKryoSerializer(Tensor.class, new TensorKryoSerializer()); } return env; @@ -166,7 +169,7 @@ public ExecutionEnvironment getExecutionEnvironment() { public StreamExecutionEnvironment getStreamExecutionEnvironment() { if (null == streamEnv) { streamEnv = StreamExecutionEnvironment.getExecutionEnvironment(); - streamEnv.addDefaultKryoSerializer(MTable.class, new MTableKryoSerializer()); + streamEnv.addDefaultKryoSerializer(MTable.class, new MTableKryoSerializerV2()); streamEnv.addDefaultKryoSerializer(Tensor.class, new TensorKryoSerializer()); } return streamEnv; @@ -182,8 +185,9 @@ public StreamExecutionEnvironment getStreamExecutionEnvironment() { public BatchTableEnvironment getBatchTableEnvironment() { if (null == batchTableEnv) { batchTableEnv = BatchTableEnvironment.create(getExecutionEnvironment()); - BuildInAggRegister.registerUdf(this.batchTableEnv); - BuildInAggRegister.registerUdaf(this.batchTableEnv); + BuiltInAggRegister.registerUdf(batchTableEnv); + BuiltInAggRegister.registerUdtf(batchTableEnv); + BuiltInAggRegister.registerUdaf(batchTableEnv); } return batchTableEnv; } @@ -205,8 +209,10 @@ public StreamTableEnvironment getStreamTableEnvironment() { .useOldPlanner() .build() ); - BuildInAggRegister.registerUdf(this.streamTableEnv); - BuildInAggRegister.registerUdaf(this.streamTableEnv); + + BuiltInAggRegister.registerUdf(streamTableEnv); + BuiltInAggRegister.registerUdtf(streamTableEnv); + BuiltInAggRegister.registerUdaf(streamTableEnv); } return streamTableEnv; } @@ -257,7 +263,7 @@ public Table createBatchTable(Row[] rows, String[] colNames) { */ public Table createBatchTable(List rows, String[] colNames) { if (rows == null || rows.size() < 1) { - throw new IllegalArgumentException("Values can not be empty."); + throw new AkIllegalArgumentException("Values can not be empty."); } Row first = rows.iterator().next(); @@ -297,7 +303,7 @@ public Table createStreamTable(Row[] rows, String[] colNames) { */ public Table createStreamTable(List rows, String[] colNames) { if (rows == null || rows.size() < 1) { - throw new IllegalArgumentException("Values can not be empty."); + throw new AkIllegalArgumentException("Values can not be empty."); } Row first = rows.iterator().next(); diff --git a/core/src/main/java/com/alibaba/alink/common/MTable.java b/core/src/main/java/com/alibaba/alink/common/MTable.java index 47d667d26..aefd17903 100644 --- a/core/src/main/java/com/alibaba/alink/common/MTable.java +++ b/core/src/main/java/com/alibaba/alink/common/MTable.java @@ -31,8 +31,10 @@ import com.alibaba.alink.common.io.filesystem.binary.RowStreamSerializerV2; import com.alibaba.alink.common.linalg.VectorUtil; import com.alibaba.alink.common.linalg.tensor.TensorUtil; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.common.viz.DataTypeDisplayInterface; import com.alibaba.alink.operator.common.io.csv.CsvFormatter; import com.alibaba.alink.operator.common.io.csv.CsvParser; import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummarizer; @@ -197,60 +199,25 @@ public MTable select(int... colIndexes) { return MTableUtil.select(this, colIndexes); } + /** + * summary for MTable. + */ public TableSummary summary(String... selectedColNames) { - TableSchema schema = getSchema(); - TableSummarizer srt = new TableSummarizer( - schema.getFieldNames(), - TableUtil.findColIndicesWithAssertAndHint(schema, getCalcCols(schema, selectedColNames)), true); - for (Row row : this.rows) { - srt.visit(row); - } - return srt.toSummary(selectedColNames); + return subSummary(selectedColNames, 0, this.getNumRow()); } //summary for data from fromId line to endId line, include fromId and exclude endId. public TableSummary subSummary(String[] selectedColNames, int fromId, int endId) { - TableSchema schema = getSchema(); - TableSummarizer srt = new TableSummarizer( - schema.getFieldNames(), - TableUtil.findColIndicesWithAssertAndHint(schema, getCalcCols(schema, selectedColNames)), true); - for (int i = Math.max(fromId, 0); i < Math.min(endId, this.getNumRow()); i++) { - srt.visit(this.rows.get(i)); + if (null == selectedColNames || 0 == selectedColNames.length) { + selectedColNames = this.getColNames(); } - return srt.toSummary(selectedColNames); - } - - /** - * exclude columns that are not supported types and not in selected columns - */ - private static String[] getCalcCols(TableSchema tableSchema, String[] selectedColNames) { - ArrayList calcCols = new ArrayList <>(); - String[] inColNames = selectedColNames.length == 0 ? tableSchema.getFieldNames() : selectedColNames; - int[] colIndices = TableUtil.findColIndices(tableSchema, inColNames); - TypeInformation [] inColTypes = tableSchema.getFieldTypes(); - - for (int i = 0; i < inColNames.length; i++) { - if (isSupportedType(inColTypes[colIndices[i]])) { - calcCols.add(inColNames[i]); - } + TableSchema schema = new TableSchema(selectedColNames, TableUtil.findColTypes(getSchema(), selectedColNames)); + int[] selectColIndices = TableUtil.findColIndices(getSchema(), selectedColNames); + TableSummarizer srt = new TableSummarizer(schema, false); + for (int i = Math.max(fromId, 0); i < Math.min(endId, this.getNumRow()); i++) { + srt.visit(Row.project(this.rows.get(i), selectColIndices)); } - - return calcCols.toArray(new String[0]); - } - - private static boolean isSupportedType(TypeInformation dataType) { - return Types.DOUBLE.equals(dataType) - || Types.LONG.equals(dataType) - || Types.BYTE.equals(dataType) - || Types.INT.equals(dataType) - || Types.FLOAT.equals(dataType) - || Types.SHORT.equals(dataType) - || Types.BIG_DEC.equals(dataType) - || Types.BOOLEAN.equals(dataType); - } - - public void printSummary(String... selectedColNames) { - System.out.println(summary(selectedColNames).toString()); + return srt.toSummary(); } /** diff --git a/core/src/main/java/com/alibaba/alink/common/annotation/ParamAnnotationUtils.java b/core/src/main/java/com/alibaba/alink/common/annotation/ParamAnnotationUtils.java index 46e7971c1..93b017fa7 100644 --- a/core/src/main/java/com/alibaba/alink/common/annotation/ParamAnnotationUtils.java +++ b/core/src/main/java/com/alibaba/alink/common/annotation/ParamAnnotationUtils.java @@ -1,17 +1,24 @@ package com.alibaba.alink.common.annotation; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.ml.api.misc.param.WithParams; + +import org.reflections.Reflections; import java.lang.annotation.Annotation; import java.util.ArrayList; import java.util.Arrays; +import java.util.Comparator; import java.util.HashSet; import java.util.List; import java.util.Queue; import java.util.Set; import java.util.concurrent.LinkedBlockingDeque; +import java.util.stream.Collectors; public class ParamAnnotationUtils { + final static String BASE_PARAMS_PKG_NAME = "com.alibaba.alink.params"; + final static Class [] PARAM_BASES = new Class[] {WithParams.class}; static List > getAllInterfaces(Class clz) { Set > visited = new HashSet <>(); @@ -118,4 +125,17 @@ public static HashSet > getAllowedTypes(ParamSelectColumnSpe } return s; } + + public static List > listParamInfos(Class ... bases) { + Reflections ref = new Reflections(BASE_PARAMS_PKG_NAME); + List > params = new ArrayList <>(); + for (Class base : bases) { + params.addAll(ref.getSubTypesOf(base)); + } + return params.stream() + .filter(PublicOperatorUtils::isPublicUsable) + .sorted(Comparator.comparing(Class::toString)) + .collect(Collectors.toList()); + } + } diff --git a/core/src/main/java/com/alibaba/alink/common/annotation/PortDesc.java b/core/src/main/java/com/alibaba/alink/common/annotation/PortDesc.java index 13fcd2f40..18e5f7e2b 100644 --- a/core/src/main/java/com/alibaba/alink/common/annotation/PortDesc.java +++ b/core/src/main/java/com/alibaba/alink/common/annotation/PortDesc.java @@ -38,7 +38,14 @@ public enum PortDesc implements Internationalizable { ASSOCIATION_PATTERNS, ASSOCIATION_RULES, GRPAH_EDGES, - GRAPH_VERTICES; + GRAPH_VERTICES, + INPUT_DICT_DATA, + INPUT_QUERY_DATA, + FEATURE_FREQUENCY, + + SIMILAR_ITEM_PAIRS, + FEATURE_HASH_RESULTS, + MIN_HASH_RESULTS; public static final ResourceBundle PORT_DESC_CN_BUNDLE = ResourceBundle.getBundle( "i18n/port_desc", new Locale("zh", "CN"), new UTF8Control()); diff --git a/core/src/main/java/com/alibaba/alink/common/annotation/PublicOperatorUtils.java b/core/src/main/java/com/alibaba/alink/common/annotation/PublicOperatorUtils.java index 2addfbf69..a4e5538b3 100644 --- a/core/src/main/java/com/alibaba/alink/common/annotation/PublicOperatorUtils.java +++ b/core/src/main/java/com/alibaba/alink/common/annotation/PublicOperatorUtils.java @@ -1,6 +1,6 @@ package com.alibaba.alink.common.annotation; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.AlgoOperator; import com.alibaba.alink.pipeline.Pipeline; import com.alibaba.alink.pipeline.PipelineModel; diff --git a/core/src/main/java/com/alibaba/alink/common/annotation/TypeCollections.java b/core/src/main/java/com/alibaba/alink/common/annotation/TypeCollections.java index 2ae250104..f3f05208e 100644 --- a/core/src/main/java/com/alibaba/alink/common/annotation/TypeCollections.java +++ b/core/src/main/java/com/alibaba/alink/common/annotation/TypeCollections.java @@ -1,21 +1,21 @@ package com.alibaba.alink.common.annotation; -import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; -import static com.alibaba.alink.common.AlinkTypes.BOOL_TENSOR; -import static com.alibaba.alink.common.AlinkTypes.BYTE_TENSOR; -import static com.alibaba.alink.common.AlinkTypes.DENSE_VECTOR; -import static com.alibaba.alink.common.AlinkTypes.DOUBLE_TENSOR; -import static com.alibaba.alink.common.AlinkTypes.FLOAT_TENSOR; -import static com.alibaba.alink.common.AlinkTypes.INT_TENSOR; -import static com.alibaba.alink.common.AlinkTypes.LONG_TENSOR; -import static com.alibaba.alink.common.AlinkTypes.M_TABLE; -import static com.alibaba.alink.common.AlinkTypes.SPARSE_VECTOR; -import static com.alibaba.alink.common.AlinkTypes.STRING_TENSOR; -import static com.alibaba.alink.common.AlinkTypes.TENSOR; -import static com.alibaba.alink.common.AlinkTypes.UBYTE_TENSOR; -import static com.alibaba.alink.common.AlinkTypes.VECTOR; +import static com.alibaba.alink.common.type.AlinkTypes.BOOL_TENSOR; +import static com.alibaba.alink.common.type.AlinkTypes.BYTE_TENSOR; +import static com.alibaba.alink.common.type.AlinkTypes.DENSE_VECTOR; +import static com.alibaba.alink.common.type.AlinkTypes.DOUBLE_TENSOR; +import static com.alibaba.alink.common.type.AlinkTypes.FLOAT_TENSOR; +import static com.alibaba.alink.common.type.AlinkTypes.INT_TENSOR; +import static com.alibaba.alink.common.type.AlinkTypes.LONG_TENSOR; +import static com.alibaba.alink.common.type.AlinkTypes.M_TABLE; +import static com.alibaba.alink.common.type.AlinkTypes.SPARSE_VECTOR; +import static com.alibaba.alink.common.type.AlinkTypes.STRING_TENSOR; +import static com.alibaba.alink.common.type.AlinkTypes.TENSOR; +import static com.alibaba.alink.common.type.AlinkTypes.UBYTE_TENSOR; +import static com.alibaba.alink.common.type.AlinkTypes.VARBINARY; +import static com.alibaba.alink.common.type.AlinkTypes.VECTOR; import static org.apache.flink.api.common.typeinfo.Types.BIG_DEC; import static org.apache.flink.api.common.typeinfo.Types.BIG_INT; import static org.apache.flink.api.common.typeinfo.Types.BOOLEAN; @@ -24,7 +24,6 @@ import static org.apache.flink.api.common.typeinfo.Types.FLOAT; import static org.apache.flink.api.common.typeinfo.Types.INT; import static org.apache.flink.api.common.typeinfo.Types.LONG; -import static org.apache.flink.api.common.typeinfo.Types.PRIMITIVE_ARRAY; import static org.apache.flink.api.common.typeinfo.Types.SHORT; import static org.apache.flink.api.common.typeinfo.Types.SQL_DATE; import static org.apache.flink.api.common.typeinfo.Types.SQL_TIME; @@ -60,7 +59,7 @@ public enum TypeCollections { SQL_TIMESTAMP ), MTABLE_TYPES( - M_TABLE + M_TABLE, STRING ), TENSOR_TYPES( TENSOR, @@ -74,8 +73,12 @@ public enum TypeCollections { NAIVE_BAYES_CATEGORICAL_TYPES( STRING, BOOLEAN, BIG_INT, INT, LONG ), + NUMERIC_AND_VECTOR_TYPES( + INT, LONG, SHORT, BYTE, DOUBLE, FLOAT, BIG_DEC, BIG_INT, + STRING, VECTOR, SPARSE_VECTOR, DENSE_VECTOR + ), BYTES_TYPES( - PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO + VARBINARY ); private final TypeInformation [] types; diff --git a/core/src/main/java/com/alibaba/alink/common/comqueue/BaseComQueue.java b/core/src/main/java/com/alibaba/alink/common/comqueue/BaseComQueue.java index 218e22200..4a3d9c7b0 100644 --- a/core/src/main/java/com/alibaba/alink/common/comqueue/BaseComQueue.java +++ b/core/src/main/java/com/alibaba/alink/common/comqueue/BaseComQueue.java @@ -1,6 +1,5 @@ package com.alibaba.alink.common.comqueue; -import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.functions.BroadcastVariableInitializer; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.MapPartitionFunction; @@ -79,17 +78,14 @@ public Q add(ComQueueItem com) { return thisAsQ(); } - @VisibleForTesting int getMaxIter() { return maxIter; } - @VisibleForTesting List getQueue() { return queue; } - @VisibleForTesting CompareCriterionFunction getCompareCriterion() { return compareCriterion; } @@ -607,7 +603,6 @@ public void mapPartition(Iterable values, Collector out) throws Exc // } // } - @VisibleForTesting static class DistributeData extends ComputeFunction { private static final long serialVersionUID = -1105584217517972610L; private final List cacheDataObjNames; diff --git a/core/src/main/java/com/alibaba/alink/common/dl/BaseEasyTransferTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/common/dl/BaseEasyTransferTrainBatchOp.java index 88977d33e..e62ae5437 100644 --- a/core/src/main/java/com/alibaba/alink/common/dl/BaseEasyTransferTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/common/dl/BaseEasyTransferTrainBatchOp.java @@ -16,7 +16,7 @@ import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.dl.utils.PythonFileUtils; import com.alibaba.alink.common.io.plugin.ResourcePluginFactory; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; diff --git a/core/src/main/java/com/alibaba/alink/common/dl/BaseKerasSequentialTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/common/dl/BaseKerasSequentialTrainBatchOp.java index 97ae7fcfb..8d6d326c6 100644 --- a/core/src/main/java/com/alibaba/alink/common/dl/BaseKerasSequentialTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/common/dl/BaseKerasSequentialTrainBatchOp.java @@ -12,7 +12,7 @@ import org.apache.flink.util.Collector; import org.apache.flink.util.StringUtils; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.Internal; @@ -21,7 +21,7 @@ import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; diff --git a/core/src/main/java/com/alibaba/alink/common/dl/DLLauncherBatchOp.java b/core/src/main/java/com/alibaba/alink/common/dl/DLLauncherBatchOp.java index ab2ff4503..4d6e9ffb6 100644 --- a/core/src/main/java/com/alibaba/alink/common/dl/DLLauncherBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/common/dl/DLLauncherBatchOp.java @@ -38,7 +38,7 @@ import com.alibaba.alink.common.io.plugin.OsUtils; import com.alibaba.alink.common.io.plugin.ResourcePluginFactory; import com.alibaba.alink.common.linalg.tensor.Tensor; -import com.alibaba.alink.common.utils.DataSetUtil; +import com.alibaba.alink.operator.batch.utils.DataSetUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.dataproc.FirstReducer; diff --git a/core/src/main/java/com/alibaba/alink/common/dl/DLLauncherStreamOp.java b/core/src/main/java/com/alibaba/alink/common/dl/DLLauncherStreamOp.java index ac59e96ca..df6904cf7 100644 --- a/core/src/main/java/com/alibaba/alink/common/dl/DLLauncherStreamOp.java +++ b/core/src/main/java/com/alibaba/alink/common/dl/DLLauncherStreamOp.java @@ -8,10 +8,12 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.configuration.Configuration; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.IterativeStream.ConnectedIterativeStreams; +import org.apache.flink.streaming.api.functions.co.RichCoFlatMapFunction; import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; @@ -24,21 +26,39 @@ import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.dl.DLEnvConfig.Version; +import com.alibaba.alink.common.dl.utils.DLClusterUtils; import com.alibaba.alink.common.dl.utils.DLTypeUtils; import com.alibaba.alink.common.dl.utils.DLUtils; +import com.alibaba.alink.common.dl.utils.ExternalFilesUtils; import com.alibaba.alink.common.dl.utils.PythonFileUtils; +import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; import com.alibaba.alink.common.io.plugin.ResourcePluginFactory; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.stream.StreamOperator; import com.alibaba.alink.params.dl.DLLauncherParams; +import com.alibaba.flink.ml.cluster.ExecutionMode; +import com.alibaba.flink.ml.cluster.MLConfig; +import com.alibaba.flink.ml.cluster.node.MLContext; +import com.alibaba.flink.ml.data.DataExchange; import com.alibaba.flink.ml.tensorflow2.client.DLConfig; +import com.alibaba.flink.ml.util.IpHostUtil; import com.alibaba.flink.ml.util.MLConstants; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.File; +import java.io.IOException; +import java.io.InterruptedIOException; +import java.net.ServerSocket; +import java.nio.file.Paths; +import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.Queue; import java.util.Set; +import java.util.concurrent.FutureTask; import static com.alibaba.alink.common.dl.utils.DLLauncherUtils.adjustNumWorkersPSs; @@ -242,4 +262,192 @@ public Row map(Row value) throws Exception { setOutput(output, outputSchema); return this; } + + /** + * This co-flatmap function provides similar functions with {@link DLClusterMapPartitionFunc}, except that this one is + * designed for stream scenario. + *

+ * The following steps are performed in order: + *

+ * 1. Collect IP/port information of all workers and broadcast to all workers. + *

+ * 2. Prepare Python environment, and start the TF cluster. + *

+ * 3. Process data stream. + */ + public static class DLStreamCoFlatMapFunc extends RichCoFlatMapFunction { + + private static final Logger LOG = LoggerFactory.getLogger(DLStreamCoFlatMapFunc.class); + + private transient DataExchange dataExchange; + private FutureTask serverFuture; + private volatile Collector collector = null; + + private MLContext mlContext; + private final ResourcePluginFactory factory; + private final MLConfig mlConfig; + + private final int numWorkers; + private final int numPSs; + private int taskId; + + private final List > taskIpPorts = new ArrayList <>(); + private boolean isTfClusterStarted = false; + + private final Queue cachedRows = new ArrayDeque <>(); + + private boolean firstItem = true; + + public DLStreamCoFlatMapFunc(MLConfig mlConfig, int numWorkers, int numPSs, ResourcePluginFactory factory) { + this.factory = factory; + this.mlConfig = mlConfig; + this.numWorkers = numWorkers; + this.numPSs = numPSs; + } + + public static void prepareExternalFiles(MLContext mlContext, String workDir) throws Exception { + String entryFunc = mlContext.getProperties().get(DLConstants.ENTRY_FUNC); + mlContext.setFuncName(entryFunc); + DLUtils.safePutProperties(mlContext, DLConstants.WORK_DIR, workDir); + workDir = new File(workDir).getAbsolutePath(); + + ExternalFilesConfig externalFilesConfig = + ExternalFilesConfig.fromJson(mlContext.getProperties().get(DLConstants.EXTERNAL_FILE_CONFIG_JSON)); + ExternalFilesUtils.prepareExternalFiles(externalFilesConfig, workDir); + + String entryScript = mlContext.getProperties().get(DLConstants.ENTRY_SCRIPT); + String entryScriptName = PythonFileUtils.getFileName(entryScript); + mlContext.setPythonDir(Paths.get(workDir)); + mlContext.setPythonFiles(new String[] {entryScriptName}); + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + this.taskId = getRuntimeContext().getIndexOfThisSubtask(); + } + + @Override + public void close() throws Exception { + super.close(); + if (isTfClusterStarted) { + DLClusterUtils.stopCluster(mlContext, serverFuture, (Void) -> drainRead(collector, true)); + mlContext = null; + isTfClusterStarted = false; + } + } + + @Override + public void flatMap1(Row value, Collector out) throws Exception { + if (!isTfClusterStarted) { + if (firstItem) { + ServerSocket serverSocket = IpHostUtil.getFreeSocket(); + int port = serverSocket.getLocalPort(); + serverSocket.close(); + String localIp = IpHostUtil.getIpAddress(); + for (int i = 0; i < numWorkers + numPSs; i += 1) { + out.collect(Row.of(-1, i, Row.of(taskId, localIp, port))); + } + System.out.println(String.format("%d select %s:%d", taskId, localIp, port)); + //out.collect(Row.of(String.format("%d-%s-%d", taskId, localIp, port))); + collector = out; + firstItem = false; + } + if (taskId < numWorkers) { // no input for ps nodes + cachedRows.add((Row) value.getField(1)); + } + } else { + if (taskId < numWorkers) { + dataExchange.write(DLUtils.encodeStringValue((Row) value.getField(1))); + drainRead(out, false); + } + } + } + + private void startDLCluster() { + System.out.println("Starting DL cluster..."); + try { + mlContext = DLClusterUtils.makeMLContext(taskId, mlConfig, ExecutionMode.TRAIN); + Map properties = mlConfig.getProperties(); + String workDir = properties.get(MLConstants.WORK_DIR); + DLClusterUtils.setMLContextIpPorts(taskId, mlContext, taskIpPorts); + + prepareExternalFiles(mlContext, workDir); + // Update external files-related properties according to workDir + { + String pythonEnv = properties.get(DLConstants.PYTHON_ENV); + if (StringUtils.isNullOrWhitespaceOnly(pythonEnv)) { + Version version = Version.valueOf(properties.get(DLConstants.ENV_VERSION)); + LOG.info(String.format("Use pythonEnv from plugin: %s", version)); + pythonEnv = DLEnvConfig.getDefaultPythonEnv(factory, version); + properties.put(MLConstants.VIRTUAL_ENV_DIR, pythonEnv.substring("file://".length())); + } else { + if (PythonFileUtils.isLocalFile(pythonEnv)) { + properties.put(MLConstants.VIRTUAL_ENV_DIR, pythonEnv.substring("file://".length())); + } else { + properties.put(MLConstants.VIRTUAL_ENV_DIR, new File(workDir, pythonEnv).getAbsolutePath()); + } + } + String entryScriptFileName = PythonFileUtils.getFileName(properties.get(DLConstants.ENTRY_SCRIPT)); + mlContext.setPythonDir(new File(workDir).toPath()); + mlContext.setPythonFiles(new String[] {new File(workDir, entryScriptFileName).getAbsolutePath()}); + } + + Tuple3 , FutureTask , Thread> dataExchangeFutureTaskThreadTuple3 + = DLClusterUtils.startDLCluster(mlContext); + dataExchange = dataExchangeFutureTaskThreadTuple3.f0; + serverFuture = dataExchangeFutureTaskThreadTuple3.f1; + } catch (Exception ex) { + throw new AkUnclassifiedErrorException("Start TF cluster failed: ", ex); + } + } + + /** + * collect ip and port to start the cluster. + * @param value + * @param out + * @throws Exception + */ + @Override + public void flatMap2(Row value, Collector out) throws Exception { + value = (Row) value.getField(2); + System.out.println(String.format("task %d received address: %s", taskId, value)); + taskIpPorts.add( + Tuple3.of((Integer) value.getField(0), (String) value.getField(1), (Integer) value.getField(2))); + if (taskIpPorts.size() == numWorkers + numPSs) { + startDLCluster(); + isTfClusterStarted = true; + System.out.println(String.format("task %d: TF cluster started", taskId)); + System.out.println(String.format("task %d: Handling %d cached rows", taskId, cachedRows.size())); + while (!cachedRows.isEmpty()) { + dataExchange.write(DLUtils.encodeStringValue(cachedRows.remove())); + drainRead(out, false); + } + } + } + + private void drainRead(Collector out, boolean readUntilEOF) { + while (true) { + try { + Row r = dataExchange.read(readUntilEOF); + if (r != null) { + out.collect(Row.of(1, r)); + } else { + break; + } + } catch (InterruptedIOException iioe) { + LOG.info("{} Reading from is interrupted, canceling the server", mlContext.getIdentity()); + serverFuture.cancel(true); + } catch (IOException e) { + LOG.error("Fail to read data from python.", e); + throw new AkUnclassifiedErrorException("Fail to read data from python.", e); + } + } + } + // + //@Override + //public TypeInformation getProducedType() { + // return tfFlatMapFunction.getProducedType(); + //} + } } diff --git a/core/src/main/java/com/alibaba/alink/common/dl/DLStreamCoFlatMapFunc.java b/core/src/main/java/com/alibaba/alink/common/dl/DLStreamCoFlatMapFunc.java deleted file mode 100644 index 9885fdbd4..000000000 --- a/core/src/main/java/com/alibaba/alink/common/dl/DLStreamCoFlatMapFunc.java +++ /dev/null @@ -1,224 +0,0 @@ -package com.alibaba.alink.common.dl; - -import org.apache.flink.api.java.tuple.Tuple3; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.streaming.api.functions.co.RichCoFlatMapFunction; -import org.apache.flink.types.Row; -import org.apache.flink.util.Collector; -import org.apache.flink.util.StringUtils; - -import com.alibaba.alink.common.dl.DLEnvConfig.Version; -import com.alibaba.alink.common.dl.utils.DLClusterUtils; -import com.alibaba.alink.common.dl.utils.DLUtils; -import com.alibaba.alink.common.dl.utils.ExternalFilesUtils; -import com.alibaba.alink.common.dl.utils.PythonFileUtils; -import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; -import com.alibaba.alink.common.io.plugin.ResourcePluginFactory; -import com.alibaba.flink.ml.cluster.ExecutionMode; -import com.alibaba.flink.ml.cluster.MLConfig; -import com.alibaba.flink.ml.cluster.node.MLContext; -import com.alibaba.flink.ml.data.DataExchange; -import com.alibaba.flink.ml.util.IpHostUtil; -import com.alibaba.flink.ml.util.MLConstants; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.File; -import java.io.IOException; -import java.io.InterruptedIOException; -import java.net.ServerSocket; -import java.nio.file.Paths; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Queue; -import java.util.concurrent.FutureTask; - -/** - * This co-flatmap function provides similar functions with {@link DLClusterMapPartitionFunc}, except that this one is - * designed for stream scenario. - *

- * The following steps are performed in order: - *

- * 1. Collect IP/port information of all workers and broadcast to all workers. - *

- * 2. Prepare Python environment, and start the TF cluster. - *

- * 3. Process data stream. - */ -public class DLStreamCoFlatMapFunc extends RichCoFlatMapFunction { - - private static final Logger LOG = LoggerFactory.getLogger(DLStreamCoFlatMapFunc.class); - - private transient DataExchange dataExchange; - private FutureTask serverFuture; - private volatile Collector collector = null; - - private MLContext mlContext; - private final ResourcePluginFactory factory; - private final MLConfig mlConfig; - - private final int numWorkers; - private final int numPSs; - private int taskId; - - private final List > taskIpPorts = new ArrayList <>(); - private boolean isTfClusterStarted = false; - - private final Queue cachedRows = new ArrayDeque <>(); - - private boolean firstItem = true; - - public DLStreamCoFlatMapFunc(MLConfig mlConfig, int numWorkers, int numPSs, ResourcePluginFactory factory) { - this.factory = factory; - this.mlConfig = mlConfig; - this.numWorkers = numWorkers; - this.numPSs = numPSs; - } - - public static void prepareExternalFiles(MLContext mlContext, String workDir) throws Exception { - String entryFunc = mlContext.getProperties().get(DLConstants.ENTRY_FUNC); - mlContext.setFuncName(entryFunc); - DLUtils.safePutProperties(mlContext, DLConstants.WORK_DIR, workDir); - workDir = new File(workDir).getAbsolutePath(); - - ExternalFilesConfig externalFilesConfig = - ExternalFilesConfig.fromJson(mlContext.getProperties().get(DLConstants.EXTERNAL_FILE_CONFIG_JSON)); - ExternalFilesUtils.prepareExternalFiles(externalFilesConfig, workDir); - - String entryScript = mlContext.getProperties().get(DLConstants.ENTRY_SCRIPT); - String entryScriptName = PythonFileUtils.getFileName(entryScript); - mlContext.setPythonDir(Paths.get(workDir)); - mlContext.setPythonFiles(new String[] {entryScriptName}); - } - - @Override - public void open(Configuration parameters) throws Exception { - super.open(parameters); - this.taskId = getRuntimeContext().getIndexOfThisSubtask(); - } - - @Override - public void close() throws Exception { - super.close(); - if (isTfClusterStarted) { - DLClusterUtils.stopCluster(mlContext, serverFuture, (Void) -> drainRead(collector, true)); - mlContext = null; - isTfClusterStarted = false; - } - } - - @Override - public void flatMap1(Row value, Collector out) throws Exception { - if (!isTfClusterStarted) { - if (firstItem) { - ServerSocket serverSocket = IpHostUtil.getFreeSocket(); - int port = serverSocket.getLocalPort(); - serverSocket.close(); - String localIp = IpHostUtil.getIpAddress(); - for (int i = 0; i < numWorkers + numPSs; i += 1) { - out.collect(Row.of(-1, i, Row.of(taskId, localIp, port))); - } - System.out.println(String.format("%d select %s:%d", taskId, localIp, port)); - //out.collect(Row.of(String.format("%d-%s-%d", taskId, localIp, port))); - collector = out; - firstItem = false; - } - if (taskId < numWorkers) { // no input for ps nodes - cachedRows.add((Row) value.getField(1)); - } - } else { - if (taskId < numWorkers) { - dataExchange.write(DLUtils.encodeStringValue((Row) value.getField(1))); - drainRead(out, false); - } - } - } - - private void startDLCluster() { - System.out.println("Starting DL cluster..."); - try { - mlContext = DLClusterUtils.makeMLContext(taskId, mlConfig, ExecutionMode.TRAIN); - Map properties = mlConfig.getProperties(); - String workDir = properties.get(MLConstants.WORK_DIR); - DLClusterUtils.setMLContextIpPorts(taskId, mlContext, taskIpPorts); - - prepareExternalFiles(mlContext, workDir); - // Update external files-related properties according to workDir - { - String pythonEnv = properties.get(DLConstants.PYTHON_ENV); - if (StringUtils.isNullOrWhitespaceOnly(pythonEnv)) { - Version version = Version.valueOf(properties.get(DLConstants.ENV_VERSION)); - LOG.info(String.format("Use pythonEnv from plugin: %s", version)); - pythonEnv = DLEnvConfig.getDefaultPythonEnv(factory, version); - properties.put(MLConstants.VIRTUAL_ENV_DIR, pythonEnv.substring("file://".length())); - } else { - if (PythonFileUtils.isLocalFile(pythonEnv)) { - properties.put(MLConstants.VIRTUAL_ENV_DIR, pythonEnv.substring("file://".length())); - } else { - properties.put(MLConstants.VIRTUAL_ENV_DIR, new File(workDir, pythonEnv).getAbsolutePath()); - } - } - String entryScriptFileName = PythonFileUtils.getFileName(properties.get(DLConstants.ENTRY_SCRIPT)); - mlContext.setPythonDir(new File(workDir).toPath()); - mlContext.setPythonFiles(new String[] {new File(workDir, entryScriptFileName).getAbsolutePath()}); - } - - Tuple3 , FutureTask , Thread> dataExchangeFutureTaskThreadTuple3 - = DLClusterUtils.startDLCluster(mlContext); - dataExchange = dataExchangeFutureTaskThreadTuple3.f0; - serverFuture = dataExchangeFutureTaskThreadTuple3.f1; - } catch (Exception ex) { - throw new AkUnclassifiedErrorException("Start TF cluster failed: ", ex); - } - } - - /** - * collect ip and port to start the cluster. - * @param value - * @param out - * @throws Exception - */ - @Override - public void flatMap2(Row value, Collector out) throws Exception { - value = (Row) value.getField(2); - System.out.println(String.format("task %d received address: %s", taskId, value)); - taskIpPorts.add( - Tuple3.of((Integer) value.getField(0), (String) value.getField(1), (Integer) value.getField(2))); - if (taskIpPorts.size() == numWorkers + numPSs) { - startDLCluster(); - isTfClusterStarted = true; - System.out.println(String.format("task %d: TF cluster started", taskId)); - System.out.println(String.format("task %d: Handling %d cached rows", taskId, cachedRows.size())); - while (!cachedRows.isEmpty()) { - dataExchange.write(DLUtils.encodeStringValue(cachedRows.remove())); - drainRead(out, false); - } - } - } - - private void drainRead(Collector out, boolean readUntilEOF) { - while (true) { - try { - Row r = dataExchange.read(readUntilEOF); - if (r != null) { - out.collect(Row.of(1, r)); - } else { - break; - } - } catch (InterruptedIOException iioe) { - LOG.info("{} Reading from is interrupted, canceling the server", mlContext.getIdentity()); - serverFuture.cancel(true); - } catch (IOException e) { - LOG.error("Fail to read data from python.", e); - throw new AkUnclassifiedErrorException("Fail to read data from python.", e); - } - } - } - // - //@Override - //public TypeInformation getProducedType() { - // return tfFlatMapFunction.getProducedType(); - //} -} diff --git a/core/src/main/java/com/alibaba/alink/common/dl/EasyTransferUtils.java b/core/src/main/java/com/alibaba/alink/common/dl/EasyTransferUtils.java index cf160711b..52496c5c7 100644 --- a/core/src/main/java/com/alibaba/alink/common/dl/EasyTransferUtils.java +++ b/core/src/main/java/com/alibaba/alink/common/dl/EasyTransferUtils.java @@ -9,7 +9,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; diff --git a/core/src/main/java/com/alibaba/alink/common/dl/utils/DLUtils.java b/core/src/main/java/com/alibaba/alink/common/dl/utils/DLUtils.java index 707b5b940..70bb18e2e 100644 --- a/core/src/main/java/com/alibaba/alink/common/dl/utils/DLUtils.java +++ b/core/src/main/java/com/alibaba/alink/common/dl/utils/DLUtils.java @@ -5,7 +5,7 @@ import org.apache.flink.types.Row; import com.alibaba.alink.common.AlinkGlobalConfiguration; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.dl.coding.ExampleCodingConfigV2; import com.alibaba.alink.common.dl.coding.ExampleCodingV2; import com.alibaba.alink.common.dl.data.DataTypesV2; diff --git a/core/src/main/java/com/alibaba/alink/common/dl/utils/DataSetDiskDownloader.java b/core/src/main/java/com/alibaba/alink/common/dl/utils/DataSetDiskDownloader.java index a84e3a666..33026f7c7 100644 --- a/core/src/main/java/com/alibaba/alink/common/dl/utils/DataSetDiskDownloader.java +++ b/core/src/main/java/com/alibaba/alink/common/dl/utils/DataSetDiskDownloader.java @@ -15,7 +15,7 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.DownloadUtils; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.flink.ml.util.IpHostUtil; diff --git a/core/src/main/java/com/alibaba/alink/common/dl/utils/TFExampleConversionUtils.java b/core/src/main/java/com/alibaba/alink/common/dl/utils/TFExampleConversionUtils.java index 938e3447c..1603bbdab 100644 --- a/core/src/main/java/com/alibaba/alink/common/dl/utils/TFExampleConversionUtils.java +++ b/core/src/main/java/com/alibaba/alink/common/dl/utils/TFExampleConversionUtils.java @@ -1,10 +1,9 @@ package com.alibaba.alink.common.dl.utils; -import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.dl.coding.TFExampleConversionV2; import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.exceptions.AkPreconditions; @@ -176,7 +175,7 @@ public static Feature toFeature(Object val, TypeInformation type) { bb.addValue(ByteString.copyFrom(stringTensor.getString(i), StandardCharsets.UTF_8)); } featureBuilder.setBytesList(bb); - } else if (PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO.equals(type)) { + } else if (AlinkTypes.VARBINARY.equals(type)) { Builder bb = BytesList.newBuilder(); bb.addValue(ByteString.copyFrom((byte[]) val)); featureBuilder.setBytesList(bb); @@ -270,7 +269,7 @@ public static Object fromFeature(Feature feature, TypeInformation type) { AkPreconditions.checkArgument(byteStringList.size() > 0, new AkIllegalDataException("no BYTES values in the feature.")); return byteStringList.get(0).toString(StandardCharsets.UTF_8); - } else if (PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO.equals(type)) { + } else if (AlinkTypes.VARBINARY.equals(type)) { AkPreconditions.checkArgument(byteStringList.size() > 0, new AkIllegalDataException("no BYTES values in the feature.")); return byteStringList.get(0).toByteArray(); diff --git a/core/src/main/java/com/alibaba/alink/common/exceptions/ErrorCode.java b/core/src/main/java/com/alibaba/alink/common/exceptions/ErrorCode.java index fc0f19766..3cb740ccd 100644 --- a/core/src/main/java/com/alibaba/alink/common/exceptions/ErrorCode.java +++ b/core/src/main/java/com/alibaba/alink/common/exceptions/ErrorCode.java @@ -18,7 +18,8 @@ enum ErrorCode { ILLEGAL_ARGUMENT(Type.PLATFORM, Level.ERROR, 0x1005L, "Illegal argument"), ILLEGAL_STATE(Type.PLATFORM, Level.ERROR, 0x1006L, "Illegal state"), NULL_POINTER(Type.PLATFORM, Level.ERROR, 0x1007L, "Null pointer"), - ; + // For PyAlink usage + JAVA_SIDE_ERROR(Type.PLATFORM, Level.ERROR, 0x1008L, "Java side error, check enclosed error"); private final Type type; private final Level level; diff --git a/core/src/main/java/com/alibaba/alink/common/fe/define/statistics/NumericStatistics.java b/core/src/main/java/com/alibaba/alink/common/fe/define/statistics/NumericStatistics.java index 53cfb4686..8279b7b35 100644 --- a/core/src/main/java/com/alibaba/alink/common/fe/define/statistics/NumericStatistics.java +++ b/core/src/main/java/com/alibaba/alink/common/fe/define/statistics/NumericStatistics.java @@ -3,7 +3,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; public enum NumericStatistics implements BaseNumericStatistics { diff --git a/core/src/main/java/com/alibaba/alink/common/io/catalog/DataHubCatalog.java b/core/src/main/java/com/alibaba/alink/common/io/catalog/DataHubCatalog.java index be6c04926..f5547dee8 100644 --- a/core/src/main/java/com/alibaba/alink/common/io/catalog/DataHubCatalog.java +++ b/core/src/main/java/com/alibaba/alink/common/io/catalog/DataHubCatalog.java @@ -33,8 +33,8 @@ import com.alibaba.alink.common.io.annotations.CatalogAnnotation; import com.alibaba.alink.common.io.catalog.plugin.DataHubClassLoaderFactory; -import com.alibaba.alink.common.io.plugin.wrapper.RichParallelSourceFunctionWithClassLoader; -import com.alibaba.alink.common.io.plugin.wrapper.RichSinkFunctionWithClassLoader; +import com.alibaba.alink.operator.stream.source.RichParallelSourceFunctionWithClassLoader; +import com.alibaba.alink.operator.stream.sink.RichSinkFunctionWithClassLoader; import com.alibaba.alink.params.io.DataHubParams; import java.lang.reflect.InvocationTargetException; diff --git a/core/src/main/java/com/alibaba/alink/common/io/catalog/HiveCatalog.java b/core/src/main/java/com/alibaba/alink/common/io/catalog/HiveCatalog.java index d69760764..4e510cc35 100644 --- a/core/src/main/java/com/alibaba/alink/common/io/catalog/HiveCatalog.java +++ b/core/src/main/java/com/alibaba/alink/common/io/catalog/HiveCatalog.java @@ -63,11 +63,11 @@ import com.alibaba.alink.common.io.plugin.wrapper.RichInputFormatWithClassLoader; import com.alibaba.alink.common.io.plugin.wrapper.RichOutputFormatWithClassLoader; import com.alibaba.alink.common.io.filesystem.FilePath; -import com.alibaba.alink.common.utils.DataSetConversionUtil; -import com.alibaba.alink.common.utils.DataStreamConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.operator.stream.StreamOperator; +import com.alibaba.alink.operator.stream.utils.DataStreamConversionUtil; import com.alibaba.alink.params.io.HiveCatalogParams; import org.apache.commons.lang3.ArrayUtils; import org.slf4j.Logger; diff --git a/core/src/main/java/com/alibaba/alink/common/io/catalog/InputOutputFormatCatalog.java b/core/src/main/java/com/alibaba/alink/common/io/catalog/InputOutputFormatCatalog.java index 742512ed9..cb949a1f7 100644 --- a/core/src/main/java/com/alibaba/alink/common/io/catalog/InputOutputFormatCatalog.java +++ b/core/src/main/java/com/alibaba/alink/common/io/catalog/InputOutputFormatCatalog.java @@ -14,9 +14,9 @@ import org.apache.flink.types.Row; import com.alibaba.alink.common.MLEnvironmentFactory; -import com.alibaba.alink.common.utils.DataSetConversionUtil; -import com.alibaba.alink.common.utils.DataStreamConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.stream.utils.DataStreamConversionUtil; public abstract class InputOutputFormatCatalog extends BaseCatalog { public InputOutputFormatCatalog(Params params) { diff --git a/core/src/main/java/com/alibaba/alink/common/io/catalog/JdbcCatalog.java b/core/src/main/java/com/alibaba/alink/common/io/catalog/JdbcCatalog.java index 90ecf3719..16492c118 100644 --- a/core/src/main/java/com/alibaba/alink/common/io/catalog/JdbcCatalog.java +++ b/core/src/main/java/com/alibaba/alink/common/io/catalog/JdbcCatalog.java @@ -36,9 +36,9 @@ import org.apache.flink.types.Row; import com.alibaba.alink.common.MLEnvironmentFactory; -import com.alibaba.alink.common.utils.DataSetConversionUtil; -import com.alibaba.alink.common.utils.DataStreamConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.stream.utils.DataStreamConversionUtil; import com.alibaba.alink.params.shared.HasOverwriteSink; import java.sql.Connection; diff --git a/core/src/main/java/com/alibaba/alink/common/io/catalog/SourceSinkFunctionCatalog.java b/core/src/main/java/com/alibaba/alink/common/io/catalog/SourceSinkFunctionCatalog.java index 3c70b333a..ed5868e9d 100644 --- a/core/src/main/java/com/alibaba/alink/common/io/catalog/SourceSinkFunctionCatalog.java +++ b/core/src/main/java/com/alibaba/alink/common/io/catalog/SourceSinkFunctionCatalog.java @@ -10,7 +10,7 @@ import org.apache.flink.types.Row; import com.alibaba.alink.common.MLEnvironmentFactory; -import com.alibaba.alink.common.utils.DataStreamConversionUtil; +import com.alibaba.alink.operator.stream.utils.DataStreamConversionUtil; public abstract class SourceSinkFunctionCatalog extends BaseCatalog { public SourceSinkFunctionCatalog(Params params) { diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/AkUtils.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/AkUtils.java index 11f2b963b..57f38c6e3 100644 --- a/core/src/main/java/com/alibaba/alink/common/io/filesystem/AkUtils.java +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/AkUtils.java @@ -12,7 +12,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/AkUtils2.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/AkUtils2.java index 5c5d42a03..190ed8a5d 100644 --- a/core/src/main/java/com/alibaba/alink/common/io/filesystem/AkUtils2.java +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/AkUtils2.java @@ -8,10 +8,10 @@ import org.apache.flink.types.Row; import org.apache.flink.util.Collector; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.MLEnvironmentFactory; -import com.alibaba.alink.common.utils.DataSetConversionUtil; -import com.alibaba.alink.common.utils.DataStreamConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.stream.utils.DataStreamConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.TableSourceBatchOp; import com.alibaba.alink.operator.batch.sql.WhereBatchOp; diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/BaseFileSystem.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/BaseFileSystem.java index 6f502aedc..79005408b 100644 --- a/core/src/main/java/com/alibaba/alink/common/io/filesystem/BaseFileSystem.java +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/BaseFileSystem.java @@ -10,7 +10,7 @@ import org.apache.flink.core.fs.RecoverableWriter; import org.apache.flink.ml.api.misc.param.Params; -import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; +import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.params.io.HasIoName; @@ -35,7 +35,9 @@ public BaseFileSystem(Params params) { } else { this.params = params.clone(); } - this.params.set(HasIoName.IO_NAME, AnnotationUtils.annotatedName(this.getClass())); + if (!this.params.contains(HasIoName.IO_NAME)) { + this.params.set(HasIoName.IO_NAME, AnnotationUtils.annotatedName(this.getClass())); + } } Params getParams() { @@ -47,10 +49,14 @@ public static BaseFileSystem of(Params params) { try { return AnnotationUtils.createFileSystem(params.get(HasIoName.IO_NAME), params); } catch (Exception e) { - throw new AkUnclassifiedErrorException("create file system failed. ",e); + throw new AkUnclassifiedErrorException(String.format("create %s file system failed. ", params.get(HasIoName.IO_NAME)),e); } } else { - throw new AkIllegalOperatorParameterException("NOT a FileSystem parameter."); + String errorMsg = "params doesn't contain parameter " + HasIoName.IO_NAME; + if (params.contains(HasIoName.IO_NAME)) { + errorMsg = params.get(HasIoName.IO_NAME) + " is not a supported FileSystem."; + } + throw new AkIllegalArgumentException(errorMsg); } } diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/FilePath.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/FilePath.java index 6597c255a..277194242 100644 --- a/core/src/main/java/com/alibaba/alink/common/io/filesystem/FilePath.java +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/FilePath.java @@ -12,6 +12,7 @@ import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.operator.common.io.csv.CsvUtil; import com.alibaba.alink.operator.common.io.reader.HttpFileSplitReader; +import com.alibaba.alink.params.io.HasIoName; import org.apache.commons.io.IOUtils; import java.io.File; @@ -92,6 +93,9 @@ public FilePathJsonable(String path, Params params) { } public FilePath toFilePath() { + if (params != null && params.contains(HasIoName.IO_NAME) && params.get(HasIoName.IO_NAME).equals("local")) { + return new FilePath(path, new LocalFileSystem()); + } return new FilePath(path, params == null ? null : BaseFileSystem.of(params)); } @@ -174,15 +178,21 @@ private void init() { URI uri = path.toUri(); String schema = uri.getScheme(); + //String schema = rewriteUri(uri).getScheme(); + + if (null==schema || schema.equals("file")) { + fileSystem = new LocalFileSystem(); + return; + } // for http - if (schema != null && (schema.equals("http") || schema.equals("https"))) { + if (schema.equals("http") || schema.equals("https")) { fileSystem = new HttpFileReadOnlyFileSystem(); return; } // for oss - if (schema != null && schema.equals("oss")) { + if (schema.equals("oss")) { String authority = CsvUtil.unEscape(uri.getAuthority()); if (authority.contains("\u0001") && authority.contains("\u0002")) { @@ -237,7 +247,6 @@ private void init() { } } - schema = rewriteUri(path.toUri()).getScheme(); List allFileSystemNames = AnnotationUtils.allFileSystemNames(); @@ -270,52 +279,52 @@ private void init() { } } - private static URI rewriteUri(URI fsUri) { - final URI uri; - - if (fsUri.getScheme() != null) { - uri = fsUri; - } else { - // Apply the default fs scheme - final URI defaultUri = org.apache.flink.core.fs.local.LocalFileSystem.getLocalFsURI(); - URI rewrittenUri = null; - - try { - rewrittenUri = new URI(defaultUri.getScheme(), null, defaultUri.getHost(), - defaultUri.getPort(), fsUri.getPath(), null, null); - } catch (URISyntaxException e) { - // for local URIs, we make one more try to repair the path by making it absolute - if (defaultUri.getScheme().equals("file")) { - try { - rewrittenUri = new URI( - "file", null, - new Path(new File(fsUri.getPath()).getAbsolutePath()).toUri().getPath(), - null); - } catch (URISyntaxException ignored) { - // could not help it... - } - } - } - - if (rewrittenUri != null) { - uri = rewrittenUri; - } else { - throw new AkIllegalOperatorParameterException("The file system URI '" + fsUri + - "' declares no scheme and cannot be interpreted relative to the default file system URI (" - + defaultUri + ")."); - } - } - - // print a helpful pointer for malformed local URIs (happens a lot to new users) - if (uri.getScheme().equals("file") && uri.getAuthority() != null && !uri.getAuthority().isEmpty()) { - String supposedUri = "file:///" + uri.getAuthority() + uri.getPath(); - - throw new AkIllegalOperatorParameterException( - "Found local file path with authority '" + uri.getAuthority() + "' in path '" - + uri.toString() + "'. Hint: Did you forget a slash? (correct path would be '" + supposedUri - + "')"); - } - - return uri; - } + //private static URI rewriteUri(URI fsUri) { + // final URI uri; + // + // if (fsUri.getScheme() != null) { + // uri = fsUri; + // } else { + // // Apply the default fs scheme + // final URI defaultUri = org.apache.flink.core.fs.local.LocalFileSystem.getLocalFsURI(); + // URI rewrittenUri = null; + // + // try { + // rewrittenUri = new URI(defaultUri.getScheme(), null, defaultUri.getHost(), + // defaultUri.getPort(), fsUri.getPath(), null, null); + // } catch (URISyntaxException e) { + // // for local URIs, we make one more try to repair the path by making it absolute + // if (defaultUri.getScheme().equals("file")) { + // try { + // rewrittenUri = new URI( + // "file", null, + // new Path(new File(fsUri.getPath()).getAbsolutePath()).toUri().getPath(), + // null); + // } catch (URISyntaxException ignored) { + // // could not help it... + // } + // } + // } + // + // if (rewrittenUri != null) { + // uri = rewrittenUri; + // } else { + // throw new AkIllegalOperatorParameterException("The file system URI '" + fsUri + + // "' declares no scheme and cannot be interpreted relative to the default file system URI (" + // + defaultUri + ")."); + // } + // } + // + // // print a helpful pointer for malformed local URIs (happens a lot to new users) + // if (uri.getScheme().equals("file") && uri.getAuthority() != null && !uri.getAuthority().isEmpty()) { + // String supposedUri = "file:///" + uri.getAuthority() + uri.getPath(); + // + // throw new AkIllegalOperatorParameterException( + // "Found local file path with authority '" + uri.getAuthority() + "' in path '" + // + uri.toString() + "'. Hint: Did you forget a slash? (correct path would be '" + supposedUri + // + "')"); + // } + // + // return uri; + //} } diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/HttpFileReadOnlyFileSystem.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/HttpFileReadOnlyFileSystem.java index 2eab566f6..4995c5c47 100644 --- a/core/src/main/java/com/alibaba/alink/common/io/filesystem/HttpFileReadOnlyFileSystem.java +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/HttpFileReadOnlyFileSystem.java @@ -9,8 +9,7 @@ import org.apache.flink.core.fs.Path; import org.apache.flink.ml.api.misc.param.Params; -import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; -import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; +import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import com.alibaba.alink.common.io.annotations.FSAnnotation; import org.slf4j.Logger; @@ -153,22 +152,22 @@ public FileStatus[] listStatus(Path f) throws IOException { @Override public boolean delete(Path f, boolean recursive) throws IOException { - throw new AkUnsupportedOperationException("Not support exception. "); + throw new AkUnsupportedOperationException("Http not support delete operation. "); } @Override public boolean mkdirs(Path f) throws IOException { - throw new AkUnsupportedOperationException("Not support exception. "); + throw new AkUnsupportedOperationException("Http not support mkdirs operation. "); } @Override public FSDataOutputStream create(Path f, WriteMode overwriteMode) throws IOException { - throw new AkUnsupportedOperationException("Not support exception. "); + throw new AkUnsupportedOperationException("Http not support create file operation. "); } @Override public boolean rename(Path src, Path dst) throws IOException { - throw new AkUnsupportedOperationException("Not support exception. "); + throw new AkUnsupportedOperationException("Http not support rename operation. "); } @Override @@ -202,19 +201,20 @@ static long doGetLen(Path path) { LOG.info("contentLength of {}, acceptRanges of {} to download {}", contentLength, acceptRanges, path); if (contentLength < 0) { - throw new AkUnsupportedOperationException("The content length can't be determined."); + throw new AkIllegalDataException("The content length can't be determined because content length < 0."); } // If the http server does not accept ranges, then we quit the program. // This is because 'accept ranges' is required to achieve robustness (through re-connection), // and efficiency (through concurrent read). if (!splittable) { - throw new AkUnsupportedOperationException("The http server does not support range reading."); + throw new AkIllegalDataException("Http-Header doesn't have header 'Accept-Ranges' or the value of " + + "'Accept-Ranges' value not equal 'bytes', The http server does not support range reading."); } return contentLength; } catch (Exception e) { - throw new AkUnclassifiedErrorException("Fail to connect to http server", e); + throw new AkIllegalDataException(String.format("Fail to connect to http address %s", path.getPath()), e); } finally { if (headerConnection != null) { headerConnection.disconnect(); @@ -281,7 +281,7 @@ public int read() throws IOException { private void createInternal(long start, long end) throws IOException { if (start >= end) { - throw new AkIllegalArgumentException("start position of http file is is lager than end position"); + throw new AkIllegalDataException("start position of http file is lager than end position"); } connection = (HttpURLConnection) path.toUri().toURL().openConnection(); diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/LocalFileSystem.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/LocalFileSystem.java index 021b29eb3..b62ea652a 100644 --- a/core/src/main/java/com/alibaba/alink/common/io/filesystem/LocalFileSystem.java +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/LocalFileSystem.java @@ -1,44 +1,32 @@ package com.alibaba.alink.common.io.filesystem; import org.apache.flink.core.fs.FileSystem; -import org.apache.flink.core.fs.FileSystemFactory; import org.apache.flink.core.fs.Path; -import org.apache.flink.core.fs.local.LocalFileSystemFactory; import org.apache.flink.ml.api.misc.param.Params; -import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; import com.alibaba.alink.common.io.annotations.FSAnnotation; - -import java.io.IOException; +import com.alibaba.alink.params.io.HasIoName; @FSAnnotation(name = "local") public final class LocalFileSystem extends BaseFileSystem { private static final long serialVersionUID = 1818806211030723090L; - private transient FileSystemFactory loaded; public LocalFileSystem() { this(new Params()); } public LocalFileSystem(Params params) { - super(params); + super(params.set(HasIoName.IO_NAME, LocalFileSystem.class.getAnnotation(FSAnnotation.class).name())); } @Override public String getSchema() { - return new LocalFileSystemFactory().getScheme(); + return com.alibaba.alink.common.io.filesystem.copy.local.LocalFileSystem.getLocalFsURI().getScheme(); } @Override protected FileSystem load(Path path) { - if (loaded == null) { - loaded = new LocalFileSystemFactory(); - } - - try { - return loaded.create(null); - } catch (IOException e) { - throw new AkUnclassifiedErrorException("local file load error",e); - } + return com.alibaba.alink.common.io.filesystem.copy.local.LocalFileSystem.getSharedInstance(); } + } diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/OssFileSystem.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/OssFileSystem.java index a98f8bf8b..0bd6dead6 100644 --- a/core/src/main/java/com/alibaba/alink/common/io/filesystem/OssFileSystem.java +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/OssFileSystem.java @@ -7,6 +7,9 @@ import org.apache.flink.core.fs.Path; import org.apache.flink.ml.api.misc.param.Params; +import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; +import com.alibaba.alink.common.exceptions.AkParseErrorException; +import com.alibaba.alink.common.exceptions.AkPluginErrorException; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; import com.alibaba.alink.common.io.annotations.FSAnnotation; import com.alibaba.alink.common.io.filesystem.plugin.FileSystemClassLoaderFactory; @@ -53,7 +56,8 @@ public OssFileSystem(String ossVersion, String endPoint, String bucketName, Stri new URI(getSchema(), bucketName, null, null).toString() ); } catch (URISyntaxException e) { - throw new AkUnclassifiedErrorException("Error. ", e); + throw new AkParseErrorException( + "Syntax error in OSS file system URI, please check your bucket name: " + bucketName, e); } } @@ -79,23 +83,28 @@ && getParams().get(OssFileSystemParams.ACCESS_KEY) != null) { FileSystemFactory factory = createFactory(); factory.configure(conf); + URI uri = null; try { if (getParams().get(OssFileSystemParams.FS_URI) != null) { + uri = new Path(getParams().get(OssFileSystemParams.FS_URI)).toUri(); try (TemporaryClassLoaderContext context = TemporaryClassLoaderContext.of(factory.getClassLoader())) { - loaded = factory.create(new Path(getParams().get(OssFileSystemParams.FS_URI)).toUri()); + loaded = factory.create(uri); } return loaded; } else if (path != null) { + uri = path.toUri(); try (TemporaryClassLoaderContext context = TemporaryClassLoaderContext.of(factory.getClassLoader())) { - loaded = factory.create(path.toUri()); + loaded = factory.create(uri); } return loaded; + } else { + throw new AkIllegalArgumentException( + "Could not create the oss file system, as both the bucket and the filePath are null."); } } catch (IOException e) { - throw new AkUnclassifiedErrorException("Error. ", e); + throw new AkUnclassifiedErrorException( + "Failed to create OSS file system from URI: " + uri.toString(), e); } - - throw new AkUnclassifiedErrorException("Could not create the oss file system. Both the bucket the filePath are null."); } @Override @@ -122,7 +131,7 @@ private FileSystemFactory createFactory() { .newInstance(); } catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException | ClassNotFoundException e) { - throw new AkUnclassifiedErrorException("Error. ", e); + throw new AkPluginErrorException("Failed to load OSSFileSystemFactory or create its instance. ", e); } } diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalBlockLocation.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalBlockLocation.java new file mode 100644 index 000000000..0f2037020 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalBlockLocation.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.alink.common.io.filesystem.copy.local; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.core.fs.BlockLocation; + +import java.io.IOException; + +/** + * Implementation of the {@link BlockLocation} interface for a local file system. + */ +@Internal +public class LocalBlockLocation implements BlockLocation { + + private final long length; + + private final String[] hosts; + + public LocalBlockLocation(final String host, final long length) { + this.hosts = new String[] { host }; + this.length = length; + } + + @Override + public String[] getHosts() throws IOException { + return this.hosts; + } + + @Override + public long getLength() { + return this.length; + } + + @Override + public long getOffset() { + return 0; + } + + @Override + public int compareTo(final BlockLocation o) { + return 0; + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalDataInputStream.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalDataInputStream.java new file mode 100644 index 000000000..95174355e --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalDataInputStream.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.alink.common.io.filesystem.copy.local; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.core.fs.FSDataInputStream; + +import javax.annotation.Nonnull; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.nio.channels.FileChannel; + +/** + * The LocalDataInputStream class is a wrapper class for a data + * input stream to the local file system. + */ +@Internal +public class LocalDataInputStream extends FSDataInputStream { + + /** The file input stream used to read data from.*/ + private final FileInputStream fis; + private final FileChannel fileChannel; + + /** + * Constructs a new LocalDataInputStream object from a given {@link File} object. + * + * @param file The File the data stream is read from + * + * @throws IOException Thrown if the data input stream cannot be created. + */ + public LocalDataInputStream(File file) throws IOException { + this.fis = new FileInputStream(file); + this.fileChannel = fis.getChannel(); + } + + @Override + public void seek(long desired) throws IOException { + if (desired != getPos()) { + this.fileChannel.position(desired); + } + } + + @Override + public long getPos() throws IOException { + return this.fileChannel.position(); + } + + @Override + public int read() throws IOException { + return this.fis.read(); + } + + @Override + public int read(@Nonnull byte[] buffer, int offset, int length) throws IOException { + return this.fis.read(buffer, offset, length); + } + + @Override + public void close() throws IOException { + // According to javadoc, this also closes the channel + this.fis.close(); + } + + @Override + public int available() throws IOException { + return this.fis.available(); + } + + @Override + public long skip(final long n) throws IOException { + return this.fis.skip(n); + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalDataOutputStream.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalDataOutputStream.java new file mode 100644 index 000000000..e9138b342 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalDataOutputStream.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.alink.common.io.filesystem.copy.local; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.core.fs.FSDataOutputStream; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; + +/** + * The LocalDataOutputStream class is a wrapper class for a data + * output stream to the local file system. + */ +@Internal +public class LocalDataOutputStream extends FSDataOutputStream { + + /** The file output stream used to write data.*/ + private final FileOutputStream fos; + + /** + * Constructs a new LocalDataOutputStream object from a given {@link File} object. + * + * @param file + * the {@link File} object the data stream is read from + * @throws IOException + * thrown if the data output stream cannot be created + */ + public LocalDataOutputStream(final File file) throws IOException { + this.fos = new FileOutputStream(file); + } + + @Override + public void write(final int b) throws IOException { + fos.write(b); + } + + @Override + public void write(final byte[] b, final int off, final int len) throws IOException { + fos.write(b, off, len); + } + + @Override + public void close() throws IOException { + fos.close(); + } + + @Override + public void flush() throws IOException { + fos.flush(); + } + + @Override + public void sync() throws IOException { + fos.getFD().sync(); + } + + @Override + public long getPos() throws IOException { + return fos.getChannel().position(); + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalFileStatus.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalFileStatus.java new file mode 100644 index 000000000..53fd2a720 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalFileStatus.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.alink.common.io.filesystem.copy.local; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.core.fs.FileStatus; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.core.fs.Path; + +import java.io.File; + +/** + * The class LocalFileStatus provides an implementation of the {@link FileStatus} interface + * for the local file system. + */ +@Internal +public class LocalFileStatus implements FileStatus { + + /** + * The file this file status belongs to. + */ + private final File file; + + /** + * The path of this file this file status belongs to. + */ + private final Path path; + + /** + * Creates a LocalFileStatus object from a given {@link File} object. + * + * @param f + * the {@link File} object this LocalFileStatus refers to + * @param fs + * the file system the corresponding file has been read from + */ + public LocalFileStatus(final File f, final FileSystem fs) { + this.file = f; + this.path = new Path(fs.getUri().getScheme() + ":" + f.toURI().getPath()); + } + + @Override + public long getAccessTime() { + return 0; // We don't have access files for local files + } + + @Override + public long getBlockSize() { + return this.file.length(); + } + + @Override + public long getLen() { + return this.file.length(); + } + + @Override + public long getModificationTime() { + return this.file.lastModified(); + } + + @Override + public short getReplication() { + return 1; // For local files replication is always 1 + } + + @Override + public boolean isDir() { + return this.file.isDirectory(); + } + + @Override + public Path getPath() { + return this.path; + } + + public File getFile() { + return this.file; + } + + @Override + public String toString() { + return "LocalFileStatus{" + + "file=" + file + + ", path=" + path + + '}'; + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalFileSystem.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalFileSystem.java new file mode 100644 index 000000000..5c6a11c61 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalFileSystem.java @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Parts of earlier versions of this file were based on source code from the + * Hadoop Project (http://hadoop.apache.org/), licensed by the Apache Software Foundation (ASF) + * under the Apache License, Version 2.0. See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package com.alibaba.alink.common.io.filesystem.copy.local; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.core.fs.BlockLocation; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.fs.FSDataOutputStream; +import org.apache.flink.core.fs.FileStatus; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.core.fs.FileSystemKind; +import org.apache.flink.core.fs.Path; +import org.apache.flink.util.OperatingSystem; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.net.InetAddress; +import java.net.URI; +import java.net.UnknownHostException; +import java.nio.file.AccessDeniedException; +import java.nio.file.DirectoryNotEmptyException; +import java.nio.file.FileAlreadyExistsException; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.StandardCopyOption; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * The class {@code LocalFileSystem} is an implementation of the {@link FileSystem} interface + * for the local file system of the machine where the JVM runs. + */ +@Internal +public class LocalFileSystem extends FileSystem { + + private static final Logger LOG = LoggerFactory.getLogger(LocalFileSystem.class); + + /** The URI representing the local file system. */ + private static final URI LOCAL_URI = OperatingSystem.isWindows() ? URI.create("file:/") : URI.create("file:///"); + + /** The shared instance of the local file system. */ + private static final LocalFileSystem + INSTANCE = new LocalFileSystem(); + + /** Path pointing to the current working directory. + * Because Paths are not immutable, we cannot cache the proper path here */ + private final URI workingDir; + + /** Path pointing to the current user home directory. + * Because Paths are not immutable, we cannot cache the proper path here. */ + private final URI homeDir; + + /** The host name of this machine. */ + private final String hostName = "unknownHost"; + + /** + * Constructs a new LocalFileSystem object. + */ + public LocalFileSystem() { + this.workingDir = new File(System.getProperty("user.dir")).toURI(); + this.homeDir = new File(System.getProperty("user.home")).toURI(); + + //String tmp = "unknownHost"; + //try { + // tmp = InetAddress.getLocalHost().getHostName(); + //} catch (UnknownHostException e) { + // LOG.error("Could not resolve local host", e); + //} + //this.hostName = tmp; + } + + // ------------------------------------------------------------------------ + + @Override + public BlockLocation[] getFileBlockLocations(FileStatus file, long start, long len) throws IOException { + return new BlockLocation[] { + new LocalBlockLocation(hostName, file.getLen()) + }; + } + + @Override + public FileStatus getFileStatus(Path f) throws IOException { + final File path = pathToFile(f); + if (path.exists()) { + return new LocalFileStatus(path, this); + } + else { + throw new FileNotFoundException("File " + f + " does not exist or the user running " + + "Flink ('" + System.getProperty("user.name") + "') has insufficient permissions to access it."); + } + } + + @Override + public URI getUri() { + return LOCAL_URI; + } + + @Override + public Path getWorkingDirectory() { + return new Path(workingDir); + } + + @Override + public Path getHomeDirectory() { + return new Path(homeDir); + } + + @Override + public FSDataInputStream open(final Path f, final int bufferSize) throws IOException { + return open(f); + } + + @Override + public FSDataInputStream open(final Path f) throws IOException { + final File file = pathToFile(f); + return new LocalDataInputStream(file); + } + + @Override + public LocalRecoverableWriter createRecoverableWriter() throws IOException { + return new LocalRecoverableWriter(this); + } + + @Override + public boolean exists(Path f) throws IOException { + final File path = pathToFile(f); + return path.exists(); + } + + @Override + public FileStatus[] listStatus(final Path f) throws IOException { + + final File localf = pathToFile(f); + FileStatus[] results; + + if (!localf.exists()) { + return null; + } + if (localf.isFile()) { + return new FileStatus[] { new LocalFileStatus(localf, this) }; + } + + final String[] names = localf.list(); + if (names == null) { + return null; + } + results = new FileStatus[names.length]; + for (int i = 0; i < names.length; i++) { + results[i] = getFileStatus(new Path(f, names[i])); + } + + return results; + } + + @Override + public boolean delete(final Path f, final boolean recursive) throws IOException { + + final File file = pathToFile(f); + if (file.isFile()) { + return file.delete(); + } else if ((!recursive) && file.isDirectory()) { + File[] containedFiles = file.listFiles(); + if (containedFiles == null) { + throw new IOException("Directory " + file.toString() + " does not exist or an I/O error occurred"); + } else if (containedFiles.length != 0) { + throw new IOException("Directory " + file.toString() + " is not empty"); + } + } + + return delete(file); + } + + /** + * Deletes the given file or directory. + * + * @param f + * the file to be deleted + * @return true if all files were deleted successfully, false otherwise + * @throws IOException + * thrown if an error occurred while deleting the files/directories + */ + private boolean delete(final File f) throws IOException { + + if (f.isDirectory()) { + final File[] files = f.listFiles(); + if (files != null) { + for (File file : files) { + final boolean del = delete(file); + if (!del) { + return false; + } + } + } + } else { + return f.delete(); + } + + // Now directory is empty + return f.delete(); + } + + /** + * Recursively creates the directory specified by the provided path. + * + * @return trueif the directories either already existed or have been created successfully, + * false otherwise + * @throws IOException + * thrown if an error occurred while creating the directory/directories + */ + @Override + public boolean mkdirs(final Path f) throws IOException { + checkNotNull(f, "path is null"); + return mkdirsInternal(pathToFile(f)); + } + + private boolean mkdirsInternal(File file) throws IOException { + if (file.isDirectory()) { + return true; + } + else if (file.exists() && !file.isDirectory()) { + // Important: The 'exists()' check above must come before the 'isDirectory()' check to + // be safe when multiple parallel instances try to create the directory + + // exists and is not a directory -> is a regular file + throw new FileAlreadyExistsException(file.getAbsolutePath()); + } + else { + File parent = file.getParentFile(); + return (parent == null || mkdirsInternal(parent)) && (file.mkdir() || file.isDirectory()); + } + } + + @Override + public FSDataOutputStream create(final Path filePath, final WriteMode overwrite) throws IOException { + checkNotNull(filePath, "filePath"); + + if (exists(filePath) && overwrite == WriteMode.NO_OVERWRITE) { + throw new FileAlreadyExistsException("File already exists: " + filePath); + } + + final Path parent = filePath.getParent(); + if (parent != null && !mkdirs(parent)) { + throw new IOException("Mkdirs failed to create " + parent); + } + + final File file = pathToFile(filePath); + return new LocalDataOutputStream(file); + } + + @Override + public boolean rename(final Path src, final Path dst) throws IOException { + final File srcFile = pathToFile(src); + final File dstFile = pathToFile(dst); + + final File dstParent = dstFile.getParentFile(); + + // Files.move fails if the destination directory doesn't exist + //noinspection ResultOfMethodCallIgnored -- we don't care if the directory existed or was created + dstParent.mkdirs(); + + try { + Files.move(srcFile.toPath(), dstFile.toPath(), StandardCopyOption.REPLACE_EXISTING); + return true; + } + catch (NoSuchFileException | AccessDeniedException | DirectoryNotEmptyException | SecurityException ex) { + // catch the errors that are regular "move failed" exceptions and return false + return false; + } + } + + @Override + public boolean isDistributedFS() { + return false; + } + + @Override + public FileSystemKind getKind() { + return FileSystemKind.FILE_SYSTEM; + } + + // ------------------------------------------------------------------------ + + /** + * Converts the given Path to a File for this file system. + * + *

If the path is not absolute, it is interpreted relative to this FileSystem's working directory. + */ + public File pathToFile(Path path) { + if (!path.isAbsolute()) { + path = new Path(getWorkingDirectory(), path); + } + return new File(path.toUri().getPath()); + } + + // ------------------------------------------------------------------------ + + /** + * Gets the URI that represents the local file system. + * That URI is {@code "file:/"} on Windows platforms and {@code "file:///"} on other + * UNIX family platforms. + * + * @return The URI that represents the local file system. + */ + public static URI getLocalFsURI() { + return LOCAL_URI; + } + + /** + * Gets the shared instance of this file system. + * + * @return The shared instance of this file system. + */ + public static LocalFileSystem getSharedInstance() { + return INSTANCE; + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalFileSystemFactory.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalFileSystemFactory.java new file mode 100644 index 000000000..d4dd4e51c --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalFileSystemFactory.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.alink.common.io.filesystem.copy.local; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.core.fs.FileSystemFactory; + +import java.net.URI; + +/** + * A factory for the {@link LocalFileSystem}. + */ +@PublicEvolving +public class LocalFileSystemFactory implements FileSystemFactory { + + @Override + public String getScheme() { + return LocalFileSystem.getLocalFsURI().getScheme(); + } + + @Override + public FileSystem create(URI fsUri) { + return LocalFileSystem.getSharedInstance(); + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalRecoverable.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalRecoverable.java new file mode 100644 index 000000000..e9a7527ad --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalRecoverable.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.alink.common.io.filesystem.copy.local; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.core.fs.RecoverableWriter.CommitRecoverable; +import org.apache.flink.core.fs.RecoverableWriter.ResumeRecoverable; + +import java.io.File; + +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * An implementation of the resume and commit descriptor objects for local recoverable streams. + */ +@Internal +class LocalRecoverable implements CommitRecoverable, ResumeRecoverable { + + /** The file path for the final result file. */ + private final File targetFile; + + /** The file path of the staging file. */ + private final File tempFile; + + /** The position to resume from. */ + private final long offset; + + /** + * Creates a resumable for the given file at the given position. + * + * @param targetFile The file to resume. + * @param offset The position to resume from. + */ + LocalRecoverable(File targetFile, File tempFile, long offset) { + checkArgument(offset >= 0, "offset must be >= 0"); + this.targetFile = checkNotNull(targetFile, "targetFile"); + this.tempFile = checkNotNull(tempFile, "tempFile"); + this.offset = offset; + } + + public File targetFile() { + return targetFile; + } + + public File tempFile() { + return tempFile; + } + + public long offset() { + return offset; + } + + @Override + public String toString() { + return "LocalRecoverable " + tempFile + " @ " + offset + " -> " + targetFile; + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalRecoverableFsDataOutputStream.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalRecoverableFsDataOutputStream.java new file mode 100644 index 000000000..a3dea0ff0 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalRecoverableFsDataOutputStream.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.alink.common.io.filesystem.copy.local; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.core.fs.RecoverableFsDataOutputStream; +import org.apache.flink.core.fs.RecoverableWriter.CommitRecoverable; +import org.apache.flink.core.fs.RecoverableWriter.ResumeRecoverable; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.channels.Channels; +import java.nio.channels.FileChannel; +import java.nio.file.AtomicMoveNotSupportedException; +import java.nio.file.FileAlreadyExistsException; +import java.nio.file.Files; +import java.nio.file.StandardCopyOption; +import java.nio.file.StandardOpenOption; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * A {@link RecoverableFsDataOutputStream} for the {@link LocalFileSystem}. + */ +@Internal +class LocalRecoverableFsDataOutputStream extends RecoverableFsDataOutputStream { + + private final File targetFile; + + private final File tempFile; + + private final FileChannel fileChannel; + + private final OutputStream fos; + + LocalRecoverableFsDataOutputStream(File targetFile, File tempFile) throws IOException { + this.targetFile = checkNotNull(targetFile); + this.tempFile = checkNotNull(tempFile); + + this.fileChannel = FileChannel.open(tempFile.toPath(), StandardOpenOption.WRITE, StandardOpenOption.CREATE_NEW); + this.fos = Channels.newOutputStream(fileChannel); + } + + LocalRecoverableFsDataOutputStream(LocalRecoverable resumable) throws IOException { + this.targetFile = checkNotNull(resumable.targetFile()); + this.tempFile = checkNotNull(resumable.tempFile()); + + if (!tempFile.exists()) { + throw new FileNotFoundException("File Not Found: " + tempFile.getAbsolutePath()); + } + + this.fileChannel = FileChannel.open(tempFile.toPath(), StandardOpenOption.WRITE, StandardOpenOption.APPEND); + if (this.fileChannel.position() < resumable.offset()) { + throw new IOException("Missing data in tmp file: " + tempFile.getAbsolutePath()); + } + this.fileChannel.truncate(resumable.offset()); + this.fos = Channels.newOutputStream(fileChannel); + } + + @Override + public void write(int b) throws IOException { + fos.write(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + fos.write(b, off, len); + } + + @Override + public void flush() throws IOException { + fos.flush(); + } + + @Override + public void sync() throws IOException { + fileChannel.force(true); + } + + @Override + public long getPos() throws IOException { + return fileChannel.position(); + } + + @Override + public ResumeRecoverable persist() throws IOException { + // we call both flush and sync in order to ensure persistence on mounted + // file systems, like NFS, EBS, EFS, ... + flush(); + sync(); + + return new LocalRecoverable(targetFile, tempFile, getPos()); + } + + @Override + public Committer closeForCommit() throws IOException { + final long pos = getPos(); + close(); + return new LocalCommitter(new LocalRecoverable(targetFile, tempFile, pos)); + } + + @Override + public void close() throws IOException { + fos.close(); + } + + // ------------------------------------------------------------------------ + + static class LocalCommitter implements Committer { + + private final LocalRecoverable recoverable; + + LocalCommitter(LocalRecoverable recoverable) { + this.recoverable = checkNotNull(recoverable); + } + + @Override + public void commit() throws IOException { + final File src = recoverable.tempFile(); + final File dest = recoverable.targetFile(); + + // sanity check + if (src.length() != recoverable.offset()) { + // something was done to this file since the committer was created. + // this is not the "clean" case + throw new IOException("Cannot clean commit: File has trailing junk data."); + } + + // rather than fall into default recovery, handle errors explicitly + // in order to improve error messages + try { + Files.move(src.toPath(), dest.toPath(), StandardCopyOption.ATOMIC_MOVE); + } + catch (UnsupportedOperationException | AtomicMoveNotSupportedException e) { + if (!src.renameTo(dest)) { + throw new IOException("Committing file failed, could not rename " + src + " -> " + dest); + } + } + catch (FileAlreadyExistsException e) { + throw new IOException("Committing file failed. Target file already exists: " + dest); + } + } + + @Override + public void commitAfterRecovery() throws IOException { + final File src = recoverable.tempFile(); + final File dest = recoverable.targetFile(); + final long expectedLength = recoverable.offset(); + + if (src.exists()) { + if (src.length() > expectedLength) { + // can happen if we co from persist to recovering for commit directly + // truncate the trailing junk away + try (FileOutputStream fos = new FileOutputStream(src, true)) { + fos.getChannel().truncate(expectedLength); + } + } else if (src.length() < expectedLength) { + throw new IOException("Missing data in tmp file: " + src); + } + + // source still exists, so no renaming happened yet. do it! + Files.move(src.toPath(), dest.toPath(), StandardCopyOption.ATOMIC_MOVE); + } + else if (!dest.exists()) { + // neither exists - that can be a sign of + // - (1) a serious problem (file system loss of data) + // - (2) a recovery of a savepoint that is some time old and the users + // removed the files in the meantime. + + // TODO how to handle this? + // We probably need an option for users whether this should log, + // or result in an exception or unrecoverable exception + } + } + + @Override + public CommitRecoverable getRecoverable() { + return recoverable; + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalRecoverableSerializer.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalRecoverableSerializer.java new file mode 100644 index 000000000..9f3068b3c --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalRecoverableSerializer.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.alink.common.io.filesystem.copy.local; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.core.io.SimpleVersionedSerializer; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; + +/** + * Simple serializer for the {@link LocalRecoverable}. + */ +@Internal +class LocalRecoverableSerializer implements SimpleVersionedSerializer { + + static final LocalRecoverableSerializer + INSTANCE = new LocalRecoverableSerializer(); + + private static final Charset CHARSET = StandardCharsets.UTF_8; + + private static final int MAGIC_NUMBER = 0x1e744b57; + + /** + * Do not instantiate, use reusable {@link #INSTANCE} instead. + */ + private LocalRecoverableSerializer() {} + + @Override + public int getVersion() { + return 1; + } + + @Override + public byte[] serialize(LocalRecoverable obj) throws IOException { + final byte[] targetFileBytes = obj.targetFile().getAbsolutePath().getBytes(CHARSET); + final byte[] tempFileBytes = obj.tempFile().getAbsolutePath().getBytes(CHARSET); + final byte[] targetBytes = new byte[20 + targetFileBytes.length + tempFileBytes.length]; + + ByteBuffer bb = ByteBuffer.wrap(targetBytes).order(ByteOrder.LITTLE_ENDIAN); + bb.putInt(MAGIC_NUMBER); + bb.putLong(obj.offset()); + bb.putInt(targetFileBytes.length); + bb.putInt(tempFileBytes.length); + bb.put(targetFileBytes); + bb.put(tempFileBytes); + + return targetBytes; + } + + @Override + public LocalRecoverable deserialize(int version, byte[] serialized) throws IOException { + switch (version) { + case 1: + return deserializeV1(serialized); + default: + throw new IOException("Unrecognized version or corrupt state: " + version); + } + } + + private static LocalRecoverable deserializeV1(byte[] serialized) throws IOException { + final ByteBuffer bb = ByteBuffer.wrap(serialized).order(ByteOrder.LITTLE_ENDIAN); + + if (bb.getInt() != MAGIC_NUMBER) { + throw new IOException("Corrupt data: Unexpected magic number."); + } + + final long offset = bb.getLong(); + final byte[] targetFileBytes = new byte[bb.getInt()]; + final byte[] tempFileBytes = new byte[bb.getInt()]; + bb.get(targetFileBytes); + bb.get(tempFileBytes); + + final String targetPath = new String(targetFileBytes, CHARSET); + final String tempPath = new String(tempFileBytes, CHARSET); + + return new LocalRecoverable(new File(targetPath), new File(tempPath), offset); + + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalRecoverableWriter.java b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalRecoverableWriter.java new file mode 100644 index 000000000..08ad487b0 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/io/filesystem/copy/local/LocalRecoverableWriter.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.alink.common.io.filesystem.copy.local; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.core.fs.Path; +import org.apache.flink.core.fs.RecoverableFsDataOutputStream; +import org.apache.flink.core.fs.RecoverableFsDataOutputStream.Committer; +import org.apache.flink.core.fs.RecoverableWriter; +import org.apache.flink.core.io.SimpleVersionedSerializer; + +import java.io.File; +import java.io.IOException; +import java.util.UUID; + +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * A {@link RecoverableWriter} for the {@link LocalFileSystem}. + */ +@Internal +public class LocalRecoverableWriter implements RecoverableWriter { + + private final LocalFileSystem fs; + + public LocalRecoverableWriter(LocalFileSystem fs) { + this.fs = checkNotNull(fs); + } + + @Override + public RecoverableFsDataOutputStream open(Path filePath) throws IOException { + final File targetFile = fs.pathToFile(filePath); + final File tempFile = generateStagingTempFilePath(targetFile); + + // try to create the parent + final File parent = tempFile.getParentFile(); + if (parent != null && !parent.mkdirs() && !parent.exists()) { + throw new IOException("Failed to create the parent directory: " + parent); + } + + return new LocalRecoverableFsDataOutputStream(targetFile, tempFile); + } + + @Override + public RecoverableFsDataOutputStream recover(ResumeRecoverable recoverable) throws IOException { + if (recoverable instanceof LocalRecoverable) { + return new LocalRecoverableFsDataOutputStream((LocalRecoverable) recoverable); + } + else { + throw new IllegalArgumentException( + "LocalFileSystem cannot recover recoverable for other file system: " + recoverable); + } + } + + @Override + public boolean requiresCleanupOfRecoverableState() { + return false; + } + + @Override + public boolean cleanupRecoverableState(ResumeRecoverable resumable) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public Committer recoverForCommit(CommitRecoverable recoverable) throws IOException { + if (recoverable instanceof LocalRecoverable) { + return new LocalRecoverableFsDataOutputStream.LocalCommitter((LocalRecoverable) recoverable); + } + else { + throw new IllegalArgumentException( + "LocalFileSystem cannot recover recoverable for other file system: " + recoverable); + } + } + + @Override + public SimpleVersionedSerializer getCommitRecoverableSerializer() { + @SuppressWarnings("unchecked") + SimpleVersionedSerializer typedSerializer = (SimpleVersionedSerializer) + (SimpleVersionedSerializer) LocalRecoverableSerializer.INSTANCE; + + return typedSerializer; + } + + @Override + public SimpleVersionedSerializer getResumeRecoverableSerializer() { + @SuppressWarnings("unchecked") + SimpleVersionedSerializer typedSerializer = (SimpleVersionedSerializer) + (SimpleVersionedSerializer) LocalRecoverableSerializer.INSTANCE; + + return typedSerializer; + } + + @Override + public boolean supportsResume() { + return true; + } + + static File generateStagingTempFilePath(File targetFile) { + checkArgument(targetFile.isAbsolute(), "targetFile must be absolute"); + checkArgument(!targetFile.isDirectory(), "targetFile must not be a directory"); + + final File parent = targetFile.getParentFile(); + final String name = targetFile.getName(); + + checkArgument(parent != null, "targetFile must not be the root directory"); + + while (true) { + File candidate = new File(parent, "." + name + ".inprogress." + UUID.randomUUID().toString()); + if (!candidate.exists()) { + return candidate; + } + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaClassLoaderFactory.java b/core/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaClassLoaderFactory.java index f5ba42e50..3cbddc67d 100644 --- a/core/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaClassLoaderFactory.java +++ b/core/src/main/java/com/alibaba/alink/common/io/kafka/plugin/KafkaClassLoaderFactory.java @@ -9,6 +9,7 @@ import com.alibaba.alink.common.io.plugin.PluginDistributeCache; import com.alibaba.alink.common.io.plugin.RegisterKey; import com.alibaba.alink.common.io.plugin.TemporaryClassLoaderContext; +import com.alibaba.alink.operator.stream.sink.KafkaSourceSinkFactory; import java.util.Iterator; import java.util.ServiceLoader; diff --git a/core/src/main/java/com/alibaba/alink/common/io/plugin/PluginDownloader.java b/core/src/main/java/com/alibaba/alink/common/io/plugin/PluginDownloader.java index d4103433b..fe8076b2d 100644 --- a/core/src/main/java/com/alibaba/alink/common/io/plugin/PluginDownloader.java +++ b/core/src/main/java/com/alibaba/alink/common/io/plugin/PluginDownloader.java @@ -1,6 +1,5 @@ package com.alibaba.alink.common.io.plugin; -import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.core.fs.Path; import org.apache.flink.shaded.guava18.com.google.common.io.Files; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference; @@ -369,7 +368,6 @@ public void loadResourcePluginConfig() throws IOException { isResourcePluginConfigLoaded = true; } - @VisibleForTesting void loadConfigFromString(String jsonString) { if (!isJarsPluginConfigLoaded) { jarsPluginConfigs = JsonConverter.fromJson(jsonString, @@ -378,7 +376,6 @@ void loadConfigFromString(String jsonString) { isJarsPluginConfigLoaded = true; } - @VisibleForTesting void loadResourceConfigFromString(String jsonString) { if (!isResourcePluginConfigLoaded) { resourcePluginConfigs = JsonConverter.fromJson(jsonString, diff --git a/core/src/main/java/com/alibaba/alink/common/io/xls/XlsReader.java b/core/src/main/java/com/alibaba/alink/common/io/xls/XlsFile.java similarity index 67% rename from core/src/main/java/com/alibaba/alink/common/io/xls/XlsReader.java rename to core/src/main/java/com/alibaba/alink/common/io/xls/XlsFile.java index d4f1a003d..12dfa88c2 100644 --- a/core/src/main/java/com/alibaba/alink/common/io/xls/XlsReader.java +++ b/core/src/main/java/com/alibaba/alink/common/io/xls/XlsFile.java @@ -1,5 +1,6 @@ package com.alibaba.alink.common.io.xls; +import org.apache.flink.api.common.io.FileOutputFormat; import org.apache.flink.api.common.io.RichInputFormat; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.core.fs.FileInputSplit; @@ -7,6 +8,8 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -public interface XlsReader { - Tuple2 , TableSchema> create(Params params); +public interface XlsFile { + Tuple2 , TableSchema> createInputFormat(Params params); + + FileOutputFormat createOutputFormat(Params params, TableSchema schema); } diff --git a/core/src/main/java/com/alibaba/alink/common/io/xls/XlsReaderClassLoader.java b/core/src/main/java/com/alibaba/alink/common/io/xls/XlsReaderClassLoader.java index f5228a642..0dc7e4b58 100644 --- a/core/src/main/java/com/alibaba/alink/common/io/xls/XlsReaderClassLoader.java +++ b/core/src/main/java/com/alibaba/alink/common/io/xls/XlsReaderClassLoader.java @@ -17,12 +17,12 @@ public XlsReaderClassLoader(String version) { super(new RegisterKey(XLS_NAME, version), PluginDistributeCache.createDistributeCache(XLS_NAME, version)); } - public static XlsReader create(XlsReaderClassLoader factory) { + public static XlsFile create(XlsReaderClassLoader factory) { ClassLoader classLoader = factory.create(); - + try (TemporaryClassLoaderContext context = TemporaryClassLoaderContext.of(classLoader)) { - Iterator iter = ServiceLoader - .load(XlsReader.class, classLoader) + Iterator iter = ServiceLoader + .load(XlsFile.class, classLoader) .iterator(); if (iter.hasNext()) { return iter.next(); @@ -37,8 +37,8 @@ public ClassLoader create() { return ClassLoaderContainer.getInstance().create( registerKey, distributeCache, - XlsReader.class, - xlsReader -> true, + XlsFile.class, + xlsFile -> true, descriptor -> registerKey.getVersion() ); } diff --git a/core/src/main/java/com/alibaba/alink/common/lazy/BasePMMLModelInfo.java b/core/src/main/java/com/alibaba/alink/common/lazy/BasePMMLModelInfo.java new file mode 100644 index 000000000..49b41f88f --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/lazy/BasePMMLModelInfo.java @@ -0,0 +1,27 @@ +package com.alibaba.alink.common.lazy; + +import javax.xml.bind.JAXBException; +import javax.xml.transform.stream.StreamResult; +import org.dmg.pmml.PMML; +import org.jpmml.model.JAXBUtil; + +import java.io.ByteArrayOutputStream; + +/** + * Base class for models who need to output PMML. + */ +public interface BasePMMLModelInfo { + + default String getPMML() { + try { + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + JAXBUtil.marshalPMML(toPMML(), new StreamResult(stream)); + return stream.toString(); + } catch (JAXBException e) { + throw new RuntimeException("PMML write stream error!"); + } + + } + + PMML toPMML(); +} diff --git a/core/src/main/java/com/alibaba/alink/common/lazy/LazyObjectsManager.java b/core/src/main/java/com/alibaba/alink/common/lazy/LazyObjectsManager.java index 32a1d6d30..23b65d73e 100644 --- a/core/src/main/java/com/alibaba/alink/common/lazy/LazyObjectsManager.java +++ b/core/src/main/java/com/alibaba/alink/common/lazy/LazyObjectsManager.java @@ -4,10 +4,12 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.params.shared.HasMLEnvironmentId; import com.alibaba.alink.pipeline.EstimatorBase; import com.alibaba.alink.pipeline.ModelBase; import com.alibaba.alink.pipeline.Trainer; +import com.alibaba.alink.pipeline.TrainerLegacy; import com.alibaba.alink.pipeline.TransformerBase; import com.alibaba.alink.pipeline.tuning.BaseTuning; import com.alibaba.alink.pipeline.tuning.Report; @@ -77,10 +79,18 @@ public LazyEvaluation > genLazyTrainOp(Trainer trainer) return genLazyObject(trainer, lazyTrainOps); } + public LazyEvaluation > genLazyTrainOp(TrainerLegacy trainer) { + return genLazyObject(trainer, lazyTrainOps); + } + public LazyEvaluation > genLazyModel(Trainer trainer) { return genLazyObject(trainer, lazyModels); } + public LazyEvaluation > genLazyModel(TrainerLegacy trainer) { + return genLazyObject(trainer, lazyModels); + } + public LazyEvaluation > genLazyTransformResult(TransformerBase transformer) { return genLazyObject(transformer, lazyTransformResults); } diff --git a/core/src/main/java/com/alibaba/alink/common/lazy/PipelineLazyCallbackUtils.java b/core/src/main/java/com/alibaba/alink/common/lazy/PipelineLazyCallbackUtils.java index 1fc8f4e00..a2e1e3f87 100644 --- a/core/src/main/java/com/alibaba/alink/common/lazy/PipelineLazyCallbackUtils.java +++ b/core/src/main/java/com/alibaba/alink/common/lazy/PipelineLazyCallbackUtils.java @@ -3,8 +3,11 @@ import com.alibaba.alink.common.MLEnvironment; import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithTrainInfo; import com.alibaba.alink.pipeline.ModelBase; import com.alibaba.alink.pipeline.Trainer; +import com.alibaba.alink.pipeline.TrainerLegacy; import com.alibaba.alink.pipeline.TransformerBase; import com.alibaba.alink.pipeline.tuning.BaseTuning; import com.alibaba.alink.pipeline.tuning.Report; @@ -48,6 +51,36 @@ public static void callbackForTrainerLazyTransformResult(Trainer trainer, }); } + @SuppressWarnings("unchecked") + public static void callbackForTrainerLazyTrainInfo(TrainerLegacy trainer, + List > callbacks) { + LazyObjectsManager lazyObjectsManager = LazyObjectsManager.getLazyObjectsManager(trainer); + LazyEvaluation > lazyTrainOp = lazyObjectsManager.genLazyTrainOp(trainer); + lazyTrainOp.addCallback(d -> { + ((WithTrainInfo ) d).lazyCollectTrainInfo(callbacks); + }); + } + + @SuppressWarnings("unchecked") + public static void callbackForTrainerLazyModelInfo(TrainerLegacy trainer, + List > callbacks) { + LazyObjectsManager lazyObjectsManager = LazyObjectsManager.getLazyObjectsManager(trainer); + LazyEvaluation > lazyTrainOp = lazyObjectsManager.genLazyTrainOp(trainer); + lazyTrainOp.addCallback(d -> { + ((WithModelInfoBatchOp ) d).lazyCollectModelInfo(callbacks); + }); + } + + public static void callbackForTrainerLazyTransformResult(TrainerLegacy trainer, + List >> callbacks) { + LazyObjectsManager lazyObjectsManager = LazyObjectsManager.getLazyObjectsManager(trainer); + LazyEvaluation > lazyModel = lazyObjectsManager.genLazyModel(trainer); + lazyModel.addCallback(model -> { + LazyEvaluation > lazyTransformResult = lazyObjectsManager.genLazyTransformResult(model); + lazyTransformResult.addCallbacks(callbacks); + }); + } + public static void callbackForTransformerLazyTransformResult(TransformerBase transformer, List >> callbacks) { MLEnvironment mlEnv = MLEnvironmentFactory.get(transformer.getMLEnvironmentId()); diff --git a/core/src/main/java/com/alibaba/alink/common/lazy/WithTrainInfoLocalOp.java b/core/src/main/java/com/alibaba/alink/common/lazy/WithTrainInfoLocalOp.java new file mode 100644 index 000000000..f71ee0aec --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/lazy/WithTrainInfoLocalOp.java @@ -0,0 +1,60 @@ +package com.alibaba.alink.common.lazy; + +import org.apache.flink.types.Row; + +import com.alibaba.alink.operator.local.LocalOperator; +import com.alibaba.alink.operator.local.lazy.LocalLazyObjectsManager; + +import java.util.Arrays; +import java.util.List; +import java.util.function.Consumer; + +public interface WithTrainInfoLocalOp> { + + S createTrainInfo(List rows); + + LocalOperator getSideOutputTrainInfo(); + + default T lazyPrintTrainInfo(String title) { + return lazyCollectTrainInfo(d -> { + if (null != title) { + System.out.println(title); + } + System.out.println(d); + }); + } + + default T lazyPrintTrainInfo() { + return lazyPrintTrainInfo(null); + } + + default T lazyCollectTrainInfo(List > callbacks) { + Consumer > consumer = new Consumer >() { + @Override + public void accept(LocalOperator op) { + ((WithTrainInfoLocalOp ) op).getSideOutputTrainInfo().lazyCollect(d -> { + S trainInfo = createTrainInfo(d); + for (Consumer callback : callbacks) { + callback.accept(trainInfo); + } + }); + } + }; + if (((T) this).isNullOutputTable()) { + LocalLazyObjectsManager lazyObjectsManager = LocalLazyObjectsManager.getLazyObjectsManager((T) this); + LazyEvaluation > lazyOpAfterLinked = lazyObjectsManager.genLazyOpAfterLinked((T) this); + lazyOpAfterLinked.addCallback(consumer); + } else { + consumer.accept((T) this); + } + return ((T) this); + } + + default T lazyCollectTrainInfo(Consumer ... callbacks) { + return lazyCollectTrainInfo(Arrays.asList(callbacks)); + } + + default S collectTrainInfo() { + return createTrainInfo(getSideOutputTrainInfo().collect()); + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/linalg/DenseVector.java b/core/src/main/java/com/alibaba/alink/common/linalg/DenseVector.java index a509927e9..f882d26ce 100644 --- a/core/src/main/java/com/alibaba/alink/common/linalg/DenseVector.java +++ b/core/src/main/java/com/alibaba/alink/common/linalg/DenseVector.java @@ -1,6 +1,6 @@ package com.alibaba.alink.common.linalg; -import com.alibaba.alink.common.DataTypeDisplayInterface; +import com.alibaba.alink.common.viz.DataTypeDisplayInterface; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; import com.alibaba.alink.common.linalg.VectorUtil.VectorSerialType; diff --git a/core/src/main/java/com/alibaba/alink/common/linalg/SparseVector.java b/core/src/main/java/com/alibaba/alink/common/linalg/SparseVector.java index cac3fd27e..a3c348e5a 100644 --- a/core/src/main/java/com/alibaba/alink/common/linalg/SparseVector.java +++ b/core/src/main/java/com/alibaba/alink/common/linalg/SparseVector.java @@ -1,6 +1,6 @@ package com.alibaba.alink.common.linalg; -import com.alibaba.alink.common.DataTypeDisplayInterface; +import com.alibaba.alink.common.viz.DataTypeDisplayInterface; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import com.alibaba.alink.common.linalg.VectorUtil.VectorSerialType; diff --git a/core/src/main/java/com/alibaba/alink/common/linalg/Tensor.java b/core/src/main/java/com/alibaba/alink/common/linalg/Tensor.java new file mode 100644 index 000000000..cb52cc242 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/linalg/Tensor.java @@ -0,0 +1,470 @@ +package com.alibaba.alink.common.linalg; + +import org.apache.flink.api.java.tuple.Tuple2; + +import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; +import com.alibaba.alink.common.exceptions.AkIllegalDataException; +import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; +import org.apache.commons.lang.StringUtils; + +import java.io.Serializable; +import java.util.Arrays; + +/** + * We use the concept of 'tensor' to uniform: + * 1. vector (1-vectorSize tensor) + * 2. matrix (2-vectorSize tensor) + * 3. higher ordered tensors + *

+ * All of them can be in dense or sparse form. + *

+ * Some examples: + * -# vectors + * [1, 2, 0, 3, 0] is serialized as "1,2,0,3,0" in the dense case, "0:1,1:2,3:3" in the sparse case. + *

+ * -# matrix + * [[1, 0, 3],[0, 0, 6]] is serialized as "$2,3$1,0,3,0,0,6" in the dense case, + * "0:0:1,0:2:3,1:2:6" or "$2,3$0:0:1,0:2:3,1:2:6" in the sparse case. + */ +public class Tensor implements Serializable { + private static final long serialVersionUID = 6054218217808496419L; + private int[] shapes = null; // shape[i] == -1 indicates that size at i-th vectorSize is unknown + private int[][] indices = null; + private double[] data = null; + private boolean isSparse; + + private Tensor() { + // an empty tensor is considered a sparse vector with all zeros + this.isSparse = true; + this.data = new double[0]; + this.shapes = new int[] {-1}; + } + + private Tensor(int[] shapes) { + // an empty tensor with shapes is considered a sparse tensor with all zeros + assert (shapes != null); + this.isSparse = true; + this.shapes = shapes; + this.data = new double[0]; + } + + public Tensor(double[] data) { + this(new int[] {data.length}, data); + } + + public Tensor(int[] shapes, double[] data) { + assert (shapes != null); + this.isSparse = false; + this.shapes = shapes; + this.data = data; + + // infer the size of leading dimension if not given + int stride = 1; + for (int i = 1; i < shapes.length; i++) { + if (shapes[i] < 0) { + throw new AkIllegalArgumentException("invalid shapes"); + } + stride *= shapes[i]; + } + + if (shapes.length >= 1) { + shapes[0] = data.length / stride; + } + } + + public Tensor(int[] shapes, int indices[][], double[] data) { + assert (shapes != null); + this.isSparse = true; + this.shapes = shapes; + this.indices = indices; + this.data = data; + } + + public static Tuple2 parseSparseTensor(String str) { + int numValues = 1; + for (int i = 0; i < str.length(); i++) { + if (str.charAt(i) == ',') { + numValues++; + } + } + int[] indices = new int[numValues]; + float[] values = new float[numValues]; + + int startPos = StringUtils.lastIndexOf(str, '$') + 1; + int endPos = -1; + int delimiterPos; + + for (int i = 0; i < numValues; i++) { + // extract the value string + endPos = StringUtils.indexOf(str, ',', startPos); + if (endPos == -1) { + endPos = str.length(); + } + delimiterPos = StringUtils.indexOf(str, ':', startPos); + if (delimiterPos == -1) { + throw new AkIllegalDataException("invalid data: " + str); + } + indices[i] = Integer.valueOf(StringUtils.substring(str, startPos, delimiterPos)); + values[i] = Float.valueOf(StringUtils.substring(str, delimiterPos + 1, endPos)); + startPos = endPos + 1; + } + + return Tuple2.of(indices, values); + } + + public static Tensor parse(String str) { + try { + str = StringUtils.trim(str); + + if (str.isEmpty()) { + return new Tensor(); + } + + int[] shapes = null; + if (str.charAt(0) == '$') { + int lastPos = StringUtils.lastIndexOf(str, '$'); + String shapeInfo = StringUtils.substring(str, 1, lastPos); + String[] shapesStr = StringUtils.split(shapeInfo, ','); + shapes = new int[shapesStr.length]; + for (int i = 0; i < shapes.length; i++) { + shapes[i] = Integer.valueOf(shapesStr[i].trim()); + } + str = StringUtils.substring(str, lastPos + 1); + str = StringUtils.trim(str); + } + + if (str.isEmpty()) { + return new Tensor(shapes); + } + + int numValues = StringUtils.countMatches(str, ",") + 1; + + // check dense or sparse + boolean isSparse = (StringUtils.indexOf(str, ':') != -1); + + if (isSparse) { + int ndim = -1; + if (null != shapes) { + ndim = shapes.length; + } + double[] data = new double[numValues]; + int[][] indices = null; + int startPos = 0; + int endPos = -1; + for (int i = 0; i < numValues; i++) { + // extract the value string + endPos = StringUtils.indexOf(str, ",", startPos); + if (endPos == -1) { + endPos = str.length(); + } + String valueStr = StringUtils.substring(str, startPos, endPos); + startPos = endPos + 1; + + if (ndim == -1) { + ndim = 0; + for (int j = 0; j < valueStr.length(); j++) { + if (valueStr.charAt(j) == ':') { + ndim++; + } + } + } + if (indices == null) { + indices = new int[numValues][ndim]; + } + if (shapes == null) { + shapes = new int[ndim]; + Arrays.fill(shapes, -1); + } + String[] kvStr = StringUtils.split(valueStr, ':'); + if (kvStr.length != ndim + 1) { + throw new AkIllegalDataException("mismatched size of tensor"); + } + for (int j = 0; j < kvStr.length - 1; j++) { + indices[i][j] = Integer.valueOf(kvStr[j].trim()); + } + data[i] = Double.valueOf(kvStr[ndim].trim()); + } + return new Tensor(shapes, indices, data); + } else { + if (shapes == null) { + shapes = new int[] {numValues}; + } + double[] data = new double[numValues]; + + int startPos = 0; + int endPos = -1; + for (int i = 0; i < numValues; i++) { + // extract the value string + endPos = StringUtils.indexOf(str, ",", startPos); + if (endPos == -1) { + endPos = str.length(); + } + String valueStr = StringUtils.substring(str, startPos, endPos); + startPos = endPos + 1; + + data[i] = Double.valueOf(valueStr); + } + return new Tensor(shapes, data); + } + } catch (Exception e) { + e.printStackTrace(); + throw new AkIllegalDataException("fail to getVector tensor \"" + str + "\""); + } + } + + public Tensor expandDim(int axis) { + if (isSparse) { + throw new AkIllegalArgumentException("expand vectorSize for sparse tensor not implemented."); + } + + int ndim = this.shapes.length; + if (axis > ndim || axis < -1 - ndim) { + throw new AkIllegalArgumentException("invalid axis: " + axis); + } + + if (axis < 0) { + axis = ndim + 1 + axis; + } + + int[] newShapes = new int[ndim + 1]; + + int i = 0; + for (; i < axis; i++) { + newShapes[i] = this.shapes[i]; + } + newShapes[i] = 1; + for (; i < ndim; i++) { + newShapes[i + 1] = this.shapes[i]; + } + + this.shapes = newShapes; + return this; + } + + public Tensor reshape(int[] newshapes) { + if (isSparse) { + int[] stride = new int[shapes.length]; + stride[stride.length - 1] = 1; + for (int i = 0; i < stride.length - 1; i++) { + stride[stride.length - 1 - 1 - i] = stride[stride.length - 1 - i] * shapes[shapes.length - 1 - i]; + } + + int[] newstride = new int[newshapes.length]; + newstride[newstride.length - 1] = 1; + for (int i = 0; i < newstride.length - 1; i++) { + newstride[newstride.length - 1 - 1 - i] = newstride[newstride.length - 1 - i] * newshapes[ + newshapes.length - 1 - i]; + } + + int[][] newIndices = new int[indices.length][newshapes.length]; + + for (int i = 0; i < indices.length; i++) { + int pos = 0; + for (int j = 0; j < indices[i].length; j++) { + pos += indices[i][j] * stride[j]; + } + + for (int j = 0; j < newIndices[i].length; j++) { + newIndices[i][j] = pos / newstride[j]; + pos = pos % newstride[j]; + } + } + + this.indices = newIndices; + this.shapes = newshapes; + + } else { + this.shapes = newshapes; + } + return this; + } + + public int[] getShapes() { + return shapes; + } + + public double[] getData() { + return data; + } + + public int[][] getIndices() { + return indices; + } + + public boolean isSparse() { + return isSparse; + } + + public Tensor standard(double[] mean, double[] stdvar) { + assert (mean.length == stdvar.length); + + if (isSparse) { + for (int i = 0; i < indices.length; i++) { + int which = mean.length == 1 ? 0 : indices[i][0]; + data[i] -= mean[which]; + data[i] *= (1.0 / stdvar[which]); + } + } else { + int size = data.length; + int stride = size / mean.length; + + for (int i = 0; i < size; i++) { + int which = i / stride; + data[i] -= mean[which]; + data[i] *= (1.0 / stdvar[which]); + } + } + return this; + } + + public Tensor normalize(double[] min, double[] max) { + assert (min.length == max.length); + + if (isSparse) { + for (int i = 0; i < indices.length; i++) { + int which = min.length == 1 ? 0 : indices[i][0]; + data[i] -= min[which]; + data[i] *= 1.0 / (max[which] - min[which]); + } + } else { + int size = data.length; + int stride = size / min.length; + + for (int i = 0; i < size; i++) { + int which = i / stride; + data[i] -= min[which]; + data[i] *= 1.0 / (max[which] - min[which]); + } + } + return this; + } + + public Tensor toDense() { + if (isSparse) { + for (int i = 0; i < shapes.length; i++) { + if (shapes[i] == -1) { + throw new AkUnclassifiedErrorException("can't convert to dense tensor because shapes is unknown"); + } + } + + int size = 1; + for (int i = 0; i < shapes.length; i++) { + size *= shapes[i]; + } + + int[] stride = new int[shapes.length]; + stride[stride.length - 1] = 1; + for (int i = 0; i < stride.length - 1; i++) { + stride[stride.length - 1 - 1 - i] = stride[stride.length - 1 - i] * shapes[shapes.length - 1 - i]; + } + + double[] newdata = new double[size]; + Arrays.fill(newdata, 0.); + for (int i = 0; i < indices.length; i++) { + int pos = 0; + for (int j = 0; j < indices[i].length; j++) { + pos += indices[i][j] * stride[j]; + } + newdata[pos] = data[i]; + } + data = newdata; + indices = null; + isSparse = false; + return this; + } else { + return this; + } + } + + public String serialize() { + boolean withShape = false; + + if (shapes != null) { + if (isSparse || shapes.length > 1) { + for (int i = 0; i < shapes.length; i++) { + if (shapes[i] != -1) { + withShape = true; + } + } + } + } + + StringBuilder sbd = new StringBuilder(); + + if (withShape) { + sbd.append("$"); + for (int i = 0; i < shapes.length; i++) { + sbd.append(shapes[i]); + if (i < shapes.length - 1) { + sbd.append(","); + } + } + sbd.append("$"); + } + + if (isSparse) { + if (null != indices) { + assert (indices.length == data.length); + for (int i = 0; i < indices.length; i++) { + for (int j = 0; j < indices[i].length; j++) { + sbd.append(indices[i][j] + ":"); + } + sbd.append(data[i]); + if (i < indices.length - 1) { + sbd.append(","); + } + } + } + } else { + for (int i = 0; i < data.length; i++) { + sbd.append(data[i]); + if (i < data.length - 1) { + sbd.append(","); + } + } + } + + return sbd.toString(); + } + + public DenseVector toDenseVector() { + int dim = shapes.length; + if (dim != 1) { + throw new AkUnclassifiedErrorException("the data can't be converted to a vector because of dimension error"); + } + + if (isSparse) { + if (shapes[0] < 0) { + throw new AkUnclassifiedErrorException("the data can't be converted to a dense vector because the " + + "data is in sparse format and its' size is not specified"); + } + return toSparseVector().toDenseVector(); + } else { + return new DenseVector(data); + } + } + + public SparseVector toSparseVector() { + if (isSparse) { + int dim = shapes.length; + if (dim != 1) { + throw new AkUnclassifiedErrorException("the data can't be converted to sparse vector"); + } + if (null == indices) { + return new SparseVector(shapes[0]); + } else { + int[] idx = new int[indices.length]; + for (int i = 0; i < idx.length; i++) { + idx[i] = indices[i][0]; + } + return new SparseVector(shapes[0], idx, data); + } + } else { + int idx[] = new int[data.length]; + for (int i = 0; i < data.length; i++) { + idx[i] = i; + } + return new SparseVector(idx.length, idx, data); + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/linalg/jama/CholeskyDecomposition.java b/core/src/main/java/com/alibaba/alink/common/linalg/jama/CholeskyDecomposition.java new file mode 100644 index 000000000..c9bc2a068 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/linalg/jama/CholeskyDecomposition.java @@ -0,0 +1,213 @@ +package com.alibaba.alink.common.linalg.jama; + +/** + * Cholesky Decomposition. + *

+ * For a symmetric, positive definite matrix A, the Cholesky decomposition + * is an lower triangular matrix L so that A = L*L'. + *

+ * If the matrix is not symmetric or positive definite, the constructor + * returns a partial decomposition and sets an internal flag that may + * be queried by the isSPD() method. + */ + +class CholeskyDecomposition implements java.io.Serializable { + private static final long serialVersionUID = -693958715904680364L; + +/* ------------------------ + Class variables + * ------------------------ */ + + /** + * Array for internal storage of decomposition. + * + * @serial internal array storage. + */ + private double[][] L; + + /** + * Row and column dimension (square matrix). + * + * @serial matrix dimension. + */ + private int n; + + /** + * Symmetric and positive definite flag. + * + * @serial is symmetric and positive definite flag. + */ + private boolean isspd; + +/* ------------------------ + Constructor + * ------------------------ */ + + /** + * Cholesky algorithm for symmetric and positive definite matrix. + * + * @param Arg Square, symmetric matrix. + * @return Structure to access L and isspd flag. + */ + + public CholeskyDecomposition(JaMatrix Arg) { + + // Initialize. + double[][] A = Arg.getArray(); + n = Arg.getRowDimension(); + L = new double[n][n]; + isspd = (Arg.getColumnDimension() == n); + // Main loop. + for (int j = 0; j < n; j++) { + double[] Lrowj = L[j]; + double d = 0.0; + for (int k = 0; k < j; k++) { + double[] Lrowk = L[k]; + double s = 0.0; + for (int i = 0; i < k; i++) { + s += Lrowk[i] * Lrowj[i]; + } + Lrowj[k] = s = (A[j][k] - s) / L[k][k]; + d = d + s * s; + isspd = isspd & (A[k][j] == A[j][k]); + } + d = A[j][j] - d; + isspd = isspd & (d > 0.0); + L[j][j] = Math.sqrt(Math.max(d, 0.0)); + for (int k = j + 1; k < n; k++) { + L[j][k] = 0.0; + } + } + } + +/* ------------------------ + Temporary, experimental code. + * ------------------------ *\ + + \** Right Triangular Cholesky Decomposition. +

+ For a symmetric, positive definite matrix A, the Right Cholesky + decomposition is an upper triangular matrix R so that A = R'*R. + This constructor computes R with the Fortran inspired column oriented + algorithm used in LINPACK and MATLAB. In Java, we suspect a row oriented, + lower triangular decomposition is faster. We have temporarily included + this constructor here until timing experiments confirm this suspicion. + *\ + + \** Array for internal storage of right triangular decomposition. **\ + private transient double[][] R; + + \** Cholesky algorithm for symmetric and positive definite matrix. + @param A Square, symmetric matrix. + @param rightflag Actual value ignored. + @return Structure to access R and isspd flag. + *\ + + public CholeskyDecomposition (Matrix Arg, int rightflag) { + // Initialize. + double[][] A = Arg.getArray(); + n = Arg.getColumnDimension(); + R = new double[n][n]; + isspd = (Arg.getColumnDimension() == n); + // Main loop. + for (int j = 0; j < n; j++) { + double d = 0.0; + for (int k = 0; k < j; k++) { + double s = A[k][j]; + for (int i = 0; i < k; i++) { + s = s - R[i][k]*R[i][j]; + } + R[k][j] = s = s/R[k][k]; + d = d + s*s; + isspd = isspd & (A[k][j] == A[j][k]); + } + d = A[j][j] - d; + isspd = isspd & (d > 0.0); + R[j][j] = Math.Sqrt(Math.max(d,0.0)); + for (int k = j+1; k < n; k++) { + R[k][j] = 0.0; + } + } + } + + \** Return upper triangular factor. + @return R + *\ + + public Matrix getR () { + return new Matrix(R,n,n); + } + +\* ------------------------ + End of temporary code. + * ------------------------ */ + +/* ------------------------ + Public Methods + * ------------------------ */ + + /** + * Is the matrix symmetric and positive definite? + * + * @return true if A is symmetric and positive definite. + */ + + public boolean isSPD() { + return isspd; + } + + /** + * Return triangular factor. + * + * @return L + */ + + public JaMatrix getL() { + return new JaMatrix(L, n, n); + } + + /** + * Solve A*X = B + * + * @param B A Matrix with as many rows as A and any number of columns. + * @return X so that L*L'*X = B + * @throws IllegalArgumentException Matrix row dimensions must agree. + * @throws RuntimeException Matrix is not symmetric positive definite. + */ + + public JaMatrix solve(JaMatrix B) { + if (B.getRowDimension() != n) { + throw new IllegalArgumentException("Matrix row dimensions must agree."); + } + if (!isspd) { + throw new RuntimeException("Matrix is not symmetric positive definite."); + } + + // Copy right hand side. + double[][] X = B.getArrayCopy(); + int nx = B.getColumnDimension(); + + // Solve L*Y = B; + for (int k = 0; k < n; k++) { + for (int j = 0; j < nx; j++) { + for (int i = 0; i < k; i++) { + X[k][j] -= X[i][j] * L[k][i]; + } + X[k][j] /= L[k][k]; + } + } + + // Solve L'*X = Y; + for (int k = n - 1; k >= 0; k--) { + for (int j = 0; j < nx; j++) { + for (int i = k + 1; i < n; i++) { + X[k][j] -= X[i][j] * L[i][k]; + } + X[k][j] /= L[k][k]; + } + } + + return new JaMatrix(X, n, nx); + } +} + diff --git a/core/src/main/java/com/alibaba/alink/common/linalg/jama/EigenvalueDecomposition.java b/core/src/main/java/com/alibaba/alink/common/linalg/jama/EigenvalueDecomposition.java new file mode 100644 index 000000000..ed916917c --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/linalg/jama/EigenvalueDecomposition.java @@ -0,0 +1,974 @@ +package com.alibaba.alink.common.linalg.jama; + +/** + * Eigenvalues and eigenvectors of a real matrix. + *

+ * If A is symmetric, then A = V*D*V' where the eigenvalue matrix D is + * diagonal and the eigenvector matrix V is orthogonal. + * I.e. A = V.times(D.times(V.transpose())) and + * V.times(V.transpose()) equals the identity matrix. + *

+ * If A is not symmetric, then the eigenvalue matrix D is block diagonal + * with the real eigenvalues in 1-by-1 blocks and any complex eigenvalues, + * lambda + i*mu, in 2-by-2 blocks, [lambda, mu; -mu, lambda]. The + * columns of V represent the eigenvectors in the sense that A*V = V*D, + * i.e. A.times(V) equals V.times(D). The matrix V may be badly + * conditioned, or even singular, so the validity of the equation + * A = V*D*inverse(V) depends upon V.cond(). + **/ + +class EigenvalueDecomposition implements java.io.Serializable { + private static final long serialVersionUID = -3744107756621089965L; + +/* ------------------------ + Class variables + * ------------------------ */ + + /** + * Row and column dimension (square matrix). + * + * @serial matrix dimension. + */ + private int n; + + /** + * Symmetry flag. + * + * @serial internal symmetry flag. + */ + private boolean issymmetric; + + /** + * Arrays for internal storage of eigenvalues. + * + * @serial internal storage of eigenvalues. + */ + private double[] d, e; + + /** + * Array for internal storage of eigenvectors. + * + * @serial internal storage of eigenvectors. + */ + private double[][] V; + + /** + * Array for internal storage of nonsymmetric Hessenberg form. + * + * @serial internal storage of nonsymmetric Hessenberg form. + */ + private double[][] H; + + /** + * Working storage for nonsymmetric algorithm. + * + * @serial working storage for nonsymmetric algorithm. + */ + private double[] ort; + +/* ------------------------ + Private Methods + * ------------------------ */ + + // Symmetric Householder reduction to tridiagonal form. + private transient double cdivr, cdivi; + + // Symmetric tridiagonal QL algorithm. + + /** + * Check for symmetry, then construct the eigenvalue decomposition + * + * @param Arg Square matrix + * @return Structure to access D and V. + */ + + public EigenvalueDecomposition(JaMatrix Arg) { + double[][] A = Arg.getArray(); + n = Arg.getColumnDimension(); + V = new double[n][n]; + d = new double[n]; + e = new double[n]; + + issymmetric = true; + for (int j = 0; (j < n) & issymmetric; j++) { + for (int i = 0; (i < n) & issymmetric; i++) { + issymmetric = (A[i][j] == A[j][i]); + } + } + + if (issymmetric) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + V[i][j] = A[i][j]; + } + } + + // Tridiagonalize. + tred2(); + + // Diagonalize. + tql2(); + + } else { + H = new double[n][n]; + ort = new double[n]; + + for (int j = 0; j < n; j++) { + for (int i = 0; i < n; i++) { + H[i][j] = A[i][j]; + } + } + + // Reduce to Hessenberg form. + orthes(); + + // Reduce Hessenberg to real Schur form. + hqr2(); + } + } + + // Nonsymmetric reduction to Hessenberg form. + + private void tred2() { + + // This is derived from the Algol procedures tred2 by + // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for + // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + for (int j = 0; j < n; j++) { + d[j] = V[n - 1][j]; + } + + // Householder reduction to tridiagonal form. + + for (int i = n - 1; i > 0; i--) { + + // Scale to avoid under/overflow. + + double scale = 0.0; + double h = 0.0; + for (int k = 0; k < i; k++) { + scale = scale + Math.abs(d[k]); + } + if (scale == 0.0) { + e[i] = d[i - 1]; + for (int j = 0; j < i; j++) { + d[j] = V[i - 1][j]; + V[i][j] = 0.0; + V[j][i] = 0.0; + } + } else { + + // Generate Householder vector. + + for (int k = 0; k < i; k++) { + d[k] /= scale; + h += d[k] * d[k]; + } + double f = d[i - 1]; + double g = Math.sqrt(h); + if (f > 0) { + g = -g; + } + e[i] = scale * g; + h = h - f * g; + d[i - 1] = f - g; + for (int j = 0; j < i; j++) { + e[j] = 0.0; + } + + // Apply calc transformation to remaining columns. + + for (int j = 0; j < i; j++) { + f = d[j]; + V[j][i] = f; + g = e[j] + V[j][j] * f; + for (int k = j + 1; k <= i - 1; k++) { + g += V[k][j] * d[k]; + e[k] += V[k][j] * f; + } + e[j] = g; + } + f = 0.0; + for (int j = 0; j < i; j++) { + e[j] /= h; + f += e[j] * d[j]; + } + double hh = f / (h + h); + for (int j = 0; j < i; j++) { + e[j] -= hh * d[j]; + } + for (int j = 0; j < i; j++) { + f = d[j]; + g = e[j]; + for (int k = j; k <= i - 1; k++) { + V[k][j] -= (f * e[k] + g * d[k]); + } + d[j] = V[i - 1][j]; + V[i][j] = 0.0; + } + } + d[i] = h; + } + + // Accumulate transformations. + + for (int i = 0; i < n - 1; i++) { + V[n - 1][i] = V[i][i]; + V[i][i] = 1.0; + double h = d[i + 1]; + if (h != 0.0) { + for (int k = 0; k <= i; k++) { + d[k] = V[k][i + 1] / h; + } + for (int j = 0; j <= i; j++) { + double g = 0.0; + for (int k = 0; k <= i; k++) { + g += V[k][i + 1] * V[k][j]; + } + for (int k = 0; k <= i; k++) { + V[k][j] -= g * d[k]; + } + } + } + for (int k = 0; k <= i; k++) { + V[k][i + 1] = 0.0; + } + } + for (int j = 0; j < n; j++) { + d[j] = V[n - 1][j]; + V[n - 1][j] = 0.0; + } + V[n - 1][n - 1] = 1.0; + e[0] = 0.0; + } + + // Complex scalar division. + + private void tql2() { + + // This is derived from the Algol procedures tql2, by + // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for + // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + for (int i = 1; i < n; i++) { + e[i - 1] = e[i]; + } + e[n - 1] = 0.0; + + double f = 0.0; + double tst1 = 0.0; + double eps = Math.pow(2.0, -52.0); + for (int l = 0; l < n; l++) { + + // Find small subdiagonal element + + tst1 = Math.max(tst1, Math.abs(d[l]) + Math.abs(e[l])); + int m = l; + while (m < n) { + if (Math.abs(e[m]) <= eps * tst1) { + break; + } + m++; + } + + // If m == l, d[l] is an eigenvalue, + // otherwise, iterate. + + if (m > l) { + int iter = 0; + do { + iter = iter + 1; // (Could check iteration count here.) + + // Compute implicit shift + + double g = d[l]; + double p = (d[l + 1] - g) / (2.0 * e[l]); + double r = Maths.hypot(p, 1.0); + if (p < 0) { + r = -r; + } + d[l] = e[l] / (p + r); + d[l + 1] = e[l] * (p + r); + double dl1 = d[l + 1]; + double h = g - d[l]; + for (int i = l + 2; i < n; i++) { + d[i] -= h; + } + f = f + h; + + // Implicit QL transformation. + + p = d[m]; + double c = 1.0; + double c2 = c; + double c3 = c; + double el1 = e[l + 1]; + double s = 0.0; + double s2 = 0.0; + for (int i = m - 1; i >= l; i--) { + c3 = c2; + c2 = c; + s2 = s; + g = c * e[i]; + h = c * p; + r = Maths.hypot(p, e[i]); + e[i + 1] = s * r; + s = e[i] / r; + c = p / r; + p = c * d[i] - s * g; + d[i + 1] = h + s * (c * g + s * d[i]); + + // Accumulate transformation. + + for (int k = 0; k < n; k++) { + h = V[k][i + 1]; + V[k][i + 1] = s * V[k][i] + c * h; + V[k][i] = c * V[k][i] - s * h; + } + } + p = -s * s2 * c3 * el1 * e[l] / dl1; + e[l] = s * p; + d[l] = c * p; + + // Check for convergence. + + } while (Math.abs(e[l]) > eps * tst1); + } + d[l] = d[l] + f; + e[l] = 0.0; + } + + // Sort eigenvalues and corresponding vectors. + + for (int i = 0; i < n - 1; i++) { + int k = i; + double p = d[i]; + for (int j = i + 1; j < n; j++) { + if (d[j] < p) { + k = j; + p = d[j]; + } + } + if (k != i) { + d[k] = d[i]; + d[i] = p; + for (int j = 0; j < n; j++) { + p = V[j][i]; + V[j][i] = V[j][k]; + V[j][k] = p; + } + } + } + } + + private void orthes() { + + // This is derived from the Algol procedures orthes and ortran, + // by Martin and Wilkinson, Handbook for Auto. Comp., + // Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutines in EISPACK. + + int low = 0; + int high = n - 1; + + for (int m = low + 1; m <= high - 1; m++) { + + // Scale column. + + double scale = 0.0; + for (int i = m; i <= high; i++) { + scale = scale + Math.abs(H[i][m - 1]); + } + if (scale != 0.0) { + + // Compute Householder transformation. + + double h = 0.0; + for (int i = high; i >= m; i--) { + ort[i] = H[i][m - 1] / scale; + h += ort[i] * ort[i]; + } + double g = Math.sqrt(h); + if (ort[m] > 0) { + g = -g; + } + h = h - ort[m] * g; + ort[m] = ort[m] - g; + + // Apply Householder calc transformation + // H = (I-u*u'/h)*H*(I-u*u')/h) + + for (int j = m; j < n; j++) { + double f = 0.0; + for (int i = high; i >= m; i--) { + f += ort[i] * H[i][j]; + } + f = f / h; + for (int i = m; i <= high; i++) { + H[i][j] -= f * ort[i]; + } + } + + for (int i = 0; i <= high; i++) { + double f = 0.0; + for (int j = high; j >= m; j--) { + f += ort[j] * H[i][j]; + } + f = f / h; + for (int j = m; j <= high; j++) { + H[i][j] -= f * ort[j]; + } + } + ort[m] = scale * ort[m]; + H[m][m - 1] = scale * g; + } + } + + // Accumulate transformations (Algol's ortran). + + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + V[i][j] = (i == j ? 1.0 : 0.0); + } + } + + for (int m = high - 1; m >= low + 1; m--) { + if (H[m][m - 1] != 0.0) { + for (int i = m + 1; i <= high; i++) { + ort[i] = H[i][m - 1]; + } + for (int j = m; j <= high; j++) { + double g = 0.0; + for (int i = m; i <= high; i++) { + g += ort[i] * V[i][j]; + } + // Double division avoids possible underflow + g = (g / ort[m]) / H[m][m - 1]; + for (int i = m; i <= high; i++) { + V[i][j] += g * ort[i]; + } + } + } + } + } + + // Nonsymmetric reduction from Hessenberg to real Schur form. + + private void cdiv(double xr, double xi, double yr, double yi) { + double r, d; + if (Math.abs(yr) > Math.abs(yi)) { + r = yi / yr; + d = yr + r * yi; + cdivr = (xr + r * xi) / d; + cdivi = (xi - r * xr) / d; + } else { + r = yr / yi; + d = yi + r * yr; + cdivr = (r * xr + xi) / d; + cdivi = (r * xi - xr) / d; + } + } + + +/* ------------------------ + Constructor + * ------------------------ */ + + private void hqr2() { + + // This is derived from the Algol procedure hqr2, + // by Martin and Wilkinson, Handbook for Auto. Comp., + // Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + // Initialize + + int nn = this.n; + int n = nn - 1; + int low = 0; + int high = nn - 1; + double eps = Math.pow(2.0, -52.0); + double exshift = 0.0; + double p = 0, q = 0, r = 0, s = 0, z = 0, t, w, x, y; + + // Store roots isolated by balanc and compute matrix norm + + double norm = 0.0; + for (int i = 0; i < nn; i++) { + if (i < low | i > high) { + d[i] = H[i][i]; + e[i] = 0.0; + } + for (int j = Math.max(i - 1, 0); j < nn; j++) { + norm = norm + Math.abs(H[i][j]); + } + } + + // Outer loop over eigenvalue index + + int iter = 0; + while (n >= low) { + + // Look for single small sub-diagonal element + + int l = n; + while (l > low) { + s = Math.abs(H[l - 1][l - 1]) + Math.abs(H[l][l]); + if (s == 0.0) { + s = norm; + } + if (Math.abs(H[l][l - 1]) < eps * s) { + break; + } + l--; + } + + // Check for convergence + // One root found + + if (l == n) { + H[n][n] = H[n][n] + exshift; + d[n] = H[n][n]; + e[n] = 0.0; + n--; + iter = 0; + + // Two roots found + + } else if (l == n - 1) { + w = H[n][n - 1] * H[n - 1][n]; + p = (H[n - 1][n - 1] - H[n][n]) / 2.0; + q = p * p + w; + z = Math.sqrt(Math.abs(q)); + H[n][n] = H[n][n] + exshift; + H[n - 1][n - 1] = H[n - 1][n - 1] + exshift; + x = H[n][n]; + + // Real pair + + if (q >= 0) { + if (p >= 0) { + z = p + z; + } else { + z = p - z; + } + d[n - 1] = x + z; + d[n] = d[n - 1]; + if (z != 0.0) { + d[n] = x - w / z; + } + e[n - 1] = 0.0; + e[n] = 0.0; + x = H[n][n - 1]; + s = Math.abs(x) + Math.abs(z); + p = x / s; + q = z / s; + r = Math.sqrt(p * p + q * q); + p = p / r; + q = q / r; + + // Row modification + + for (int j = n - 1; j < nn; j++) { + z = H[n - 1][j]; + H[n - 1][j] = q * z + p * H[n][j]; + H[n][j] = q * H[n][j] - p * z; + } + + // Column modification + + for (int i = 0; i <= n; i++) { + z = H[i][n - 1]; + H[i][n - 1] = q * z + p * H[i][n]; + H[i][n] = q * H[i][n] - p * z; + } + + // Accumulate transformations + + for (int i = low; i <= high; i++) { + z = V[i][n - 1]; + V[i][n - 1] = q * z + p * V[i][n]; + V[i][n] = q * V[i][n] - p * z; + } + + // Complex pair + + } else { + d[n - 1] = x + p; + d[n] = x + p; + e[n - 1] = z; + e[n] = -z; + } + n = n - 2; + iter = 0; + + // No convergence yet + + } else { + + // Form shift + + x = H[n][n]; + y = 0.0; + w = 0.0; + if (l < n) { + y = H[n - 1][n - 1]; + w = H[n][n - 1] * H[n - 1][n]; + } + + // Wilkinson's original ad hoc shift + + if (iter == 10) { + exshift += x; + for (int i = low; i <= n; i++) { + H[i][i] -= x; + } + s = Math.abs(H[n][n - 1]) + Math.abs(H[n - 1][n - 2]); + x = y = 0.75 * s; + w = -0.4375 * s * s; + } + + // MATLAB's new ad hoc shift + + if (iter == 30) { + s = (y - x) / 2.0; + s = s * s + w; + if (s > 0) { + s = Math.sqrt(s); + if (y < x) { + s = -s; + } + s = x - w / ((y - x) / 2.0 + s); + for (int i = low; i <= n; i++) { + H[i][i] -= s; + } + exshift += s; + x = y = w = 0.964; + } + } + + iter = iter + 1; // (Could check iteration count here.) + + // Look for two consecutive small sub-diagonal elements + + int m = n - 2; + while (m >= l) { + z = H[m][m]; + r = x - z; + s = y - z; + p = (r * s - w) / H[m + 1][m] + H[m][m + 1]; + q = H[m + 1][m + 1] - z - r - s; + r = H[m + 2][m + 1]; + s = Math.abs(p) + Math.abs(q) + Math.abs(r); + p = p / s; + q = q / s; + r = r / s; + if (m == l) { + break; + } + if (Math.abs(H[m][m - 1]) * (Math.abs(q) + Math.abs(r)) < + eps * (Math.abs(p) * (Math.abs(H[m - 1][m - 1]) + Math.abs(z) + + Math.abs(H[m + 1][m + 1])))) { + break; + } + m--; + } + + for (int i = m + 2; i <= n; i++) { + H[i][i - 2] = 0.0; + if (i > m + 2) { + H[i][i - 3] = 0.0; + } + } + + // Double QR step involving rows l:n and columns m:n + + for (int k = m; k <= n - 1; k++) { + boolean notlast = (k != n - 1); + if (k != m) { + p = H[k][k - 1]; + q = H[k + 1][k - 1]; + r = (notlast ? H[k + 2][k - 1] : 0.0); + x = Math.abs(p) + Math.abs(q) + Math.abs(r); + if (x != 0.0) { + p = p / x; + q = q / x; + r = r / x; + } + } + if (x == 0.0) { + break; + } + s = Math.sqrt(p * p + q * q + r * r); + if (p < 0) { + s = -s; + } + if (s != 0) { + if (k != m) { + H[k][k - 1] = -s * x; + } else if (l != m) { + H[k][k - 1] = -H[k][k - 1]; + } + p = p + s; + x = p / s; + y = q / s; + z = r / s; + q = q / p; + r = r / p; + + // Row modification + + for (int j = k; j < nn; j++) { + p = H[k][j] + q * H[k + 1][j]; + if (notlast) { + p = p + r * H[k + 2][j]; + H[k + 2][j] = H[k + 2][j] - p * z; + } + H[k][j] = H[k][j] - p * x; + H[k + 1][j] = H[k + 1][j] - p * y; + } + + // Column modification + + for (int i = 0; i <= Math.min(n, k + 3); i++) { + p = x * H[i][k] + y * H[i][k + 1]; + if (notlast) { + p = p + z * H[i][k + 2]; + H[i][k + 2] = H[i][k + 2] - p * r; + } + H[i][k] = H[i][k] - p; + H[i][k + 1] = H[i][k + 1] - p * q; + } + + // Accumulate transformations + + for (int i = low; i <= high; i++) { + p = x * V[i][k] + y * V[i][k + 1]; + if (notlast) { + p = p + z * V[i][k + 2]; + V[i][k + 2] = V[i][k + 2] - p * r; + } + V[i][k] = V[i][k] - p; + V[i][k + 1] = V[i][k + 1] - p * q; + } + } // (s != 0) + } // k loop + } // check convergence + } // while (n >= low) + + // Backsubstitute to find vectors of upper triangular form + + if (norm == 0.0) { + return; + } + + for (n = nn - 1; n >= 0; n--) { + p = d[n]; + q = e[n]; + + // Real vector + + if (q == 0) { + int l = n; + H[n][n] = 1.0; + for (int i = n - 1; i >= 0; i--) { + w = H[i][i] - p; + r = 0.0; + for (int j = l; j <= n; j++) { + r = r + H[i][j] * H[j][n]; + } + if (e[i] < 0.0) { + z = w; + s = r; + } else { + l = i; + if (e[i] == 0.0) { + if (w != 0.0) { + H[i][n] = -r / w; + } else { + H[i][n] = -r / (eps * norm); + } + + // Solve real equations + + } else { + x = H[i][i + 1]; + y = H[i + 1][i]; + q = (d[i] - p) * (d[i] - p) + e[i] * e[i]; + t = (x * s - z * r) / q; + H[i][n] = t; + if (Math.abs(x) > Math.abs(z)) { + H[i + 1][n] = (-r - w * t) / x; + } else { + H[i + 1][n] = (-s - y * t) / z; + } + } + + // Overflow control + + t = Math.abs(H[i][n]); + if ((eps * t) * t > 1) { + for (int j = i; j <= n; j++) { + H[j][n] = H[j][n] / t; + } + } + } + } + + // Complex vector + + } else if (q < 0) { + int l = n - 1; + + // Last vector component imaginary so matrix is triangular + + if (Math.abs(H[n][n - 1]) > Math.abs(H[n - 1][n])) { + H[n - 1][n - 1] = q / H[n][n - 1]; + H[n - 1][n] = -(H[n][n] - p) / H[n][n - 1]; + } else { + cdiv(0.0, -H[n - 1][n], H[n - 1][n - 1] - p, q); + H[n - 1][n - 1] = cdivr; + H[n - 1][n] = cdivi; + } + H[n][n - 1] = 0.0; + H[n][n] = 1.0; + for (int i = n - 2; i >= 0; i--) { + double ra, sa, vr, vi; + ra = 0.0; + sa = 0.0; + for (int j = l; j <= n; j++) { + ra = ra + H[i][j] * H[j][n - 1]; + sa = sa + H[i][j] * H[j][n]; + } + w = H[i][i] - p; + + if (e[i] < 0.0) { + z = w; + r = ra; + s = sa; + } else { + l = i; + if (e[i] == 0) { + cdiv(-ra, -sa, w, q); + H[i][n - 1] = cdivr; + H[i][n] = cdivi; + } else { + + // Solve complex equations + + x = H[i][i + 1]; + y = H[i + 1][i]; + vr = (d[i] - p) * (d[i] - p) + e[i] * e[i] - q * q; + vi = (d[i] - p) * 2.0 * q; + if (vr == 0.0 & vi == 0.0) { + vr = eps * norm * (Math.abs(w) + Math.abs(q) + + Math.abs(x) + Math.abs(y) + Math.abs(z)); + } + cdiv(x * r - z * ra + q * sa, x * s - z * sa - q * ra, vr, vi); + H[i][n - 1] = cdivr; + H[i][n] = cdivi; + if (Math.abs(x) > (Math.abs(z) + Math.abs(q))) { + H[i + 1][n - 1] = (-ra - w * H[i][n - 1] + q * H[i][n]) / x; + H[i + 1][n] = (-sa - w * H[i][n] - q * H[i][n - 1]) / x; + } else { + cdiv(-r - y * H[i][n - 1], -s - y * H[i][n], z, q); + H[i + 1][n - 1] = cdivr; + H[i + 1][n] = cdivi; + } + } + + // Overflow control + + t = Math.max(Math.abs(H[i][n - 1]), Math.abs(H[i][n])); + if ((eps * t) * t > 1) { + for (int j = i; j <= n; j++) { + H[j][n - 1] = H[j][n - 1] / t; + H[j][n] = H[j][n] / t; + } + } + } + } + } + } + + // Vectors of isolated roots + + for (int i = 0; i < nn; i++) { + if (i < low | i > high) { + for (int j = i; j < nn; j++) { + V[i][j] = H[i][j]; + } + } + } + + // Back transformation to get eigenvectors of original matrix + + for (int j = nn - 1; j >= low; j--) { + for (int i = low; i <= high; i++) { + z = 0.0; + for (int k = low; k <= Math.min(j, high); k++) { + z = z + V[i][k] * H[k][j]; + } + V[i][j] = z; + } + } + } + +/* ------------------------ + Public Methods + * ------------------------ */ + + /** + * Return the eigenvector matrix + * + * @return V + */ + + public JaMatrix getV() { + return new JaMatrix(V, n, n); + } + + /** + * Return the real parts of the eigenvalues + * + * @return real(diag ( D)) + */ + public double[] getRealEigenvalues() { + return d; + } + + /** + * Return the imaginary parts of the eigenvalues + * + * @return imag(diag ( D)) + */ + public double[] getImagEigenvalues() { + return e; + } + + /** + * Return the block diagonal eigenvalue matrix + * + * @return D + */ + + public JaMatrix getD() { + JaMatrix X = new JaMatrix(n, n); + double[][] D = X.getArray(); + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + D[i][j] = 0.0; + } + D[i][i] = d[i]; + if (e[i] > 0) { + D[i][i + 1] = e[i]; + } else if (e[i] < 0) { + D[i][i - 1] = e[i]; + } + } + return X; + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/linalg/jama/JMatrixFunc.java b/core/src/main/java/com/alibaba/alink/common/linalg/jama/JMatrixFunc.java new file mode 100644 index 000000000..941eff278 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/linalg/jama/JMatrixFunc.java @@ -0,0 +1,206 @@ +/* + * To change this template, choose Tools | Templates + * and open the template in the editor. + */ +package com.alibaba.alink.common.linalg.jama; + +import com.alibaba.alink.common.linalg.DenseMatrix; + +/** + * @author yangxu + */ +public class JMatrixFunc { + + public static DenseMatrix copy(DenseMatrix jm) { + int m = jm.numRows(); + int n = jm.numCols(); + double[][] A = jm.getArrayCopy2D(); + DenseMatrix X = new DenseMatrix(m, n); + double[][] C = X.getArrayCopy2D(); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + C[i][j] = A[i][j]; + } + } + return X; + } + + public static DenseMatrix transpose(DenseMatrix jm) { + int m = jm.numRows(); + int n = jm.numCols(); + double[][] A = jm.getArrayCopy2D(); + DenseMatrix X = new DenseMatrix(n, m); + double[][] C = X.getArrayCopy2D(); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + C[j][i] = A[i][j]; + } + } + return X; + } + + public static double norm1(DenseMatrix jm) { + return new JaMatrix(jm).norm1(); + } + + public static double norm2(DenseMatrix jm) { + return new JaMatrix(jm).norm2(); + } + + public static double normInf(DenseMatrix jm) { + return new JaMatrix(jm).normInf(); + } + + public static double normF(DenseMatrix jm) { + return new JaMatrix(jm).normF(); + } + + public static DenseMatrix uminus(DenseMatrix jm) { + return new JaMatrix(jm).uminus().toJMatrix(); + } + + public static DenseMatrix add(DenseMatrix jm1, DenseMatrix jm2) { + return new JaMatrix(jm1).plus(new JaMatrix(jm2)).toJMatrix(); + } + + public static DenseMatrix subtract(DenseMatrix jm1, DenseMatrix jm2) { + return new JaMatrix(jm1).minus(new JaMatrix(jm2)).toJMatrix(); + } + + public static DenseMatrix multiply(DenseMatrix jm, double d) { + return new JaMatrix(jm).times(d).toJMatrix(); + } + + public static DenseMatrix multiply(DenseMatrix jm, DenseMatrix jm1) { + return new JaMatrix(jm).times(new JaMatrix(jm1)).toJMatrix(); + } + + public static DenseMatrix[] lu(DenseMatrix jm) { + DenseMatrix[] jms = new DenseMatrix[3]; + LUDecomposition lud = new JaMatrix(jm).lu(); + jms[0] = lud.getL().toJMatrix(); + jms[1] = new JaMatrix(lud.getDoublePivot(), lud.getPivot().length).toJMatrix(); + jms[2] = lud.getU().toJMatrix(); + return jms; + } + + public static DenseMatrix[] qr(DenseMatrix jm) { + QRDecomposition qr = new JaMatrix(jm).qr(); + DenseMatrix[] jms = new DenseMatrix[2]; + jms[0] = qr.getQ().toJMatrix(); + jms[1] = qr.getR().toJMatrix(); + return jms; + } + + public static DenseMatrix chol(DenseMatrix jm) { + CholeskyDecomposition chol = new JaMatrix(jm).chol(); + return chol.getL().toJMatrix(); + } + + public static DenseMatrix[] svd(DenseMatrix jm) { + if (jm.numRows() < jm.numCols()) { + SingularValueDecomposition svd = new JaMatrix(jm.transpose()).svd(); + DenseMatrix[] jms = new DenseMatrix[3]; + jms[0] = svd.getV().toJMatrix(); + jms[1] = svd.getS().toJMatrix(); + jms[2] = svd.getU().toJMatrix(); + return jms; + } else { + SingularValueDecomposition svd = new JaMatrix(jm).svd(); + DenseMatrix[] jms = new DenseMatrix[3]; + jms[0] = svd.getU().toJMatrix(); + jms[1] = svd.getS().toJMatrix(); + jms[2] = svd.getV().toJMatrix(); + return jms; + } + } + + public static DenseMatrix[] eig(DenseMatrix jm) { + EigenvalueDecomposition ed = new JaMatrix(jm).eig(); + DenseMatrix[] jms = new DenseMatrix[2]; + jms[0] = ed.getV().toJMatrix(); + jms[1] = ed.getD().toJMatrix(); + return jms; + } + + public static DenseMatrix solve(DenseMatrix A, DenseMatrix B) { + return new JaMatrix(A).solve(new JaMatrix(B)).toJMatrix(); + } + + public static DenseMatrix solveTranspose(DenseMatrix A, DenseMatrix B) { + return new JaMatrix(A).solveTranspose(new JaMatrix(B)).toJMatrix(); + } + + public static DenseMatrix inverse(DenseMatrix jm) { + return new JaMatrix(jm).inverse().toJMatrix(); + } + + public static double det(DenseMatrix jm) { + double[][] Btmp = jm.getArrayCopy2D(); + boolean isZeroMatrix = true; + for (int i = 0; i < jm.numRows(); i++) { + for (int j = 0; j < jm.numCols(); j++) { + if (Btmp[i][j] != 0) { + isZeroMatrix = false; + } + } + } + if (isZeroMatrix) { + return 0; + } + return new JaMatrix(jm).det(); + } + + public static double detLog(DenseMatrix jm) { + double[][] Btmp = jm.getArrayCopy2D(); + boolean isZeroMatrix = true; + for (int i = 0; i < jm.numRows(); i++) { + for (int j = 0; j < jm.numCols(); j++) { + if (Btmp[i][j] != 0) { + isZeroMatrix = false; + } + } + } + if (isZeroMatrix) { + return Double.MIN_VALUE; + } + return new JaMatrix(jm).detLog(); + } + + public static int rank(DenseMatrix jm) { + return new JaMatrix(jm).rank(); + } + + public static double cond(DenseMatrix jm) { + return new JaMatrix(jm).cond(); + } + + public static double trace(DenseMatrix jm) { + return new JaMatrix(jm).trace(); + } + + public static DenseMatrix random(int m, int n) { + return JaMatrix.random(m, n).toJMatrix(); + } + + public static DenseMatrix identity(int m, int n) { + return JaMatrix.identity(m, n).toJMatrix(); + } + + public static DenseMatrix solveLS(DenseMatrix jm, DenseMatrix B) { + return new JaMatrix(jm).solveLS(new JaMatrix(B)).toJMatrix(); + } + + public static void print(DenseMatrix jm, int w, int v) { + new JaMatrix(jm).print(w, v); + } + + /** + * Is the matrix nonsingular? + * + * @return true if U, and hence A, is nonsingular. + */ + public static boolean isNonsingular(DenseMatrix jm) { + return new LUDecomposition(new JaMatrix(jm)).isNonsingular(); + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/linalg/jama/JaMatrix.java b/core/src/main/java/com/alibaba/alink/common/linalg/jama/JaMatrix.java new file mode 100644 index 000000000..62c8a57a9 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/linalg/jama/JaMatrix.java @@ -0,0 +1,1189 @@ +package com.alibaba.alink.common.linalg.jama; + +import com.alibaba.alink.common.linalg.DenseMatrix; + +import java.io.BufferedReader; +import java.io.PrintWriter; +import java.io.StreamTokenizer; +import java.text.DecimalFormat; +import java.text.DecimalFormatSymbols; +import java.text.NumberFormat; +import java.util.Locale; + +/** + * Jama = Java Matrix class. + *

+ * The Java Matrix Class provides the fundamental operations of numerical linear + * algebra. Various constructors create Matrices from two dimensional arrays of + * double precision floating point numbers. Various "gets" and "sets" provide + * access to submatrices and matrix elements. Several methods implement basic + * matrix arithmetic, including matrix addition and multiplication, matrix + * norms, and element-by-element array operations. Methods for reading and + * printing matrices are also included. All the operations in this version of + * the Matrix Class involve real matrices. Complex matrices may be handled in a + * future version. + *

+ * Five fundamental matrix decompositions, which consist of pairs or triples of + * matrices, permutation vectors, and the like, produce results in five + * decomposition classes. These decompositions are accessed by the Matrix class + * to compute solutions of simultaneous linear equations, determinants, inverses + * and other matrix functions. The five decompositions are: + *

+ *

    + *
  • Cholesky Decomposition of symmetric, positive definite matrices. + *
  • LU Decomposition of rectangular matrices. + *
  • QR Decomposition of rectangular matrices. + *
  • Singular Value Decomposition of rectangular matrices. + *
  • Eigenvalue Decomposition of both symmetric and nonsymmetric square + * matrices. + *
+ *
+ *
Example of use:
+ *

+ *

Solve a linear system A x = b and compute the residual norm, ||b - A x||. + *

+ * < + * PRE> + * double[][] vals = {{1.,2.,3},{4.,5.,6.},{7.,8.,10.}}; Matrix A = new + * Matrix(vals); Matrix b = Matrix.random(3,1); Matrix x = A.solve(b); Matrix r + * = A.times(x).minus(b); double rnorm = r.normInf(); + *

+ *
+ * + * @author The MathWorks, Inc. and the National Institute of Standards and + * Technology. + * @version 5 August 1998 + */ +public class JaMatrix implements Cloneable, java.io.Serializable { + private static final long serialVersionUID = 455469613748961890L; + + /* ------------------------ + Class variables + * ------------------------ */ + /** + * Array for internal storage of elements. + * + * @serial internal array storage. + */ + private double[][] A; + /** + * Row and column dimensions. + * + * @serial row dimension. + * @serial column dimension. + */ + private int m, n; + + /* ------------------------ + Constructors + * ------------------------ */ + public JaMatrix() { + } + + /** + * Construct an m-by-n matrix of zeros. + * + * @param m Number of rows. + * @param n Number of colums. + */ + public JaMatrix(int m, int n) { + this.m = m; + this.n = n; + A = new double[m][n]; + } + + /** + * Construct an m-by-n constant matrix. + * + * @param m Number of rows. + * @param n Number of colums. + * @param s Fill the matrix with this scalar value. + */ + public JaMatrix(int m, int n, double s) { + this.m = m; + this.n = n; + A = new double[m][n]; + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + A[i][j] = s; + } + } + } + + /** + * Construct a matrix from a 2-D array. + * + * @param A Two-dimensional array of doubles. + * @throws IllegalArgumentException All rows must have the same size + * @see #constructWithCopy + */ + public JaMatrix(double[][] A) { + m = A.length; + n = A[0].length; + for (int i = 0; i < m; i++) { + if (A[i].length != n) { + throw new IllegalArgumentException("All rows must have the same size."); + } + } + this.A = A; + } + + /** + * Construct a matrix quickly without checking arguments. + * + * @param A Two-dimensional array of doubles. + * @param m Number of rows. + * @param n Number of colums. + */ + public JaMatrix(double[][] A, int m, int n) { + this.A = A; + this.m = m; + this.n = n; + } + + /** + * Construct a matrix from a one-dimensional packed array + * + * @param vals One-dimensional array of doubles, packed by columns (ala + * Fortran). + * @param m Number of rows. + * @throws IllegalArgumentException Array size must be a multiple of m. + */ + public JaMatrix(double vals[], int m) { + this.m = m; + n = (m != 0 ? vals.length / m : 0); + if (m * n != vals.length) { + throw new IllegalArgumentException("Array size must be a multiple of m."); + } + A = new double[m][n]; + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + A[i][j] = vals[i + j * m]; + } + } + } + + public JaMatrix(DenseMatrix jm) { + this.A = jm.getArrayCopy2D(); + this.m = jm.numRows(); + this.n = jm.numCols(); + } + + /** + * Construct a matrix from a copy of a 2-D array. + * + * @param A Two-dimensional array of doubles. + * @throws IllegalArgumentException All rows must have the same size + */ + public static JaMatrix constructWithCopy(double[][] A) { + int m = A.length; + int n = A[0].length; + JaMatrix X = new JaMatrix(m, n); + double[][] C = X.getArray(); + for (int i = 0; i < m; i++) { + if (A[i].length != n) { + throw new IllegalArgumentException("All rows must have the same size."); + } + for (int j = 0; j < n; j++) { + C[i][j] = A[i][j]; + } + } + return X; + } + + /** + * Generate matrix with random elements + * + * @param m Number of rows. + * @param n Number of colums. + * @return An m-by-n matrix with uniformly distributed random elements. + */ + public static JaMatrix random(int m, int n) { + JaMatrix A = new JaMatrix(m, n); + double[][] X = A.getArray(); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + X[i][j] = Math.random(); + } + } + return A; + } + + /* ------------------------ + Public Methods + * ------------------------ */ + + /** + * Generate identity matrix + * + * @param m Number of rows. + * @param n Number of colums. + * @return An m-by-n matrix with ones on the diagonal and zeros elsewhere. + */ + public static JaMatrix identity(int m, int n) { + JaMatrix A = new JaMatrix(m, n); + double[][] X = A.getArray(); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + X[i][j] = (i == j ? 1.0 : 0.0); + } + } + return A; + } + + /** + * Read a matrix from a stream. The format is the same the print method, so + * printed matrices can be read back in (provided they were printed using US + * Locale). Elements are separated by whitespace, all the elements for each + * row appear on a single line, the last row is followed by a blank line. + * + * @param input the input stream. + */ + public static JaMatrix read(BufferedReader input) throws java.io.IOException { + StreamTokenizer tokenizer = new StreamTokenizer(input); + + // Although StreamTokenizer will getVector numbers, it doesn't recognize + // scientific notation (E or D); however, Double.valueOf does. + // The strategy here is to disable StreamTokenizer's number parsing. + // We'll only get whitespace delimited words, EOL's and EOF's. + // These words should all be numbers, for Double.valueOf to getVector. + tokenizer.resetSyntax(); + tokenizer.wordChars(0, 255); + tokenizer.whitespaceChars(0, ' '); + tokenizer.eolIsSignificant(true); + java.util.Vector v = new java.util.Vector(); + + // Ignore initial empty lines + while (tokenizer.nextToken() == StreamTokenizer.TT_EOL) { ; } + if (tokenizer.ttype == StreamTokenizer.TT_EOF) { + throw new java.io.IOException("Unexpected EOF on matrix read."); + } + do { + v.addElement(Double.valueOf(tokenizer.sval)); // Read & store 1st row. + } while (tokenizer.nextToken() == StreamTokenizer.TT_WORD); + + int n = v.size(); // Now we've got the number of columns! + double row[] = new double[n]; + for (int j = 0; j < n; j++) // extract the elements of the 1st row. + { + row[j] = ((Double) v.elementAt(j)).doubleValue(); + } + v.removeAllElements(); + v.addElement(row); // Start storing rows instead of columns. + while (tokenizer.nextToken() == StreamTokenizer.TT_WORD) { + // While non-empty lines + v.addElement(row = new double[n]); + int j = 0; + do { + if (j >= n) { + throw new java.io.IOException("Row " + v.size() + " is too long."); + } + row[j++] = Double.valueOf(tokenizer.sval).doubleValue(); + } while (tokenizer.nextToken() == StreamTokenizer.TT_WORD); + if (j < n) { + throw new java.io.IOException("Row " + v.size() + " is too short."); + } + } + int m = v.size(); // Now we've got the number of rows. + double[][] A = new double[m][]; + v.copyInto(A); // copy the rows out of the vector + return new JaMatrix(A); + } + + public DenseMatrix toJMatrix() { + return new DenseMatrix(A); + } + + public boolean Compare(JaMatrix other, double eps) { + if ((this.m != other.m) || (this.n != other.n)) { + return false; + } + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + if (Math.abs(this.A[i][j] - other.A[i][j]) > eps) { + return false; + } + } + } + return true; + } + + /** + * Make a deep copy of a matrix + */ + public JaMatrix copy() { + JaMatrix X = new JaMatrix(m, n); + double[][] C = X.getArray(); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + C[i][j] = A[i][j]; + } + } + return X; + } + + /** + * Clone the Matrix object. + */ + public Object clone() { + return this.copy(); + } + + /** + * Access the internal two-dimensional array. + * + * @return Pointer to the two-dimensional array of matrix elements. + */ + public double[][] getArray() { + return A; + } + + /** + * Copy the internal two-dimensional array. + * + * @return Two-dimensional array copy of matrix elements. + */ + public double[][] getArrayCopy() { + double[][] C = new double[m][n]; + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + C[i][j] = A[i][j]; + } + } + return C; + } + + /** + * Make a one-dimensional column packed copy of the internal array. + * + * @return Matrix elements packed in a one-dimensional array by columns. + */ + public double[] getColumnPackedCopy() { + double[] vals = new double[m * n]; + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + vals[i + j * m] = A[i][j]; + } + } + return vals; + } + + /** + * Make a one-dimensional row packed copy of the internal array. + * + * @return Matrix elements packed in a one-dimensional array by rows. + */ + public double[] getRowPackedCopy() { + double[] vals = new double[m * n]; + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + vals[i * n + j] = A[i][j]; + } + } + return vals; + } + + /** + * Get row dimension. + * + * @return m, the number of rows. + */ + public int getRowDimension() { + return m; + } + + /** + * Get column dimension. + * + * @return n, the number of columns. + */ + public int getColumnDimension() { + return n; + } + + /** + * Get a single element. + * + * @param i Row index. + * @param j Column index. + * @return A(i, j) + * @throws ArrayIndexOutOfBoundsException + */ + public double get(int i, int j) { + return A[i][j]; + } + + /** + * Get a submatrix. + * + * @param i0 Initial row index + * @param i1 Final row index + * @param j0 Initial column index + * @param j1 Final column index + * @return A(i0 : i1, j0 : j1) + * @throws ArrayIndexOutOfBoundsException Submatrix indices + */ + public JaMatrix getMatrix(int i0, int i1, int j0, int j1) { + JaMatrix X = new JaMatrix(i1 - i0 + 1, j1 - j0 + 1); + double[][] B = X.getArray(); + try { + for (int i = i0; i <= i1; i++) { + for (int j = j0; j <= j1; j++) { + B[i - i0][j - j0] = A[i][j]; + } + } + } catch (ArrayIndexOutOfBoundsException e) { + throw new ArrayIndexOutOfBoundsException("Submatrix indices"); + } + return X; + } + + /** + * Get a submatrix. + * + * @param r Array of row indices. + * @param c Array of column indices. + * @return A(r ( :), c(:)) + * @throws ArrayIndexOutOfBoundsException Submatrix indices + */ + public JaMatrix getMatrix(int[] r, int[] c) { + JaMatrix X = new JaMatrix(r.length, c.length); + double[][] B = X.getArray(); + try { + for (int i = 0; i < r.length; i++) { + for (int j = 0; j < c.length; j++) { + B[i][j] = A[r[i]][c[j]]; + } + } + } catch (ArrayIndexOutOfBoundsException e) { + throw new ArrayIndexOutOfBoundsException("Submatrix indices"); + } + return X; + } + + /** + * Get a submatrix. + * + * @param i0 Initial row index + * @param i1 Final row index + * @param c Array of column indices. + * @return A(i0 : i1, c ( :)) + * @throws ArrayIndexOutOfBoundsException Submatrix indices + */ + public JaMatrix getMatrix(int i0, int i1, int[] c) { + JaMatrix X = new JaMatrix(i1 - i0 + 1, c.length); + double[][] B = X.getArray(); + try { + for (int i = i0; i <= i1; i++) { + for (int j = 0; j < c.length; j++) { + B[i - i0][j] = A[i][c[j]]; + } + } + } catch (ArrayIndexOutOfBoundsException e) { + throw new ArrayIndexOutOfBoundsException("Submatrix indices"); + } + return X; + } + + /** + * Get a submatrix. + * + * @param r Array of row indices. + * @param j0 Initial column index + * @param j1 Final column index + * @return A(r ( :), j0:j1) + * @throws ArrayIndexOutOfBoundsException Submatrix indices + */ + public JaMatrix getMatrix(int[] r, int j0, int j1) { + JaMatrix X = new JaMatrix(r.length, j1 - j0 + 1); + double[][] B = X.getArray(); + try { + for (int i = 0; i < r.length; i++) { + for (int j = j0; j <= j1; j++) { + B[i][j - j0] = A[r[i]][j]; + } + } + } catch (ArrayIndexOutOfBoundsException e) { + throw new ArrayIndexOutOfBoundsException("Submatrix indices"); + } + return X; + } + + /** + * Set a single element. + * + * @param i Row index. + * @param j Column index. + * @param s A(i,j). + * @throws ArrayIndexOutOfBoundsException + */ + public void set(int i, int j, double s) { + A[i][j] = s; + } + + /** + * Set a submatrix. + * + * @param i0 Initial row index + * @param i1 Final row index + * @param j0 Initial column index + * @param j1 Final column index + * @param X A(i0:i1,j0:j1) + * @throws ArrayIndexOutOfBoundsException Submatrix indices + */ + public void setMatrix(int i0, int i1, int j0, int j1, JaMatrix X) { + try { + for (int i = i0; i <= i1; i++) { + for (int j = j0; j <= j1; j++) { + A[i][j] = X.get(i - i0, j - j0); + } + } + } catch (ArrayIndexOutOfBoundsException e) { + throw new ArrayIndexOutOfBoundsException("Submatrix indices"); + } + } + + /** + * Set a submatrix. + * + * @param r Array of row indices. + * @param c Array of column indices. + * @param X A(r(:),c(:)) + * @throws ArrayIndexOutOfBoundsException Submatrix indices + */ + public void setMatrix(int[] r, int[] c, JaMatrix X) { + try { + for (int i = 0; i < r.length; i++) { + for (int j = 0; j < c.length; j++) { + A[r[i]][c[j]] = X.get(i, j); + } + } + } catch (ArrayIndexOutOfBoundsException e) { + throw new ArrayIndexOutOfBoundsException("Submatrix indices"); + } + } + + /** + * Set a submatrix. + * + * @param r Array of row indices. + * @param j0 Initial column index + * @param j1 Final column index + * @param X A(r(:),j0:j1) + * @throws ArrayIndexOutOfBoundsException Submatrix indices + */ + public void setMatrix(int[] r, int j0, int j1, JaMatrix X) { + try { + for (int i = 0; i < r.length; i++) { + for (int j = j0; j <= j1; j++) { + A[r[i]][j] = X.get(i, j - j0); + } + } + } catch (ArrayIndexOutOfBoundsException e) { + throw new ArrayIndexOutOfBoundsException("Submatrix indices"); + } + } + + /** + * Set a submatrix. + * + * @param i0 Initial row index + * @param i1 Final row index + * @param c Array of column indices. + * @param X A(i0:i1,c(:)) + * @throws ArrayIndexOutOfBoundsException Submatrix indices + */ + public void setMatrix(int i0, int i1, int[] c, JaMatrix X) { + try { + for (int i = i0; i <= i1; i++) { + for (int j = 0; j < c.length; j++) { + A[i][c[j]] = X.get(i - i0, j); + } + } + } catch (ArrayIndexOutOfBoundsException e) { + throw new ArrayIndexOutOfBoundsException("Submatrix indices"); + } + } + + /** + * Matrix transpose. + * + * @return A' + */ + public JaMatrix transpose() { + JaMatrix X = new JaMatrix(n, m); + double[][] C = X.getArray(); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + C[j][i] = A[i][j]; + } + } + return X; + } + + /** + * One norm + * + * @return maximum column sum. + */ + public double norm1() { + double f = 0; + for (int j = 0; j < n; j++) { + double s = 0; + for (int i = 0; i < m; i++) { + s += Math.abs(A[i][j]); + } + f = Math.max(f, s); + } + return f; + } + + /** + * Two norm + * + * @return maximum singular value. + */ + public double norm2() { + return (new SingularValueDecomposition(this).norm2()); + } + + /** + * Infinity norm + * + * @return maximum row sum. + */ + public double normInf() { + double f = 0; + for (int i = 0; i < m; i++) { + double s = 0; + for (int j = 0; j < n; j++) { + s += Math.abs(A[i][j]); + } + f = Math.max(f, s); + } + return f; + } + + /** + * Frobenius norm + * + * @return Sqrt of sum of squares of all elements. + */ + public double normF() { + double f = 0; + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + f = Maths.hypot(f, A[i][j]); + } + } + return f; + } + + /** + * Unary minus + * + * @return -A + */ + public JaMatrix uminus() { + JaMatrix X = new JaMatrix(m, n); + double[][] C = X.getArray(); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + C[i][j] = -A[i][j]; + } + } + return X; + } + + /** + * C = A + B + * + * @param B another matrix + * @return A + B + */ + public JaMatrix plus(JaMatrix B) { + checkMatrixDimensions(B); + JaMatrix X = new JaMatrix(m, n); + double[][] C = X.getArray(); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + C[i][j] = A[i][j] + B.A[i][j]; + } + } + return X; + } + + /** + * A = A + B + * + * @param B another matrix + * @return A + B + */ + public JaMatrix plusEquals(JaMatrix B) { + checkMatrixDimensions(B); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + A[i][j] = A[i][j] + B.A[i][j]; + } + } + return this; + } + + /** + * C = A - B + * + * @param B another matrix + * @return A - B + */ + public JaMatrix minus(JaMatrix B) { + checkMatrixDimensions(B); + JaMatrix X = new JaMatrix(m, n); + double[][] C = X.getArray(); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + C[i][j] = A[i][j] - B.A[i][j]; + } + } + return X; + } + + /** + * A = A - B + * + * @param B another matrix + * @return A - B + */ + public JaMatrix minusEquals(JaMatrix B) { + checkMatrixDimensions(B); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + A[i][j] = A[i][j] - B.A[i][j]; + } + } + return this; + } + + /** + * Element-by-element multiplication, C = A.*B + * + * @param B another matrix + * @return A.*B + */ + public JaMatrix arrayTimes(JaMatrix B) { + checkMatrixDimensions(B); + JaMatrix X = new JaMatrix(m, n); + double[][] C = X.getArray(); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + C[i][j] = A[i][j] * B.A[i][j]; + } + } + return X; + } + + /** + * Element-by-element multiplication in place, A = A.*B + * + * @param B another matrix + * @return A.*B + */ + public JaMatrix arrayTimesEquals(JaMatrix B) { + checkMatrixDimensions(B); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + A[i][j] = A[i][j] * B.A[i][j]; + } + } + return this; + } + + /** + * Element-by-element right division, C = A./B + * + * @param B another matrix + * @return A./B + */ + public JaMatrix arrayRightDivide(JaMatrix B) { + checkMatrixDimensions(B); + JaMatrix X = new JaMatrix(m, n); + double[][] C = X.getArray(); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + C[i][j] = A[i][j] / B.A[i][j]; + } + } + return X; + } + + /** + * Element-by-element right division in place, A = A./B + * + * @param B another matrix + * @return A./B + */ + public JaMatrix arrayRightDivideEquals(JaMatrix B) { + checkMatrixDimensions(B); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + A[i][j] = A[i][j] / B.A[i][j]; + } + } + return this; + } + + /** + * Element-by-element left division, C = A.\B + * + * @param B another matrix + * @return A.\B + */ + public JaMatrix arrayLeftDivide(JaMatrix B) { + checkMatrixDimensions(B); + JaMatrix X = new JaMatrix(m, n); + double[][] C = X.getArray(); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + C[i][j] = B.A[i][j] / A[i][j]; + } + } + return X; + } + + /** + * Element-by-element left division in place, A = A.\B + * + * @param B another matrix + * @return A.\B + */ + public JaMatrix arrayLeftDivideEquals(JaMatrix B) { + checkMatrixDimensions(B); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + A[i][j] = B.A[i][j] / A[i][j]; + } + } + return this; + } + + /** + * Multiply a matrix by a scalar, C = s*A + * + * @param s scalar + * @return s*A + */ + public JaMatrix times(double s) { + JaMatrix X = new JaMatrix(m, n); + double[][] C = X.getArray(); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + C[i][j] = s * A[i][j]; + } + } + return X; + } + + /** + * Multiply a matrix by a scalar in place, A = s*A + * + * @param s scalar + * @return replace A by s*A + */ + public JaMatrix timesEquals(double s) { + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + A[i][j] = s * A[i][j]; + } + } + return this; + } + + /** + * Linear algebraic matrix multiplication, A * B + * + * @param B another matrix + * @return Matrix product, A * B + * @throws IllegalArgumentException Matrix inner dimensions must agree. + */ + public JaMatrix times(JaMatrix B) { + if (B.m != n) { + throw new IllegalArgumentException("Matrix inner dimensions must agree."); + } + JaMatrix X = new JaMatrix(m, B.n); + double[][] C = X.getArray(); + double[] Bcolj = new double[n]; + for (int j = 0; j < B.n; j++) { + for (int k = 0; k < n; k++) { + Bcolj[k] = B.A[k][j]; + } + for (int i = 0; i < m; i++) { + double[] Arowi = A[i]; + double s = 0; + for (int k = 0; k < n; k++) { + s += Arowi[k] * Bcolj[k]; + } + C[i][j] = s; + } + } + return X; + } + + /** + * LU Decomposition + * + * @return LUDecomposition + * @see LUDecomposition + */ + public LUDecomposition lu() { + return new LUDecomposition(this); + } + + /** + * QR Decomposition + * + * @return QRDecomposition + * @see QRDecomposition + */ + public QRDecomposition qr() { + return new QRDecomposition(this); + } + + /** + * Cholesky Decomposition + * + * @return CholeskyDecomposition + * @see CholeskyDecomposition + */ + public CholeskyDecomposition chol() { + return new CholeskyDecomposition(this); + } + + /** + * Singular Value Decomposition + * + * @return SingularValueDecomposition + * @see SingularValueDecomposition + */ + public SingularValueDecomposition svd() { + return new SingularValueDecomposition(this); + } + + /** + * Eigenvalue Decomposition + * + * @return EigenvalueDecomposition + * @see EigenvalueDecomposition + */ + public EigenvalueDecomposition eig() { + return new EigenvalueDecomposition(this); + } + + /** + * Solve A*X = B + * + * @param B right hand side + * @return solution if A is square, least squares solution otherwise + */ + public JaMatrix solve(JaMatrix B) { + return (m == n ? (new LUDecomposition(this)).solve(B) + : (new QRDecomposition(this)).solve(B)); + } + + /** + * Solve X*A = B, which is also A'*X' = B' + * + * @param B right hand side + * @return solution if A is square, least squares solution otherwise. + */ + public JaMatrix solveTranspose(JaMatrix B) { + return transpose().solve(B.transpose()); + } + + public JaMatrix solveLS(JaMatrix B) { + double[][] Btmp = B.getArray(); + boolean isZeroMatrix = true; + for (int i = 0; i < B.getRowDimension(); i++) { + for (int j = 0; j < B.getColumnDimension(); j++) { + if (Btmp[i][j] != 0) { + isZeroMatrix = false; + } + } + } + if (isZeroMatrix) { + return new JaMatrix(JMatrixFunc.identity(n, B.getColumnDimension())); + } + SingularValueDecomposition svd = new SingularValueDecomposition(this); + double[][] s = svd.getS().getArray(); + double[][] si = new double[s.length][s[0].length]; + for (int i = 0; i < s.length; i++) { + for (int j = 0; j < s[0].length; j++) { + si[i][j] = 0; + } + } + for (int i = 0; i < s.length; i++) { + if (s[i][i] != 0) { + si[i][i] = 1 / s[i][i]; + } + } + return svd.getV().times(new JaMatrix(si)).times(svd.getU().transpose()).times(B); + + /* + JaMatrix AA = transpose().times(this); + JaMatrix AB = transpose().times(B); + AA.print(m, m); + AB.print(m, m); + return AA.solve(AB); + */ + } + + /** + * Matrix inverse or pseudoinverse + * + * @return inverse(A) if A is square, pseudoinverse otherwise. + */ + public JaMatrix inverse() { + return solve(identity(m, m)); + } + + /** + * Matrix determinant + * + * @return determinant + */ + public double det() { + return new LUDecomposition(this).det(); + } + + public double detLog() { + return new LUDecomposition(this).detLog(); + } + + /** + * Matrix rank + * + * @return effective numerical rank, obtained from SVD. + */ + public int rank() { + return new SingularValueDecomposition(this).rank(); + } + + /** + * Matrix condition (2 norm) + * + * @return ratio of largest to smallest singular value. + */ + public double cond() { + return new SingularValueDecomposition(this).cond(); + } + + /** + * Matrix trace. + * + * @return sum of the diagonal elements. + */ + public double trace() { + double t = 0; + for (int i = 0; i < Math.min(m, n); i++) { + t += A[i][i]; + } + return t; + } + + /** + * Print the matrix to stdout. Line the elements up in columns with a + * Fortran-like 'Fw.d' style format. + * + * @param w Column width. + * @param d Number of digits after the decimal. + */ + public String repr(int w, int d) { + if (m * n > 1024 * 1024) { + return "matrix is too big to print on screen"; + } + java.io.CharArrayWriter cw = new java.io.CharArrayWriter(); + java.io.PrintWriter pw = new java.io.PrintWriter(cw, true); + print(pw, w, d); + return cw.toString(); + } + + public void print(int w, int d) { + print(new PrintWriter(System.out, true), w, d); + } + + /** + * Print the matrix to the output stream. Line the elements up in columns + * with a Fortran-like 'Fw.d' style format. + * + * @param output Output stream. + * @param w Column width. + * @param d Number of digits after the decimal. + */ + public void print(PrintWriter output, int w, int d) { + DecimalFormat format = new DecimalFormat(); + format.setDecimalFormatSymbols(new DecimalFormatSymbols(Locale.US)); + format.setMinimumIntegerDigits(1); + format.setMaximumFractionDigits(d); + format.setMinimumFractionDigits(d); + format.setGroupingUsed(false); + print(output, format, w + 2); + } + + // DecimalFormat is a little disappointing coming from Fortran or C's printf. + // Since it doesn't pad on the left, the elements will come out different + // widths. Consequently, we'll pass the desired column width in as an + // argument and do the extra padding ourselves. + + /** + * Print the matrix to stdout. Line the elements up in columns. Use the + * format object, and right justify within columns of width characters. Note + * that is the matrix is to be read back in, you probably will want to use a + * NumberFormat that is set to US Locale. + * + * @param format A Formatting object for individual elements. + * @param width Field width for each column. + * @see java.text.DecimalFormat#setDecimalFormatSymbols + */ + public void print(NumberFormat format, int width) { + print(new PrintWriter(System.out, true), format, width); + } + + /** + * Print the matrix to the output stream. Line the elements up in columns. + * Use the format object, and right justify within columns of width + * characters. Note that is the matrix is to be read back in, you probably + * will want to use a NumberFormat that is set to US Locale. + * + * @param output the output stream. + * @param format A formatting object to format the matrix elements + * @param width Column width. + * @see java.text.DecimalFormat#setDecimalFormatSymbols + */ + public void print(PrintWriter output, NumberFormat format, int width) { + output.println(); // start on new line. + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + String s = format.format(A[i][j]); // format the number + int padding = Math.max(1, width - s.length()); // At _least_ 1 space + for (int k = 0; k < padding; k++) { + output.print(' '); + } + output.print(s); + } + output.println(); + } + output.println(); // end with blank line. + } + + + /* ------------------------ + Private Methods + * ------------------------ */ + + /** + * Check if size(A) == size(B) * + */ + private void checkMatrixDimensions(JaMatrix B) { + if (B.m != m || B.n != n) { + throw new IllegalArgumentException("Matrix dimensions must agree."); + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/linalg/jama/LUDecomposition.java b/core/src/main/java/com/alibaba/alink/common/linalg/jama/LUDecomposition.java new file mode 100644 index 000000000..798f98bb6 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/linalg/jama/LUDecomposition.java @@ -0,0 +1,346 @@ +package com.alibaba.alink.common.linalg.jama; + +/** + * LU Decomposition. + *

+ * For an m-by-n matrix A with m >= n, the LU decomposition is an m-by-n unit + * lower triangular matrix L, an n-by-n upper triangular matrix U, and a + * permutation vector piv of length m so that A(piv,:) = L*U. If m < n, then L + * is m-by-m and U is m-by-n.

+ * The LU decompostion with pivoting always exists, even if the matrix is + * singular, so the constructor will never fail. The primary use of the LU + * decomposition is in the solution of square systems of simultaneous linear + * equations. This will fail if isNonsingular() returns false. + */ +class LUDecomposition implements java.io.Serializable { + private static final long serialVersionUID = -4234096950824402122L; + /* ------------------------ + Class variables + * ------------------------ */ + + /** + * Array for internal storage of decomposition. + * + * @serial internal array storage. + */ + private double[][] LU; + /** + * Row and column dimensions, and pivot sign. + * + * @serial column dimension. + * @serial row dimension. + * @serial pivot sign. + */ + private int m, n, pivsign; + /** + * Internal storage of pivot vector. + * + * @serial pivot vector. + */ + private int[] piv; + /* ------------------------ + Constructor + * ------------------------ */ + + /** + * LU Decomposition + * + * @param A Rectangular matrix + * @return Structure to access L, U and piv. + */ + public LUDecomposition(JaMatrix A) { + + // Use a "left-looking", dot-product, Crout/Doolittle algorithm. + LU = A.getArrayCopy(); + m = A.getRowDimension(); + n = A.getColumnDimension(); + piv = new int[m]; + for (int i = 0; i < m; i++) { + piv[i] = i; + } + pivsign = 1; + double[] LUrowi; + double[] LUcolj = new double[m]; + + // Outer loop. + for (int j = 0; j < n; j++) { + + // Make a copy of the j-th column to localize references. + for (int i = 0; i < m; i++) { + LUcolj[i] = LU[i][j]; + } + + // Apply previous transformations. + for (int i = 0; i < m; i++) { + LUrowi = LU[i]; + + // Most of the time is spent in the following dot product. + int kmax = Math.min(i, j); + double s = 0.0; + for (int k = 0; k < kmax; k++) { + s += LUrowi[k] * LUcolj[k]; + } + + LUrowi[j] = LUcolj[i] -= s; + } + + // Find pivot and exchange if necessary. + int p = j; + for (int i = j + 1; i < m; i++) { + if (Math.abs(LUcolj[i]) > Math.abs(LUcolj[p])) { + p = i; + } + } + if (p != j) { + for (int k = 0; k < n; k++) { + double t = LU[p][k]; + LU[p][k] = LU[j][k]; + LU[j][k] = t; + } + int k = piv[p]; + piv[p] = piv[j]; + piv[j] = k; + pivsign = -pivsign; + } + + // Compute multipliers. + if (j < m & LU[j][j] != 0.0) { + for (int i = j + 1; i < m; i++) { + LU[i][j] /= LU[j][j]; + } + } + } + } + + /* ------------------------ + Temporary, experimental code. + ------------------------ *\ + + \** LU Decomposition, computed by Gaussian elimination. +

+ This constructor computes L and U with the "daxpy"-based elimination + algorithm used in LINPACK and MATLAB. In Java, we suspect the dot-product, + Crout algorithm will be faster. We have temporarily included this + constructor until timing experiments confirm this suspicion. +

+ @param A Rectangular matrix + @param linpackflag Use Gaussian elimination. Actual value ignored. + @return Structure to access L, U and piv. + *\ + + public LUDecomposition (Matrix A, int linpackflag) { + // Initialize. + LU = A.getArrayCopy(); + m = A.getRowDimension(); + n = A.getColumnDimension(); + piv = new int[m]; + for (int i = 0; i < m; i++) { + piv[i] = i; + } + pivsign = 1; + // Main loop. + for (int k = 0; k < n; k++) { + // Find pivot. + int p = k; + for (int i = k+1; i < m; i++) { + if (Math.Abs(LU[i][k]) > Math.Abs(LU[p][k])) { + p = i; + } + } + // Exchange if necessary. + if (p != k) { + for (int j = 0; j < n; j++) { + double t = LU[p][j]; LU[p][j] = LU[k][j]; LU[k][j] = t; + } + int t = piv[p]; piv[p] = piv[k]; piv[k] = t; + pivsign = -pivsign; + } + // Compute multipliers and eliminate k-th column. + if (LU[k][k] != 0.0) { + for (int i = k+1; i < m; i++) { + LU[i][k] /= LU[k][k]; + for (int j = k+1; j < n; j++) { + LU[i][j] -= LU[i][k]*LU[k][j]; + } + } + } + } + } + + \* ------------------------ + End of temporary code. + * ------------------------ */ + + /* ------------------------ + Public Methods + * ------------------------ */ + + /** + * Is the matrix nonsingular? + * + * @return true if U, and hence A, is nonsingular. + */ + public boolean isNonsingular() { + for (int j = 0; j < n; j++) { + if (LU[j][j] == 0) { + return false; + } + } + return true; + } + + /** + * Return lower triangular factor + * + * @return L + */ + public JaMatrix getL() { + JaMatrix X = new JaMatrix(m, n); + double[][] L = X.getArray(); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + if (i > j) { + L[i][j] = LU[i][j]; + } else if (i == j) { + L[i][j] = 1.0; + } else { + L[i][j] = 0.0; + } + } + } + return X; + } + + /** + * Return upper triangular factor + * + * @return U + */ + public JaMatrix getU() { + JaMatrix X = new JaMatrix(n, n); + double[][] U = X.getArray(); + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + if (i <= j) { + U[i][j] = LU[i][j]; + } else { + U[i][j] = 0.0; + } + } + } + return X; + } + + /** + * Return pivot permutation vector + * + * @return piv + */ + public int[] getPivot() { + int[] p = new int[m]; + for (int i = 0; i < m; i++) { + p[i] = piv[i]; + } + return p; + } + + /** + * Return pivot permutation vector as a one-dimensional double array + * + * @return (double) piv + */ + public double[] getDoublePivot() { + double[] vals = new double[m]; + for (int i = 0; i < m; i++) { + vals[i] = (double) piv[i]; + } + return vals; + } + + /** + * Determinant + * + * @return det(A) + * @throws IllegalArgumentException Matrix must be square + */ + public double det() { + if (m != n) { + throw new IllegalArgumentException("Matrix must be square."); + } + double d = (double) pivsign; + for (int j = 0; j < n; j++) { + d *= LU[j][j]; + } + return d; + } + + public double detLog() { + if (m != n) { + throw new IllegalArgumentException("Matrix must be square."); + } + double d = 0; + int sysbolCount = 0; + for (int j = 0; j < n; j++) { + if (LU[j][j] < 0) { + sysbolCount++; + } + if (LU[j][j] == 0) { + return Double.MIN_VALUE; + } + d += Math.log(Math.abs(LU[j][j])); + } + if (pivsign == 0) { + return Double.MIN_VALUE; + } + if (pivsign < 0) { + sysbolCount++; + } + if (sysbolCount % 2 == 1) { + return Double.MIN_VALUE; + } + return d; + } + + /** + * Solve A*X = B + * + * @param B A Matrix with as many rows as A and any number of columns. + * @return X so that L*U*X = B(piv,:) + * @throws IllegalArgumentException Matrix row dimensions must agree. + * @throws RuntimeException Matrix is singular. + */ + public JaMatrix solve(JaMatrix B) { + if (B.getRowDimension() != m) { + throw new IllegalArgumentException("Matrix row dimensions must agree."); + } + + if (!this.isNonsingular()) { + throw new RuntimeException("Matrix is singular."); + } + // Copy right hand side with pivoting + int nx = B.getColumnDimension(); + JaMatrix Xmat = B.getMatrix(piv, 0, nx - 1); + double[][] X1 = Xmat.getArray(); + + // Solve L*Y = B(piv,:) + for (int k = 0; k < n; k++) { + for (int i = k + 1; i < n; i++) { + for (int j = 0; j < nx; j++) { + X1[i][j] -= X1[k][j] * LU[i][k]; + } + } + } + // Solve U*X = Y; + for (int k = n - 1; k >= 0; k--) { + for (int j = 0; j < nx; j++) { + X1[k][j] /= LU[k][k]; + } + for (int i = 0; i < k; i++) { + for (int j = 0; j < nx; j++) { + X1[i][j] -= X1[k][j] * LU[i][k]; + } + } + } + return Xmat; + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/linalg/jama/Maths.java b/core/src/main/java/com/alibaba/alink/common/linalg/jama/Maths.java new file mode 100644 index 000000000..20a32fc89 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/linalg/jama/Maths.java @@ -0,0 +1 @@ +package com.alibaba.alink.common.linalg.jama; class Maths { /** * Sqrt(a^2 + b^2) without under/overflow. **/ public static double hypot(double a, double b) { double r; if (Math.abs(a) > Math.abs(b)) { r = b / a; r = Math.abs(a) * Math.sqrt(1 + r * r); } else if (b != 0) { r = a / b; r = Math.abs(b) * Math.sqrt(1 + r * r); } else { r = 0.0; } return r; } } \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/common/linalg/jama/QRDecomposition.java b/core/src/main/java/com/alibaba/alink/common/linalg/jama/QRDecomposition.java new file mode 100644 index 000000000..8a22e4203 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/linalg/jama/QRDecomposition.java @@ -0,0 +1,228 @@ +package com.alibaba.alink.common.linalg.jama; + +/** + * QR Decomposition. + *

+ * For an m-by-n matrix A with m >= n, the QR decomposition is an m-by-n + * orthogonal matrix Q and an n-by-n upper triangular matrix R so that + * A = Q*R. + *

+ * The QR decompostion always exists, even if the matrix does not have + * full rank, so the constructor will never fail. The primary use of the + * QR decomposition is in the least squares solution of nonsquare systems + * of simultaneous linear equations. This will fail if isFullRank() + * returns false. + */ +class QRDecomposition implements java.io.Serializable { + private static final long serialVersionUID = -6821106953011796474L; + + /* ------------------------ + Class variables + * ------------------------ */ + /** + * Array for internal storage of decomposition. + * + * @serial internal array storage. + */ + private double[][] QR; + /** + * Row and column dimensions. + * + * @serial column dimension. + * @serial row dimension. + */ + private int m, n; + /** + * Array for internal storage of diagonal of R. + * + * @serial diagonal of R. + */ + private double[] Rdiag; + + /* ------------------------ + Constructor + * ------------------------ */ + + /** + * QR Decomposition, computed by Householder reflections. + * + * @param A Rectangular matrix + * @return Structure to access R and the Householder vectors and compute Q. + */ + public QRDecomposition(JaMatrix A) { + // Initialize. + QR = A.getArrayCopy(); + m = A.getRowDimension(); + n = A.getColumnDimension(); + Rdiag = new double[n]; + + // Main loop. + for (int k = 0; k < n; k++) { + // Compute 2-norm of k-th column without under/overflow. + double nrm = 0; + for (int i = k; i < m; i++) { + nrm = Maths.hypot(nrm, QR[i][k]); + } + + if (nrm != 0.0) { + // Form k-th Householder vector. + if (QR[k][k] < 0) { + nrm = -nrm; + } + for (int i = k; i < m; i++) { + QR[i][k] /= nrm; + } + QR[k][k] += 1.0; + + // Apply transformation to remaining columns. + for (int j = k + 1; j < n; j++) { + double s = 0.0; + for (int i = k; i < m; i++) { + s += QR[i][k] * QR[i][j]; + } + s = -s / QR[k][k]; + for (int i = k; i < m; i++) { + QR[i][j] += s * QR[i][k]; + } + } + } + Rdiag[k] = -nrm; + } + } + + /* ------------------------ + Public Methods + * ------------------------ */ + + /** + * Is the matrix full rank? + * + * @return true if R, and hence A, has full rank. + */ + public boolean isFullRank() { + for (int j = 0; j < n; j++) { + if (Rdiag[j] == 0) { + return false; + } + } + return true; + } + + /** + * Return the Householder vectors + * + * @return Lower trapezoidal matrix whose columns define the reflections + */ + public JaMatrix getH() { + JaMatrix X = new JaMatrix(m, n); + double[][] H = X.getArray(); + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + if (i >= j) { + H[i][j] = QR[i][j]; + } else { + H[i][j] = 0.0; + } + } + } + return X; + } + + /** + * Return the upper triangular factor + * + * @return R + */ + public JaMatrix getR() { + JaMatrix X = new JaMatrix(n, n); + double[][] R = X.getArray(); + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + if (i < j) { + R[i][j] = QR[i][j]; + } else if (i == j) { + R[i][j] = Rdiag[i]; + } else { + R[i][j] = 0.0; + } + } + } + return X; + } + + /** + * Generate and return the (economy-sized) orthogonal factor + * + * @return Q + */ + public JaMatrix getQ() { + JaMatrix X = new JaMatrix(m, n); + double[][] Q = X.getArray(); + for (int k = n - 1; k >= 0; k--) { + for (int i = 0; i < m; i++) { + Q[i][k] = 0.0; + } + Q[k][k] = 1.0; + for (int j = k; j < n; j++) { + if (QR[k][k] != 0) { + double s = 0.0; + for (int i = k; i < m; i++) { + s += QR[i][k] * Q[i][j]; + } + s = -s / QR[k][k]; + for (int i = k; i < m; i++) { + Q[i][j] += s * QR[i][k]; + } + } + } + } + return X; + } + + /** + * Least squares solution of A*X = B + * + * @param B A Matrix with as many rows as A and any number of columns. + * @return X that minimizes the two norm of Q*R*X-B. + * @throws IllegalArgumentException Matrix row dimensions must agree. + * @throws RuntimeException Matrix is rank deficient. + */ + public JaMatrix solve(JaMatrix B) { + if (B.getRowDimension() != m) { + throw new IllegalArgumentException("Matrix row dimensions must agree."); + } + if (!this.isFullRank()) { + throw new RuntimeException("Matrix is rank deficient."); + } + + // Copy right hand side + int nx = B.getColumnDimension(); + double[][] X = B.getArrayCopy(); + + // Compute Y = transpose(Q)*B + for (int k = 0; k < n; k++) { + for (int j = 0; j < nx; j++) { + double s = 0.0; + for (int i = k; i < m; i++) { + s += QR[i][k] * X[i][j]; + } + s = -s / QR[k][k]; + for (int i = k; i < m; i++) { + X[i][j] += s * QR[i][k]; + } + } + } + // Solve R*X = Y; + for (int k = n - 1; k >= 0; k--) { + for (int j = 0; j < nx; j++) { + X[k][j] /= Rdiag[k]; + } + for (int i = 0; i < k; i++) { + for (int j = 0; j < nx; j++) { + X[i][j] -= X[k][j] * QR[i][k]; + } + } + } + return (new JaMatrix(X, n, nx).getMatrix(0, n - 1, 0, nx - 1)); + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/linalg/jama/SingularValueDecomposition.java b/core/src/main/java/com/alibaba/alink/common/linalg/jama/SingularValueDecomposition.java new file mode 100644 index 000000000..f5e3b0c28 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/linalg/jama/SingularValueDecomposition.java @@ -0,0 +1,562 @@ +package com.alibaba.alink.common.linalg.jama; + +/** + * Singular Value Decomposition. + *

+ * For an m-by-n matrix A with m >= n, the singular value decomposition is + * an m-by-n orthogonal matrix U, an n-by-n diagonal matrix S, and + * an n-by-n orthogonal matrix V so that A = U*S*V'. + *

+ * The singular values, sigma[k] = S[k][k], are ordered so that + * sigma[0] >= sigma[1] >= ... >= sigma[n-1]. + *

+ * The singular value decompostion always exists, so the constructor will + * never fail. The matrix condition number and the effective numerical + * rank can be computed from this decomposition. + */ +class SingularValueDecomposition implements java.io.Serializable { + private static final long serialVersionUID = -1998948033552340189L; + + /* ------------------------ + Class variables + * ------------------------ */ + /** + * Arrays for internal storage of U and V. + * + * @serial internal storage of U. + * @serial internal storage of V. + */ + private double[][] U, V; + /** + * Array for internal storage of singular values. + * + * @serial internal storage of singular values. + */ + private double[] s; + /** + * Row and column dimensions. + * + * @serial row dimension. + * @serial column dimension. + */ + private int m, n; + + /* ------------------------ + Constructor + * ------------------------ */ + + /** + * Construct the singular value decomposition + * + * @param Arg Rectangular matrix + * @return Structure to access U, S and V. + */ + public SingularValueDecomposition(JaMatrix Arg) { + + // Derived from LINPACK code. + // Initialize. + double[][] A = Arg.getArrayCopy(); + m = Arg.getRowDimension(); + n = Arg.getColumnDimension(); + + /* Apparently the failing cases are only a proper subset of (m= n"); } + */ + int nu = Math.min(m, n); + s = new double[Math.min(m + 1, n)]; + U = new double[m][nu]; + V = new double[n][n]; + double[] e = new double[n]; + double[] work = new double[m]; + boolean wantu = true; + boolean wantv = true; + + // Reduce A to bidiagonal form, storing the diagonal elements + // in s and the super-diagonal elements in e. + + int nct = Math.min(m - 1, n); + int nrt = Math.max(0, Math.min(n - 2, m)); + for (int k = 0; k < Math.max(nct, nrt); k++) { + if (k < nct) { + + // Compute the transformation for the k-th column and + // place the k-th diagonal in s[k]. + // Compute 2-norm of k-th column without under/overflow. + s[k] = 0; + for (int i = k; i < m; i++) { + s[k] = Maths.hypot(s[k], A[i][k]); + } + if (s[k] != 0.0) { + if (A[k][k] < 0.0) { + s[k] = -s[k]; + } + for (int i = k; i < m; i++) { + A[i][k] /= s[k]; + } + A[k][k] += 1.0; + } + s[k] = -s[k]; + } + for (int j = k + 1; j < n; j++) { + if ((k < nct) & (s[k] != 0.0)) { + + // Apply the transformation. + + double t = 0; + for (int i = k; i < m; i++) { + t += A[i][k] * A[i][j]; + } + t = -t / A[k][k]; + for (int i = k; i < m; i++) { + A[i][j] += t * A[i][k]; + } + } + + // Place the k-th row of A into e for the + // subsequent calculation of the row transformation. + + e[j] = A[k][j]; + } + if (wantu & (k < nct)) { + + // Place the transformation in U for subsequent back + // multiplication. + + for (int i = k; i < m; i++) { + U[i][k] = A[i][k]; + } + } + if (k < nrt) { + + // Compute the k-th row transformation and place the + // k-th super-diagonal in e[k]. + // Compute 2-norm without under/overflow. + e[k] = 0; + for (int i = k + 1; i < n; i++) { + e[k] = Maths.hypot(e[k], e[i]); + } + if (e[k] != 0.0) { + if (e[k + 1] < 0.0) { + e[k] = -e[k]; + } + for (int i = k + 1; i < n; i++) { + e[i] /= e[k]; + } + e[k + 1] += 1.0; + } + e[k] = -e[k]; + if ((k + 1 < m) & (e[k] != 0.0)) { + + // Apply the transformation. + + for (int i = k + 1; i < m; i++) { + work[i] = 0.0; + } + for (int j = k + 1; j < n; j++) { + for (int i = k + 1; i < m; i++) { + work[i] += e[j] * A[i][j]; + } + } + for (int j = k + 1; j < n; j++) { + double t = -e[j] / e[k + 1]; + for (int i = k + 1; i < m; i++) { + A[i][j] += t * work[i]; + } + } + } + if (wantv) { + + // Place the transformation in V for subsequent + // back multiplication. + + for (int i = k + 1; i < n; i++) { + V[i][k] = e[i]; + } + } + } + } + + // Set up the final bidiagonal matrix or order p. + + int p = Math.min(n, m + 1); + if (nct < n) { + s[nct] = A[nct][nct]; + } + if (m < p) { + s[p - 1] = 0.0; + } + if (nrt + 1 < p) { + e[nrt] = A[nrt][p - 1]; + } + e[p - 1] = 0.0; + + // If required, generate U. + + if (wantu) { + for (int j = nct; j < nu; j++) { + for (int i = 0; i < m; i++) { + U[i][j] = 0.0; + } + U[j][j] = 1.0; + } + for (int k = nct - 1; k >= 0; k--) { + if (s[k] != 0.0) { + for (int j = k + 1; j < nu; j++) { + double t = 0; + for (int i = k; i < m; i++) { + t += U[i][k] * U[i][j]; + } + t = -t / U[k][k]; + for (int i = k; i < m; i++) { + U[i][j] += t * U[i][k]; + } + } + for (int i = k; i < m; i++) { + U[i][k] = -U[i][k]; + } + U[k][k] = 1.0 + U[k][k]; + for (int i = 0; i < k - 1; i++) { + U[i][k] = 0.0; + } + } else { + for (int i = 0; i < m; i++) { + U[i][k] = 0.0; + } + U[k][k] = 1.0; + } + } + } + + // If required, generate V. + + if (wantv) { + for (int k = n - 1; k >= 0; k--) { + if ((k < nrt) & (e[k] != 0.0)) { + for (int j = k + 1; j < nu; j++) { + double t = 0; + for (int i = k + 1; i < n; i++) { + t += V[i][k] * V[i][j]; + } + t = -t / V[k + 1][k]; + for (int i = k + 1; i < n; i++) { + V[i][j] += t * V[i][k]; + } + } + } + for (int i = 0; i < n; i++) { + V[i][k] = 0.0; + } + V[k][k] = 1.0; + } + } + + // Main iteration loop for the singular values. + + int pp = p - 1; + int iter = 0; + double eps = Math.pow(2.0, -52.0); + double tiny = Math.pow(2.0, -966.0); + while (p > 0) { + int k, kase; + + // Here is where a test for too many iterations would go. + + // This section of the program inspects for + // negligible elements in the s and e arrays. On + // completion the variables kase and k are set as follows. + + // kase = 1 if s(p) and e[k-1] are negligible and k

= -1; k--) { + if (k == -1) { + break; + } + if (Math.abs(e[k]) + <= tiny + eps * (Math.abs(s[k]) + Math.abs(s[k + 1]))) { + e[k] = 0.0; + break; + } + } + if (k == p - 2) { + kase = 4; + } else { + int ks; + for (ks = p - 1; ks >= k; ks--) { + if (ks == k) { + break; + } + double t = (ks != p ? Math.abs(e[ks]) : 0.) + + (ks != k + 1 ? Math.abs(e[ks - 1]) : 0.); + if (Math.abs(s[ks]) <= tiny + eps * t) { + s[ks] = 0.0; + break; + } + } + if (ks == k) { + kase = 3; + } else if (ks == p - 1) { + kase = 1; + } else { + kase = 2; + k = ks; + } + } + k++; + + // Perform the task indicated by kase. + + switch (kase) { + + // Deflate negligible s(p). + + case 1: { + double f = e[p - 2]; + e[p - 2] = 0.0; + for (int j = p - 2; j >= k; j--) { + double t = Maths.hypot(s[j], f); + double cs = s[j] / t; + double sn = f / t; + s[j] = t; + if (j != k) { + f = -sn * e[j - 1]; + e[j - 1] = cs * e[j - 1]; + } + if (wantv) { + for (int i = 0; i < n; i++) { + t = cs * V[i][j] + sn * V[i][p - 1]; + V[i][p - 1] = -sn * V[i][j] + cs * V[i][p - 1]; + V[i][j] = t; + } + } + } + } + break; + + // Split at negligible s(k). + + case 2: { + double f = e[k - 1]; + e[k - 1] = 0.0; + for (int j = k; j < p; j++) { + double t = Maths.hypot(s[j], f); + double cs = s[j] / t; + double sn = f / t; + s[j] = t; + f = -sn * e[j]; + e[j] = cs * e[j]; + if (wantu) { + for (int i = 0; i < m; i++) { + t = cs * U[i][j] + sn * U[i][k - 1]; + U[i][k - 1] = -sn * U[i][j] + cs * U[i][k - 1]; + U[i][j] = t; + } + } + } + } + break; + + // Perform one qr step. + + case 3: { + + // Calculate the shift. + + double scale = Math.max(Math.max(Math.max(Math.max( + Math.abs(s[p - 1]), Math.abs(s[p - 2])), Math.abs(e[p - 2])), + Math.abs(s[k])), Math.abs(e[k])); + double sp = s[p - 1] / scale; + double spm1 = s[p - 2] / scale; + double epm1 = e[p - 2] / scale; + double sk = s[k] / scale; + double ek = e[k] / scale; + double b = ((spm1 + sp) * (spm1 - sp) + epm1 * epm1) / 2.0; + double c = (sp * epm1) * (sp * epm1); + double shift = 0.0; + if ((b != 0.0) | (c != 0.0)) { + shift = Math.sqrt(b * b + c); + if (b < 0.0) { + shift = -shift; + } + shift = c / (b + shift); + } + double f = (sk + sp) * (sk - sp) + shift; + double g = sk * ek; + + // Chase zeros. + + for (int j = k; j < p - 1; j++) { + double t = Maths.hypot(f, g); + double cs = f / t; + double sn = g / t; + if (j != k) { + e[j - 1] = t; + } + f = cs * s[j] + sn * e[j]; + e[j] = cs * e[j] - sn * s[j]; + g = sn * s[j + 1]; + s[j + 1] = cs * s[j + 1]; + if (wantv) { + for (int i = 0; i < n; i++) { + t = cs * V[i][j] + sn * V[i][j + 1]; + V[i][j + 1] = -sn * V[i][j] + cs * V[i][j + 1]; + V[i][j] = t; + } + } + t = Maths.hypot(f, g); + cs = f / t; + sn = g / t; + s[j] = t; + f = cs * e[j] + sn * s[j + 1]; + s[j + 1] = -sn * e[j] + cs * s[j + 1]; + g = sn * e[j + 1]; + e[j + 1] = cs * e[j + 1]; + if (wantu && (j < m - 1)) { + for (int i = 0; i < m; i++) { + t = cs * U[i][j] + sn * U[i][j + 1]; + U[i][j + 1] = -sn * U[i][j] + cs * U[i][j + 1]; + U[i][j] = t; + } + } + } + e[p - 2] = f; + iter = iter + 1; + } + break; + + // Convergence. + + case 4: { + + // Make the singular values positive. + + if (s[k] <= 0.0) { + s[k] = (s[k] < 0.0 ? -s[k] : 0.0); + if (wantv) { + for (int i = 0; i <= pp; i++) { + V[i][k] = -V[i][k]; + } + } + } + + // Order the singular values. + + while (k < pp) { + if (s[k] >= s[k + 1]) { + break; + } + double t = s[k]; + s[k] = s[k + 1]; + s[k + 1] = t; + if (wantv && (k < n - 1)) { + for (int i = 0; i < n; i++) { + t = V[i][k + 1]; + V[i][k + 1] = V[i][k]; + V[i][k] = t; + } + } + if (wantu && (k < m - 1)) { + for (int i = 0; i < m; i++) { + t = U[i][k + 1]; + U[i][k + 1] = U[i][k]; + U[i][k] = t; + } + } + k++; + } + iter = 0; + p--; + } + break; + } + } + } + + /* ------------------------ + Public Methods + * ------------------------ */ + + /** + * Return the left singular vectors + * + * @return U + */ + public JaMatrix getU() { + return new JaMatrix(U, m, Math.min(m + 1, n)); + } + + /** + * Return the right singular vectors + * + * @return V + */ + public JaMatrix getV() { + return new JaMatrix(V, n, n); + } + + /** + * Return the one-dimensional array of singular values + * + * @return diagonal of S. + */ + public double[] getSingularValues() { + return s; + } + + /** + * Return the diagonal matrix of singular values + * + * @return S + */ + public JaMatrix getS() { + JaMatrix X = new JaMatrix(n, n); + double[][] S = X.getArray(); + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + S[i][j] = 0.0; + } + S[i][i] = this.s[i]; + } + return X; + } + + /** + * Two norm + * + * @return max(S) + */ + public double norm2() { + return s[0]; + } + + /** + * Two norm condition number + * + * @return max(S)/min(S) + */ + public double cond() { + return s[0] / s[Math.min(m, n) - 1]; + } + + /** + * Effective numerical matrix rank + * + * @return Number of nonnegligible singular values. + */ + public int rank() { + double eps = Math.pow(2.0, -52.0); + double tol = Math.max(m, n) * s[0] * eps; + int r = 0; + for (int i = 0; i < s.length; i++) { + if (s[i] > tol) { + r++; + } + } + return r; + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/linalg/tensor/Tensor.java b/core/src/main/java/com/alibaba/alink/common/linalg/tensor/Tensor.java index e5c18c991..e24875eb5 100644 --- a/core/src/main/java/com/alibaba/alink/common/linalg/tensor/Tensor.java +++ b/core/src/main/java/com/alibaba/alink/common/linalg/tensor/Tensor.java @@ -1,6 +1,6 @@ package com.alibaba.alink.common.linalg.tensor; -import com.alibaba.alink.common.DataTypeDisplayInterface; +import com.alibaba.alink.common.viz.DataTypeDisplayInterface; import com.alibaba.alink.common.linalg.tensor.TensorUtil.CoordInc; import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils; import org.apache.commons.lang3.ArrayUtils; diff --git a/core/src/main/java/com/alibaba/alink/common/mapper/ModelStreamModelMapperAdapter.java b/core/src/main/java/com/alibaba/alink/common/mapper/ModelStreamModelMapperAdapter.java index f1bbe46f3..75aa79272 100644 --- a/core/src/main/java/com/alibaba/alink/common/mapper/ModelStreamModelMapperAdapter.java +++ b/core/src/main/java/com/alibaba/alink/common/mapper/ModelStreamModelMapperAdapter.java @@ -9,9 +9,9 @@ import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.io.filesystem.FilePath; -import com.alibaba.alink.operator.common.stream.model.ModelStreamFileScanner; -import com.alibaba.alink.operator.common.stream.model.ModelStreamFileScanner.ScanTask; -import com.alibaba.alink.operator.common.stream.model.ModelStreamUtils; +import com.alibaba.alink.operator.common.modelstream.ModelStreamFileScanner; +import com.alibaba.alink.operator.common.modelstream.ModelStreamFileScanner.ScanTask; +import com.alibaba.alink.operator.common.modelstream.ModelStreamUtils; import com.alibaba.alink.params.ModelStreamScanParams; import java.sql.Timestamp; diff --git a/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/DataConversionUtils.java b/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/DataConversionUtils.java index 6978eb136..6f8e5d4ba 100644 --- a/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/DataConversionUtils.java +++ b/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/DataConversionUtils.java @@ -2,7 +2,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.MTable; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; diff --git a/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/PyTableFn.java b/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/PyTableFn.java index 854522a65..793706329 100644 --- a/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/PyTableFn.java +++ b/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/PyTableFn.java @@ -9,7 +9,7 @@ import org.apache.flink.util.Collector; import com.alibaba.alink.common.AlinkGlobalConfiguration; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkIllegalOperationException; import com.alibaba.alink.common.exceptions.AkPreconditions; import com.alibaba.alink.common.utils.Functional.SerializableBiFunction; diff --git a/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/impl/PyMTableScalarFn.java b/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/impl/PyMTableScalarFn.java index 39c909c1a..b727bb9c0 100644 --- a/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/impl/PyMTableScalarFn.java +++ b/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/impl/PyMTableScalarFn.java @@ -2,7 +2,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.MTable; import com.alibaba.alink.common.pyrunner.fn.BasePyScalarFn; import com.alibaba.alink.common.pyrunner.fn.PyScalarFnHandle; diff --git a/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/impl/PyTensorScalarFn.java b/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/impl/PyTensorScalarFn.java index c074cfc67..3e2754ba8 100644 --- a/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/impl/PyTensorScalarFn.java +++ b/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/impl/PyTensorScalarFn.java @@ -2,7 +2,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.tensor.Tensor; import com.alibaba.alink.common.pyrunner.fn.BasePyScalarFn; import com.alibaba.alink.common.pyrunner.fn.PyScalarFnHandle; diff --git a/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/impl/PyVectorScalarFn.java b/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/impl/PyVectorScalarFn.java index 9927714c5..aad21b3ac 100644 --- a/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/impl/PyVectorScalarFn.java +++ b/core/src/main/java/com/alibaba/alink/common/pyrunner/fn/impl/PyVectorScalarFn.java @@ -2,7 +2,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.Vector; import com.alibaba.alink.common.pyrunner.fn.BasePyScalarFn; import com.alibaba.alink.common.pyrunner.fn.PyScalarFnHandle; diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/BuildInAggRegister.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/BuiltInAggRegister.java similarity index 97% rename from core/src/main/java/com/alibaba/alink/common/sql/builtin/BuildInAggRegister.java rename to core/src/main/java/com/alibaba/alink/common/sql/builtin/BuiltInAggRegister.java index df1219022..3d7d9af5b 100644 --- a/core/src/main/java/com/alibaba/alink/common/sql/builtin/BuildInAggRegister.java +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/BuiltInAggRegister.java @@ -35,7 +35,7 @@ import com.alibaba.alink.common.sql.builtin.time.ToTimeStamp; import com.alibaba.alink.common.sql.builtin.time.UnixTimeStamp; -public class BuildInAggRegister { +public class BuiltInAggRegister { public static final String EXTEND = "_preceding"; public static final String CONSIDER_NULL_EXTEND = "_including_null"; @@ -79,9 +79,11 @@ public static void registerUdaf(StreamTableEnvironment env) { env.registerFunction(UdafName.LAST_DISTINCT.name + CONSIDER_NULL_EXTEND, new LastDistinctValueUdaf(true)); env.registerFunction(UdafName.LAST_TIME.name, new LastTimeUdaf()); env.registerFunction(UdafName.LAST_VALUE.name, new LastValueUdaf()); - env.registerFunction(UdafName.LAST_VALUE.name + CONSIDER_NULL_EXTEND, new LastValueUdaf(true)); + env.registerFunction(UdafName.LAST_VALUE_CONSIDER_NULL.name, new LastValueUdaf(true)); env.registerFunction(UdafName.LISTAGG.name, new ListAggUdaf()); env.registerFunction(UdafName.LISTAGG.name + EXTEND, new ListAggUdaf(true)); + env.registerFunction(UdafName.CONCAT_AGG.name, new ListAggUdaf()); + env.registerFunction(UdafName.CONCAT_AGG.name + EXTEND, new ListAggUdaf(true)); env.registerFunction(UdafName.MODE.name, new ModeUdaf(false)); env.registerFunction(UdafName.MODE.name + EXTEND, new ModeUdaf(true)); env.registerFunction(UdafName.FREQ.name, new FreqUdaf(false)); @@ -89,7 +91,6 @@ public static void registerUdaf(StreamTableEnvironment env) { env.registerFunction(UdafName.IS_EXIST.name, new IsExistUdaf()); env.registerFunction(UdafName.TIMESERIES_AGG.name + EXTEND, new TimeSeriesAgg(true)); env.registerFunction(UdafName.TIMESERIES_AGG.name, new TimeSeriesAgg(false)); - env.registerFunction(UdafName.CONCAT_AGG.name, new ListAggUdaf()); } public static void registerUdaf(BatchTableEnvironment env) { diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/UdafName.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/UdafName.java index cb7c3f170..5d7e14f68 100644 --- a/core/src/main/java/com/alibaba/alink/common/sql/builtin/UdafName.java +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/UdafName.java @@ -17,7 +17,10 @@ public enum UdafName { LAG("lag"), LAST_DISTINCT("last_distinct"), LAST_TIME("last_time"), + + //last_value is conflict in blink, so last_value_impl. and transform. LAST_VALUE("last_value_impl"), + LAST_VALUE_CONSIDER_NULL("last_value_including_null"), LISTAGG("listagg"), MODE("mode"), diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/BaseSummaryUdaf.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/BaseSummaryUdaf.java index 89014331d..cb133da6d 100644 --- a/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/BaseSummaryUdaf.java +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/BaseSummaryUdaf.java @@ -1,6 +1,5 @@ package com.alibaba.alink.common.sql.builtin.agg; - import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.sql.builtin.agg.BaseSummaryUdaf.SummaryData; @@ -48,7 +47,6 @@ public SummaryData createAccumulator() { return new SummaryData(excludeLast); } - public static class SummaryData { public long count = 0; public double sum = 0; @@ -202,6 +200,9 @@ public void reset() { } public void merge(SummaryData data) { + if (this.handle == null) { + this.handle = data.handle; + } sum += data.sum; count += data.count; squareSum += data.squareSum; @@ -232,6 +233,5 @@ public boolean equals(Object o) { return true; } - } } diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/BaseUdaf.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/BaseUdaf.java index e302cabd9..00a77c6a6 100644 --- a/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/BaseUdaf.java +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/BaseUdaf.java @@ -2,7 +2,7 @@ import org.apache.flink.table.functions.AggregateFunction; -public abstract class BaseUdaf extends AggregateFunction { +public abstract class BaseUdaf extends AggregateFunction { public void accumulate(ACC acc, Object... values) {} @@ -10,14 +10,12 @@ public void retract(ACC acc, Object... values) {} public void resetAccumulator(ACC acc) {} - public void merge(ACC acc, Iterable it) {} + public void merge(ACC acc, Iterable it) {} ACC acc; public void accumulateBatch(Object... values) { - if (acc == null) { - acc = createAccumulator(); - } + createAccumulatorAndSet(); accumulate(acc, values); } @@ -25,4 +23,10 @@ public T getValueBatch() { return getValue(acc); } + private void createAccumulatorAndSet() { + if (acc == null) { + acc = createAccumulator(); + } + } + } diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/CountUdaf.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/CountUdaf.java index 67cdad02f..1118f0c2a 100644 --- a/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/CountUdaf.java +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/CountUdaf.java @@ -3,7 +3,6 @@ public class CountUdaf extends BaseSummaryUdaf { - public CountUdaf() { this(false); } @@ -44,4 +43,5 @@ public void retract(SummaryData acc, Object... values) { } acc.retractData(1); } + } diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/LastDistinctValueUdaf.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/LastDistinctValueUdaf.java index d084b5e59..7ab9560e5 100644 --- a/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/LastDistinctValueUdaf.java +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/LastDistinctValueUdaf.java @@ -2,13 +2,12 @@ import org.apache.flink.api.java.tuple.Tuple3; -import com.alibaba.alink.common.exceptions.AkIllegalDataException; +import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; import com.alibaba.alink.common.sql.builtin.agg.LastDistinctValueUdaf.LastDistinctSaveFirst; import java.sql.Timestamp; - -public class LastDistinctValueUdaf extends BaseUdaf { +public class LastDistinctValueUdaf extends BaseUdaf { private double timeInterval = -1; private final boolean considerNull; @@ -25,7 +24,6 @@ public LastDistinctValueUdaf(double timeInterval) { this.considerNull = false; } - @Override public Object getValue(LastDistinctSaveFirst accumulator) { return accumulator.query(); @@ -39,15 +37,29 @@ public LastDistinctSaveFirst createAccumulator() { return new LastDistinctSaveFirst(timeInterval, considerNull); } + //key, valueCol, timeCol, timeInterval @Override public void accumulate(LastDistinctSaveFirst acc, Object... values) { - if (values.length != 4) { - throw new AkIllegalDataException(""); + if (values.length != 3 && values.length != 4) { + throw new AkIllegalArgumentException("values length must be 3 or 4."); + } + Object key = null; + Object value = null; + Object eventTime = null; + double timeInterval = 0; + if (4 == values.length) { + key = values[0]; + value = values[1]; + eventTime = values[2]; + timeInterval = Double.parseDouble(values[3].toString()); + } else { + key = values[0]; + value = values[0]; + eventTime = values[1]; + timeInterval = Double.parseDouble(values[2].toString()); } - Object key = values[0]; - Object value = values[1]; - Object eventTime = values[2]; - acc.setTimeInterval(Double.parseDouble(values[3].toString())); + + acc.setTimeInterval(timeInterval); acc.addOne(key, value, eventTime); acc.setLastKey(key); acc.setLastTime(eventTime); @@ -128,7 +140,7 @@ public LastDistinct(double timeInterval) { } void addOne(Object key, Object value, double currentTime) { - Tuple3 currentNode = Tuple3.of(key, value, currentTime); + Tuple3 currentNode = Tuple3.of(key, value, currentTime); if (firstObj == null) { firstObj = currentNode; } else if (secondObj == null) { @@ -168,7 +180,6 @@ Object query(Object key, double currentTime) { } } - public void addOne(Object key, Object value, Object currentTime) { if (currentTime instanceof Timestamp) { addOne(key, value, ((Timestamp) currentTime).getTime()); diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/LastValueTypeData.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/LastValueTypeData.java index 420b1f2bb..f757a78f7 100644 --- a/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/LastValueTypeData.java +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/LastValueTypeData.java @@ -1,454 +1,454 @@ package com.alibaba.alink.common.sql.builtin.agg; - import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import java.io.Serializable; +import java.math.BigDecimal; import java.sql.Timestamp; public class LastValueTypeData { - public static class LagData extends LastValueData { - - LagData(boolean considerNull) { - this.k = -1; - this.considerNull = considerNull; - } - - void addLagData(Object data, int k, Object defaultData) { - if (this.k == -1) { - this.k = k; - } - if (needDataType()) { - getType(data); - } - this.defaultData = defaultData; - thisValue = new Node(data, -1); - lastValue.append(thisValue); - if (considerNull) { - if (this.k == 0 && lastValue.dataSize > 1 || - this.k != 0 && lastValue.dataSize > this.k) { - lastValue.removeRoot(); - } - } else { - if (this.k == 0 && lastValue.noNullNum > 1 || - this.k != 0 && lastValue.noNullNum > this.k) { - lastValue.removeRoot(); - } - } - } - - void addLagData(Object data, int k) { - addLagData(data, k, null); - } - - Object getLagData() { - return transformData(lastValue.lastKData(k, defaultData)); - } - - Object getLagDataConsiderNull() { - return transformData(lastValue.lastKDataConsideringNull(k, defaultData)); - } - - private TypeInformation dataType = null; - private boolean passTrough = false; - - private boolean needDataType() { - if (passTrough) { - return false; - } - return dataType == null; - } - - private void getType(Object data) { - if (data == null) { - return; - } - if (data instanceof Double) { - dataType = Types.DOUBLE; - } else if (data instanceof Long) { - dataType = Types.LONG; - } else if (data instanceof Integer) { - dataType = Types.INT; - } else if (data instanceof Short) { - dataType = Types.SHORT; - } else if (data instanceof Float) { - dataType = Types.FLOAT; - } else if (data instanceof Byte) { - dataType = Types.BYTE; - } else { - passTrough = true; - } - } - - Object transformData(Object inputData) { - if (passTrough || inputData == null || dataType != null) { - return inputData; - } - - Double data = ((Number) inputData).doubleValue(); - if (Types.LONG.equals(dataType)) { - return data.longValue(); - } else if (Types.INT.equals(dataType)) { - return data.intValue(); - } else if (Types.SHORT.equals(dataType)) { - return data.shortValue(); - } else if (Types.FLOAT.equals(dataType)) { - return data.floatValue(); - } else if (Types.BYTE.equals(dataType)) { - return data.byteValue(); - } else { - return data; - } - } - } - - public static class LastValueData { - int k; - LinkedData lastValue = null; - Node thisValue = null; - public boolean hasParsed = false; - Object defaultData = null; - boolean considerNull; - - LastValueData(boolean considerNull) { - this.lastValue = new LinkedData(); - this.considerNull = considerNull; - this.k = -1; - } - - public double parseData(Object data, double timeInterval) { - if (!hasParsed) { - hasParsed = true; - return Double.parseDouble(data.toString()); - } - return timeInterval; - } - - public LastValueData() { - this(false); - } - - void reset() { - lastValue = new LinkedData(); - thisValue = null; - } - - void addData(Object data, Object timeStamp, int k, double timeInterval) { - if (this.k == -1) { - this.k = k; - } - long time; - if (timeStamp instanceof Timestamp) { - time = ((Timestamp) timeStamp).getTime(); - } else { - time = (long) timeStamp; - } - thisValue = new Node(data, time); - addLatestNode(lastValue, thisValue, this.k, time, timeInterval, considerNull); - } - - void retractData() { - lastValue.removeRoot(); - } - - Object getData() { - return lastValue.lastKData(k, null); - } - - Object getDataConsideringNull() { - return lastValue.lastKDataConsideringNull(k, null); - } - - - @Override - public boolean equals(Object o) { - if (!(o instanceof LastValueData)) { - return false; - } - LinkedData otherLastValue = ((LastValueData) o).lastValue; - return AggUtil.judgeNull(lastValue, otherLastValue, LastValueTypeData::judgeLinkData); - } - - static void addLatestNode(LinkedData linkedData, Node node, int k, long time, - double timeInterval, boolean considerNull) { - linkedData.append(node); - if (timeInterval != -1) { - if (considerNull) { - while (k == 0 && linkedData.dataSize > 1 || - k != 0 && linkedData.dataSize > k - || linkedData.root != null && linkedData.root.time > time + timeInterval) { - linkedData.removeRoot(); - } - } else { - while (k == 0 && linkedData.noNullNum > 1 || - k != 0 && linkedData.noNullNum > k || - linkedData.root != null && linkedData.root.time > time + timeInterval) { - linkedData.removeRoot(); - } - } - } - } - } - - public static boolean judgeLinkData(Object a, Object b) { - Node aNode = ((LinkedData) a).root; - Node bNode = ((LinkedData) b).root; - while (aNode != null || bNode != null) { - if (!aNode.getData().equals(bNode.getData())) { - return false; - } - } - if (null == aNode && null == bNode) { - return true; - } - return false; - } - - public static class LastTimeData extends LastValueData { - - @Override - Object getData() { - Long res = (Long) lastValue.lastTime(k); - if (res == null) { - return null; - } - return new Timestamp(res); - } - } - - public static class SumLastData extends LastValueData { - NumberTypeHandle handle = null; - - @Override - void addData(Object data, Object timeStamp, int k, double timeInterval) { - if (handle == null) { - handle = new NumberTypeHandle(data); - } - super.addData(data, timeStamp, k, timeInterval); - } - - @Override - Object getData() { - double res = (double) lastValue.sumLastK(k); - return handle.transformData(res); - } - } - - public static void merge(LastValueData acc, Iterable it) { - boolean firstData = acc.lastValue == null; - for (LastValueData lastValueData : it) { - if (lastValueData != null) { - if (firstData) { - acc = new LastValueData(); - firstData = false; - acc.lastValue.root = lastValueData.lastValue.root; - acc.lastValue.lastNode = lastValueData.lastValue.lastNode; - } else { - acc.lastValue.lastNode.setNext(lastValueData.lastValue.root); - acc.lastValue.lastNode = lastValueData.lastValue.lastNode; - } - acc.lastValue.dataSize += lastValueData.lastValue.dataSize; - } - } - } - - public static void merge(LagData acc, Iterable it) { - boolean firstData = acc.lastValue == null; - for (LagData lastValueData : it) { - if (lastValueData != null) { - if (firstData) { - acc = new LagData(lastValueData.considerNull); - firstData = false; - acc.lastValue.root = lastValueData.lastValue.root; - acc.lastValue.lastNode = lastValueData.lastValue.lastNode; - } else { - acc.lastValue.lastNode.setNext(lastValueData.lastValue.root); - acc.lastValue.lastNode = lastValueData.lastValue.lastNode; - } - acc.lastValue.dataSize += lastValueData.lastValue.dataSize; - } - } - } - - public static class LinkedData implements Serializable { - public Node root; - public Node lastNode; - public int dataSize;//the number of all the ip - public int noNullNum = 0; - - public LinkedData() { - } - - public LinkedData(Node node) { - addRoot(node); - } - - void addRoot(Node node) { - root = node; - lastNode = node; - dataSize = 1; - if (node.data != null) { - ++noNullNum; - } - } - - //remove the earliest data. - void removeRoot() { - if (root.data != null) { - --noNullNum; - } - if (dataSize > 1) { - dataSize -= 1; - root = root.next; - } else { - dataSize -= 1; - root = null; - lastNode = null; - } - } - - void append(Node node) { - if (root == null) { - addRoot(node); - } else { - if (node.data != null) { - ++noNullNum; - } - dataSize += 1; - lastNode.setNext(node); - node.setPrevious(lastNode); - lastNode = node; - } - } - - //return the last k data. - public Object lastKData(int k, Object defaultData) { - if (noNullNum == 0 || k > dataSize) { - return defaultData; - } - int index = k; - Node p = lastNode; - while (index > 0) { - p = p.getPrevious(); - while (p!=null&&p.data==null) { - p = p.getPrevious(); - } - if (p == null) { - return defaultData; - } - --index; - } - return p.data; - } - - public Object lastKDataConsideringNull(int k, Object defaultData) { - if (dataSize == 0 || k > dataSize) { - return defaultData; - } - int index = k; - Node p = lastNode; - while (index > 0) { - p = p.getPrevious(); - if (p == null) { - return defaultData; - } - --index; - } - return p.data; - } - - //return the time of last k data. - public Object lastTime(int k) { - if (dataSize == 0 || k > dataSize) { - return null; - } - int index = k; - Node p = lastNode; - while (index > 0) { - p = p.getPrevious(); - while (p!=null&&p.data==null) { - p = p.getPrevious(); - } - if (p == null) { - return null; - } - --index; - } - return p.time; - } - - //return the sum of last k data. - public Object sumLastK(int k) { - if (dataSize == 0) { - return 0.0; - } - double res = 0; - int index = k; - Node p = lastNode; - while (index > 0) { - res += ((Number) p.data).doubleValue(); - p = p.getPrevious(); - while (p!=null&&p.data==null) { - p = p.getPrevious(); - } - if (p == null) { - return res; - } - --index; - } - - return res; - } - } - - public static class Node implements Serializable { - private Object data; - private long time; - private Node next = null; - private Node previous = null; - - public Node(Object data, long time) { - setData(data, time); - } - - public void setData(Object data, long time) { - this.data = data; - this.time = time; - } - - public Object getData() { - return data; - } - - public double getTime() { - return time; - } - - public void setNext(Node next) { - this.next = next; - } - - public Node getNext() { - return next; - } - - public boolean hasNext() { - return !(next == null); - } - - public void setPrevious(Node previous) { - this.previous = previous; - } - - public Node getPrevious() { - return previous; - } - - public boolean hasPrevious() { - return !(previous == null); - } - } + public static class LagData extends LastValueData { + + LagData(boolean considerNull) { + this.k = -1; + this.considerNull = considerNull; + } + + void addLagData(Object data, int k, Object defaultData) { + if (this.k == -1) { + this.k = k; + } + if (needDataType()) { + getType(data); + } + this.defaultData = defaultData; + thisValue = new Node(data, -1); + lastValue.append(thisValue); + if (considerNull) { + if (this.k == 0 && lastValue.dataSize > 1 || this.k != 0 && lastValue.dataSize > this.k) { + lastValue.removeRoot(); + } + } else { + if (this.k == 0 && lastValue.noNullNum > 1 || this.k != 0 && lastValue.noNullNum > this.k) { + lastValue.removeRoot(); + } + } + } + + void addLagData(Object data, int k) { + addLagData(data, k, null); + } + + Object getLagData() { + return transformData(lastValue.lastKData(k, defaultData)); + } + + Object getLagDataConsiderNull() { + return transformData(lastValue.lastKDataConsideringNull(k, defaultData)); + } + + private TypeInformation dataType = null; + private boolean passTrough = false; + + private boolean needDataType() { + if (passTrough) { + return false; + } + return dataType == null; + } + + private void getType(Object data) { + if (data == null) { + return; + } + if (data instanceof Double) { + dataType = Types.DOUBLE; + } else if (data instanceof Long) { + dataType = Types.LONG; + } else if (data instanceof Integer) { + dataType = Types.INT; + } else if (data instanceof Short) { + dataType = Types.SHORT; + } else if (data instanceof Float) { + dataType = Types.FLOAT; + } else if (data instanceof Byte) { + dataType = Types.BYTE; + } else if (data instanceof BigDecimal) { + dataType = Types.BIG_DEC; + }else { + passTrough = true; + } + + } + + Object transformData(Object inputData) { + if (passTrough || inputData == null || dataType == null) { + return inputData; + } + + Double data = ((Number) inputData).doubleValue(); + if (Types.LONG.equals(dataType)) { + return data.longValue(); + } else if (Types.INT.equals(dataType)) { + return data.intValue(); + } else if (Types.SHORT.equals(dataType)) { + return data.shortValue(); + } else if (Types.FLOAT.equals(dataType)) { + return data.floatValue(); + } else if (Types.BYTE.equals(dataType)) { + return data.byteValue(); + } else if (Types.BIG_DEC.equals(dataType)) { + return inputData; + } else { + return data; + } + } + } + + public static class LastValueData { + int k; + LinkedData lastValue = null; + Node thisValue = null; + public boolean hasParsed = false; + Object defaultData = null; + boolean considerNull; + + LastValueData(boolean considerNull) { + this.lastValue = new LinkedData(); + this.considerNull = considerNull; + this.k = -1; + } + + public double parseData(Object data, double timeInterval) { + if (!hasParsed) { + hasParsed = true; + return Double.parseDouble(data.toString()); + } + return timeInterval; + } + + public LastValueData() { + this(false); + } + + void reset() { + lastValue = new LinkedData(); + thisValue = null; + } + + void addData(Object data, Object timeStamp, int k, double timeInterval) { + if (this.k == -1) { + this.k = k; + } + long time; + if (timeStamp instanceof Timestamp) { + time = ((Timestamp) timeStamp).getTime(); + } else { + time = (long) timeStamp; + } + thisValue = new Node(data, time); + addLatestNode(lastValue, thisValue, this.k, time, timeInterval, considerNull); + } + + void retractData() { + lastValue.removeRoot(); + } + + Object getData() { + return lastValue.lastKData(k, null); + } + + Object getDataConsideringNull() { + return lastValue.lastKDataConsideringNull(k, null); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof LastValueData)) { + return false; + } + LinkedData otherLastValue = ((LastValueData) o).lastValue; + return AggUtil.judgeNull(lastValue, otherLastValue, LastValueTypeData::judgeLinkData); + } + + static void addLatestNode(LinkedData linkedData, Node node, int k, long time, double timeInterval, + boolean considerNull) { + linkedData.append(node); + if (timeInterval != -1) { + if (considerNull) { + while (k == 0 && linkedData.dataSize > 1 || k != 0 && linkedData.dataSize > k + || linkedData.root != null && linkedData.root.time > time + timeInterval) { + linkedData.removeRoot(); + } + } else { + while (k == 0 && linkedData.noNullNum > 1 || k != 0 && linkedData.noNullNum > k + || linkedData.root != null && linkedData.root.time > time + timeInterval) { + linkedData.removeRoot(); + } + } + } + } + } + + public static boolean judgeLinkData(Object a, Object b) { + Node aNode = ((LinkedData) a).root; + Node bNode = ((LinkedData) b).root; + while (aNode != null || bNode != null) { + if (!aNode.getData().equals(bNode.getData())) { + return false; + } + } + if (null == aNode && null == bNode) { + return true; + } + return false; + } + + public static class LastTimeData extends LastValueData { + + @Override + Object getData() { + Long res = (Long) lastValue.lastTime(k); + if (res == null) { + return null; + } + return new Timestamp(res); + } + } + + public static class SumLastData extends LastValueData { + NumberTypeHandle handle = null; + + @Override + void addData(Object data, Object timeStamp, int k, double timeInterval) { + if (handle == null) { + handle = new NumberTypeHandle(data); + } + super.addData(data, timeStamp, k, timeInterval); + } + + @Override + Object getData() { + double res = (double) lastValue.sumLastK(k); + return handle.transformData(res); + } + } + + public static void merge(LastValueData acc, Iterable it) { + boolean firstData = acc.lastValue == null; + for (LastValueData lastValueData : it) { + if (lastValueData != null) { + if (firstData) { + acc = new LastValueData(); + firstData = false; + acc.lastValue.root = lastValueData.lastValue.root; + acc.lastValue.lastNode = lastValueData.lastValue.lastNode; + } else { + acc.lastValue.lastNode.setNext(lastValueData.lastValue.root); + acc.lastValue.lastNode = lastValueData.lastValue.lastNode; + } + acc.lastValue.dataSize += lastValueData.lastValue.dataSize; + } + } + } + + public static void merge(LagData acc, Iterable it) { + boolean firstData = acc.lastValue == null; + for (LagData lastValueData : it) { + if (lastValueData != null) { + if (firstData) { + acc = new LagData(lastValueData.considerNull); + firstData = false; + acc.lastValue.root = lastValueData.lastValue.root; + acc.lastValue.lastNode = lastValueData.lastValue.lastNode; + } else { + acc.lastValue.lastNode.setNext(lastValueData.lastValue.root); + acc.lastValue.lastNode = lastValueData.lastValue.lastNode; + } + acc.lastValue.dataSize += lastValueData.lastValue.dataSize; + } + } + } + + public static class LinkedData implements Serializable { + public Node root; + public Node lastNode; + public int dataSize;//the number of all the ip + public int noNullNum = 0; + + public LinkedData() { + } + + public LinkedData(Node node) { + addRoot(node); + } + + void addRoot(Node node) { + root = node; + lastNode = node; + dataSize = 1; + if (node.data != null) { + ++noNullNum; + } + } + + //remove the earliest data. + void removeRoot() { + if (root.data != null) { + --noNullNum; + } + if (dataSize > 1) { + dataSize -= 1; + root = root.next; + } else { + dataSize -= 1; + root = null; + lastNode = null; + } + } + + void append(Node node) { + if (root == null) { + addRoot(node); + } else { + if (node.data != null) { + ++noNullNum; + } + dataSize += 1; + lastNode.setNext(node); + node.setPrevious(lastNode); + lastNode = node; + } + } + + //return the last k data. + public Object lastKData(int k, Object defaultData) { + if (noNullNum == 0 || k > dataSize) { + return defaultData; + } + int index = k; + Node p = lastNode; + while (index > 0) { + p = p.getPrevious(); + while (p != null && p.data == null) { + p = p.getPrevious(); + } + if (p == null) { + return defaultData; + } + --index; + } + return p.data; + } + + public Object lastKDataConsideringNull(int k, Object defaultData) { + if (dataSize == 0 || k > dataSize) { + return defaultData; + } + int index = k; + Node p = lastNode; + while (index > 0) { + p = p.getPrevious(); + if (p == null) { + return defaultData; + } + --index; + } + return p.data; + } + + //return the time of last k data. + public Object lastTime(int k) { + if (dataSize == 0 || k > dataSize) { + return null; + } + int index = k; + Node p = lastNode; + while (index > 0) { + p = p.getPrevious(); + while (p != null && p.data == null) { + p = p.getPrevious(); + } + if (p == null) { + return null; + } + --index; + } + return p.time; + } + + //return the sum of last k data. + public Object sumLastK(int k) { + if (dataSize == 0) { + return 0.0; + } + double res = 0; + int index = k; + Node p = lastNode; + while (index > 0) { + res += ((Number) p.data).doubleValue(); + p = p.getPrevious(); + while (p != null && p.data == null) { + p = p.getPrevious(); + } + if (p == null) { + return res; + } + --index; + } + + return res; + } + } + + public static class Node implements Serializable { + private Object data; + private long time; + private Node next = null; + private Node previous = null; + + public Node(Object data, long time) { + setData(data, time); + } + + public void setData(Object data, long time) { + this.data = data; + this.time = time; + } + + public Object getData() { + return data; + } + + public double getTime() { + return time; + } + + public void setNext(Node next) { + this.next = next; + } + + public Node getNext() { + return next; + } + + public boolean hasNext() { + return !(next == null); + } + + public void setPrevious(Node previous) { + this.previous = previous; + } + + public Node getPrevious() { + return previous; + } + + public boolean hasPrevious() { + return !(previous == null); + } + } } diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/MTableAgg.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/MTableAgg.java index 13097e843..345813b26 100644 --- a/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/MTableAgg.java +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/MTableAgg.java @@ -1,6 +1,7 @@ package com.alibaba.alink.common.sql.builtin.agg; -import org.apache.flink.types.Row; +import + org.apache.flink.types.Row; import com.alibaba.alink.common.MTable; import com.alibaba.alink.common.utils.TableUtil; @@ -17,6 +18,11 @@ public class MTableAgg extends BaseUdaf > { private final boolean dropLast; private int sortColIdx; + //for json converter. + public MTableAgg() { + this(false); + } + public MTableAgg(boolean dropLast) { this.dropLast = dropLast; } diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/NumberTypeHandle.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/NumberTypeHandle.java index 0901c4128..1b8f5645c 100644 --- a/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/NumberTypeHandle.java +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/agg/NumberTypeHandle.java @@ -5,28 +5,37 @@ import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; -public class NumberTypeHandle { +import java.math.BigDecimal; +public class NumberTypeHandle { private TypeInformation dataType = null; + NumberTypeHandle(Object data) { getType(data); } Number transformData(Number inputData) { - Double data = inputData.doubleValue(); - if (Types.LONG.equals(dataType)) { - return data.longValue(); + if (Types.DOUBLE.equals(dataType)) { + return inputData.doubleValue(); + } else if (Types.LONG.equals(dataType)) { + return inputData.longValue(); } else if (Types.INT.equals(dataType)) { - return data.intValue(); - } else if (Types.SHORT.equals(dataType)) { - return data.shortValue(); + return inputData.intValue(); } else if (Types.FLOAT.equals(dataType)) { - return data.floatValue(); + return inputData.floatValue(); + } else if (Types.SHORT.equals(dataType)) { + return inputData.shortValue(); } else if (Types.BYTE.equals(dataType)) { - return data.byteValue(); - } else if (Types.DOUBLE.equals(dataType)) { - return data; + return inputData.byteValue(); + } else if (Types.BIG_DEC.equals(dataType)) { + if (inputData instanceof BigDecimal) { + return inputData; + } else if (inputData instanceof Double || inputData instanceof Float) { + return new BigDecimal(inputData.doubleValue()); + } else { + return new BigDecimal(inputData.longValue()); + } } throw new AkUnsupportedOperationException("Do not support this type: " + dataType); } @@ -44,6 +53,8 @@ private void getType(Object data) { dataType = Types.FLOAT; } else if (data instanceof Byte) { dataType = Types.BYTE; + } else if (data instanceof BigDecimal) { + dataType = Types.BIG_DEC; } else { throw new AkUnsupportedOperationException("We only support double, long, int, float, short, byte."); } diff --git a/core/src/main/java/com/alibaba/alink/common/AlinkTypes.java b/core/src/main/java/com/alibaba/alink/common/type/AlinkTypes.java similarity index 90% rename from core/src/main/java/com/alibaba/alink/common/AlinkTypes.java rename to core/src/main/java/com/alibaba/alink/common/type/AlinkTypes.java index 0dd8c204b..5fc3e60cf 100644 --- a/core/src/main/java/com/alibaba/alink/common/AlinkTypes.java +++ b/core/src/main/java/com/alibaba/alink/common/type/AlinkTypes.java @@ -1,11 +1,13 @@ -package com.alibaba.alink.common; +package com.alibaba.alink.common.type; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; import org.apache.flink.api.common.typeinfo.TypeHint; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.types.Row; +import com.alibaba.alink.common.MTable; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.linalg.Vector; @@ -30,7 +32,7 @@ * This class contains bi-direction mapping between TypeInformations and their names. */ public class AlinkTypes extends Types { - private static final HashBiMap > TYPES = HashBiMap.create(); + private static final HashBiMap > TYPES = HashBiMap.create(); /** * MTable type information. @@ -71,6 +73,8 @@ public class AlinkTypes extends Types { STRING_TENSOR )); + public static final TypeInformation VARBINARY = PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO; + static { TYPES.put("M_TABLE", M_TABLE); TYPES.put("DENSE_VECTOR", DENSE_VECTOR); @@ -85,6 +89,7 @@ public class AlinkTypes extends Types { TYPES.put("INT_TENSOR", INT_TENSOR); TYPES.put("LONG_TENSOR", LONG_TENSOR); TYPES.put("STRING_TENSOR", STRING_TENSOR); + TYPES.put("VARBINARY", VARBINARY); } /** @@ -93,7 +98,7 @@ public class AlinkTypes extends Types { * @param type TypeInformation * @return Corresponding type name, or null if not found. */ - public static String getTypeName(TypeInformation type) { + public static String getTypeName(TypeInformation type) { return TYPES.inverse().get(type); } @@ -107,7 +112,7 @@ public static TypeInformation getRowType(TypeInformation ... types) { * @param name type name string. * @return Corresponding TypeInformation, or null if not found. */ - public static TypeInformation getTypeInformation(String name) { + public static TypeInformation getTypeInformation(String name) { if (TYPES.containsKey(name)) { return TYPES.get(name); } else { diff --git a/core/src/main/java/com/alibaba/alink/common/type/BigDecimalTypeInfo.java b/core/src/main/java/com/alibaba/alink/common/type/BigDecimalTypeInfo.java new file mode 100644 index 000000000..62cabad53 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/type/BigDecimalTypeInfo.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.alink.common.type; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.AtomicType; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.BigDecComparator; +import org.apache.flink.api.common.typeutils.base.BigDecSerializer; + +import java.lang.reflect.Constructor; +import java.math.BigDecimal; +import java.util.Arrays; + +/** + * {@link TypeInformation} for {@link BigDecimal}. + * + *

It differs from {@link BasicTypeInfo#BIG_DEC_TYPE_INFO} in that: This type includes + * `precision` and `scale`, similar to SQL DECIMAL. + *

+ * NOTE: This class is copied and modified from a higher version of Flink source code. This class is modified to NOT + * inherit {@link BasicTypeInfo}, otherwise + * {@link org.apache.flink.table.calcite.FlinkTypeFactory#createTypeFromTypeInfo} will + */ +public class BigDecimalTypeInfo extends TypeInformation implements AtomicType { + + private static final long serialVersionUID = 1L; + + public static BigDecimalTypeInfo of(int precision, int scale) { + return new BigDecimalTypeInfo(precision, scale); + } + + public static BigDecimalTypeInfo of(BigDecimal value) { + return of(value.precision(), value.scale()); + } + + private final int precision; + + private final int scale; + + public BigDecimalTypeInfo(int precision, int scale) { + this.precision = precision; + this.scale = scale; + } + + @Override + public String toString() { + return String.format("Decimal(%d,%d)", precision(), scale()); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof BigDecimalTypeInfo)) { + return false; + } + BigDecimalTypeInfo that = (BigDecimalTypeInfo) obj; + return this.precision() == that.precision() && this.scale() == that.scale(); + } + + @Override + public int hashCode() { + int h0 = this.getClass().getCanonicalName().hashCode(); + return Arrays.hashCode(new int[] {h0, precision(), scale()}); + } + + @Override + public boolean canEqual(Object obj) { + return false; + } + + public int precision() { + return precision; + } + + public int scale() { + return scale; + } + + @Override + public boolean isBasicType() { + return true; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 1; + } + + @Override + public int getTotalFields() { + return 1; + } + + @Override + public Class getTypeClass() { + return BigDecimal.class; + } + + @Override + public boolean isKeyType() { + return true; + } + + @Override + public TypeSerializer createSerializer(ExecutionConfig config) { + return BigDecSerializer.INSTANCE; + } + + private static TypeComparator instantiateComparator(Class > comparatorClass, + boolean ascendingOrder) { + try { + Constructor > constructor = comparatorClass.getConstructor(Boolean.TYPE); + return constructor.newInstance(ascendingOrder); + } catch (Exception e) { + throw new RuntimeException("Could not initialize basic comparator " + comparatorClass.getName(), e); + } + } + + @Override + public TypeComparator createComparator(boolean sortOrderAscending, ExecutionConfig executionConfig) { + return instantiateComparator(BigDecComparator.class, sortOrderAscending); + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/VectorTypes.java b/core/src/main/java/com/alibaba/alink/common/type/VectorTypes.java similarity index 98% rename from core/src/main/java/com/alibaba/alink/common/VectorTypes.java rename to core/src/main/java/com/alibaba/alink/common/type/VectorTypes.java index ac253c4b9..0d66f29ed 100644 --- a/core/src/main/java/com/alibaba/alink/common/VectorTypes.java +++ b/core/src/main/java/com/alibaba/alink/common/type/VectorTypes.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.common; +package com.alibaba.alink.common.type; import org.apache.flink.api.common.typeinfo.TypeInformation; diff --git a/core/src/main/java/com/alibaba/alink/common/utils/DataStreamUtil.java b/core/src/main/java/com/alibaba/alink/common/utils/DataStreamUtil.java deleted file mode 100644 index 360b324a4..000000000 --- a/core/src/main/java/com/alibaba/alink/common/utils/DataStreamUtil.java +++ /dev/null @@ -1,57 +0,0 @@ -package com.alibaba.alink.common.utils; - -import org.apache.flink.api.common.functions.MapFunction; -import org.apache.flink.api.common.functions.Partitioner; -import org.apache.flink.api.common.functions.RichFlatMapFunction; -import org.apache.flink.api.common.functions.RichMapFunction; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.streaming.api.datastream.DataStream; -import org.apache.flink.types.Row; -import org.apache.flink.util.Collector; - -import java.util.ArrayList; -import java.util.List; - -/** - * Utils for handling datastream. - */ -@SuppressWarnings("unchecked") -public class DataStreamUtil { - - /** - * Stack a datastream of rows - */ - public static DataStream > stack(DataStream input, final int size) { - return input - .flatMap(new RichFlatMapFunction >() { - private static final long serialVersionUID = -2909825492775487009L; - transient Collector > collector; - transient List buffer; - - @Override - public void open(Configuration parameters) throws Exception { - this.buffer = new ArrayList <>(); - } - - @Override - public void close() throws Exception { - if (this.buffer.size() > 0) { - this.collector.collect(this.buffer); - this.buffer.clear(); - } - } - - @Override - public void flatMap(Row value, Collector > out) throws Exception { - this.collector = out; - this.buffer.add(value); - if (this.buffer.size() >= size) { - this.collector.collect(this.buffer); - this.buffer.clear(); - } - } - }); - } - -} diff --git a/core/src/main/java/com/alibaba/alink/common/utils/TableUtil.java b/core/src/main/java/com/alibaba/alink/common/utils/TableUtil.java index bacd6a22d..db67cd4de 100644 --- a/core/src/main/java/com/alibaba/alink/common/utils/TableUtil.java +++ b/core/src/main/java/com/alibaba/alink/common/utils/TableUtil.java @@ -1,30 +1,35 @@ package com.alibaba.alink.common.utils; -import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.table.api.Table; +import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; -import com.alibaba.alink.common.DataTypeDisplayInterface; +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.common.viz.DataTypeDisplayInterface; import com.alibaba.alink.common.exceptions.AkColumnNotFoundException; import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.exceptions.AkPreconditions; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; -import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter; import com.alibaba.alink.operator.common.similarity.similarity.LevenshteinSimilarity; +import com.alibaba.alink.params.shared.colname.HasFeatureColsDefaultAsNull; +import com.alibaba.alink.params.shared.colname.HasGroupColDefaultAsNull; +import com.alibaba.alink.params.shared.colname.HasLabelCol; +import com.alibaba.alink.params.shared.colname.HasLabelColDefaultAsNull; +import com.alibaba.alink.params.shared.colname.HasSelectedColsDefaultAsNull; +import com.alibaba.alink.params.shared.colname.HasWeightCol; +import com.alibaba.alink.params.shared.colname.HasWeightColDefaultAsNull; import com.google.common.base.Joiner; import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; +import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; @@ -32,8 +37,6 @@ import java.util.Set; import java.util.UUID; -import static org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO; - /** * Utility to operator to interact with Table contents, such as rows and columns. */ @@ -42,9 +45,10 @@ public class TableUtil { public static final Set STRING_TYPE_SET = new HashSet <>(); private static final LevenshteinSimilarity levenshteinSimilarity = new LevenshteinSimilarity(); public static final int DISPLAY_SIZE = 6; + public static final String HEX = "0123456789abcdef"; static { - STRING_TYPE_MAP.put("VARBINARY", BYTE_PRIMITIVE_ARRAY_TYPE_INFO); + STRING_TYPE_MAP.put("VARBINARY", AlinkTypes.VARBINARY); STRING_TYPE_MAP.put("VECTOR", AlinkTypes.VECTOR); STRING_TYPE_MAP.put("DENSE_VECTOR", AlinkTypes.DENSE_VECTOR); STRING_TYPE_MAP.put("SPARSE_VECTOR", AlinkTypes.SPARSE_VECTOR); @@ -396,6 +400,25 @@ public static boolean isSupportedNumericType(TypeInformation dataType) { || Types.BIG_DEC.equals(dataType); } + /** + * Determine whether it is date type. + * @param dataType the dataType to determine. + * @return whether it is date type. + */ + public static boolean isSupportedDateType(TypeInformation dataType) { + return Types.SQL_DATE.equals(dataType) + || Types.SQL_TIMESTAMP.equals(dataType) + || Types.SQL_TIME.equals(dataType); + } + + /** + * Determine whether it is boolean type. + * @param dataType the dataType to determine. + * @return whether it is boolean type. + */ + public static boolean isSupportedBoolType(TypeInformation dataType) { + return Types.BOOLEAN.equals(dataType); + } /** * Determine whether it is a string type. * @@ -602,6 +625,40 @@ public static String[] getCategoricalCols( return res.toArray(new String[0]); } + public static String[] getOptionalFeatureCols(TableSchema tableSchema, Params params) { + if (params.contains(HasFeatureColsDefaultAsNull.FEATURE_COLS)) { + return params.get(HasFeatureColsDefaultAsNull.FEATURE_COLS); + } + + if (params.contains(HasSelectedColsDefaultAsNull.SELECTED_COLS)) { + return params.get(HasSelectedColsDefaultAsNull.SELECTED_COLS); + } + + String[] featureCols = ArrayUtils.clone(tableSchema.getFieldNames()); + + if (params.contains(HasWeightColDefaultAsNull.WEIGHT_COL)) { + featureCols = ArrayUtils.removeElements(featureCols, params.get(HasWeightColDefaultAsNull.WEIGHT_COL)); + } + + if (params.contains(HasWeightCol.WEIGHT_COL)) { + featureCols = ArrayUtils.removeElements(featureCols, params.get(HasWeightCol.WEIGHT_COL)); + } + + if (params.contains(HasGroupColDefaultAsNull.GROUP_COL)) { + featureCols = ArrayUtils.removeElements(featureCols, params.get(HasGroupColDefaultAsNull.GROUP_COL)); + } + + if (params.contains(HasLabelCol.LABEL_COL)) { + featureCols = ArrayUtils.removeElements(featureCols, params.get(HasLabelCol.LABEL_COL)); + } + + if (params.contains(HasLabelColDefaultAsNull.LABEL_COL)) { + featureCols = ArrayUtils.removeElements(featureCols, params.get(HasLabelColDefaultAsNull.LABEL_COL)); + } + + return featureCols; + } + /** * format the column names as header of markdown. */ @@ -628,7 +685,7 @@ public static String formatTitle(String[] colNames) { } } - return sbd.toString() + "\n" + sbdSplitter.toString(); + return sbd + "\n" + sbdSplitter; } /** @@ -644,7 +701,7 @@ public static String formatRows(Row row) { sbd.append("|"); } Object obj = row.getField(i); - if (obj instanceof Double || obj instanceof Float) { + if (obj instanceof Double || obj instanceof Float || obj instanceof BigDecimal) { sbd.append(String.format("%.4f", ((Number) obj).doubleValue())); } else if (obj instanceof DataTypeDisplayInterface) { if (obj instanceof DenseVector || obj instanceof SparseVector) { @@ -702,6 +759,16 @@ public static String formatRows(Row row) { sbd.append(" "); } } + } else if (obj instanceof byte[]) { + int byteSize = ((byte[]) obj).length; + sbd.append("byte[").append(byteSize).append("] "); + byte[] byteArray = byteSize > DISPLAY_SIZE ? Arrays.copyOfRange((byte[]) obj, 0, DISPLAY_SIZE) + : (byte[]) obj; + for (byte b : byteArray) { + sbd.append(HEX.charAt((b >> 4) & 0x0f)); + sbd.append(HEX.charAt(b & 0x0f)); + } + sbd.append((byteSize > DISPLAY_SIZE ? "..." : "")); } else { sbd.append(obj); } @@ -744,120 +811,51 @@ public static String columnsToSqlClause(String[] colNames) { return Joiner.on("`,`").appendTo(new StringBuilder("`"), colNames).append("`").toString(); } - /** - * open ends here - **/ - public static Table concatTables(Table[] tables, Long sessionId) { - final int[] numCols = new int[tables.length]; - final List allColNames = new ArrayList <>(); - final List > allColTypes = new ArrayList <>(); - allColNames.add("table_id"); - allColTypes.add(Types.LONG); - for (int i = 0; i < tables.length; i++) { - if (tables[i] == null) { - numCols[i] = 0; - } else { - numCols[i] = tables[i].getSchema().getFieldNames().length; - String[] prefixedColNames = tables[i].getSchema().getFieldNames().clone(); - for (int j = 0; j < prefixedColNames.length; j++) { - prefixedColNames[j] = String.format("t%d_%s", i, prefixedColNames[j]); - } - allColNames.addAll(Arrays.asList(prefixedColNames)); - allColTypes.addAll(Arrays.asList(tables[i].getSchema().getFieldTypes())); - } - } - - if (allColNames.size() == 1) { - return null; - } - - DataSet allRows = null; - int startCol = 1; - final int numAllCols = allColNames.size(); - for (int i = 0; i < tables.length; i++) { - if (tables[i] == null) { - continue; - } - final int constStartCol = startCol; - final int iTable = i; - DataSet rows = BatchOperator.fromTable(tables[i]).setMLEnvironmentId(sessionId).getDataSet(); - rows = rows.map(new RichMapFunction () { - private static final long serialVersionUID = -8085823678072944808L; - transient Row reused; - - @Override - public void open(Configuration parameters) { - reused = new Row(numAllCols); - } - - @Override - public Row map(Row value) { - for (int i = 0; i < numAllCols; i++) { - reused.setField(i, null); - } - reused.setField(0, (long) iTable); - for (int i = 0; i < numCols[iTable]; i++) { - reused.setField(constStartCol + i, value.getField(i)); - } - return reused; - } - }); - if (allRows == null) { - allRows = rows; - } else { - allRows = allRows.union(rows); - } - startCol += numCols[i]; - } - return DataSetConversionUtil.toTable(sessionId, allRows, allColNames.toArray(new String[0]), - allColTypes.toArray(new TypeInformation[0])); - } - - public static Table[] splitTable(Table table) { - TableSchema schema = table.getSchema(); - final String[] colNames = schema.getFieldNames(); - String idCol = colNames[0]; - if (!idCol.equalsIgnoreCase("table_id")) { - throw new AkIllegalArgumentException("The table can't be splited."); - } - - String lastCol = colNames[colNames.length - 1]; - int maxTableId = Integer.parseInt(lastCol.substring(1, lastCol.indexOf('_'))); - int numTables = maxTableId + 1; - - int[] numColsOfEachTable = new int[numTables]; - for (int i = 1; i < colNames.length; i++) { - int tableId = Integer.parseInt(colNames[i].substring(1, lastCol.indexOf('_'))); - numColsOfEachTable[tableId]++; - } - - Table[] splited = new Table[numTables]; - int startCol = 1; - for (int i = 0; i < numTables; i++) { - if (numColsOfEachTable[i] == 0) { - continue; - } - String[] selectedCols = Arrays.copyOfRange(colNames, startCol, startCol + numColsOfEachTable[i]); - BatchOperator sub = BatchOperator.fromTable(table) - .where(String.format("%s=%d", "table_id", i)) - .select(selectedCols); - - // recover the col names - String prefix = String.format("t%d_", i); - StringBuilder sbd = new StringBuilder(); - for (int j = 0; j < selectedCols.length; j++) { - if (j > 0) { - sbd.append(","); - } - sbd.append(selectedCols[j].substring(prefix.length())); - } - sub = sub.as(sbd.toString()); - splited[i] = sub.getOutputTable(); - startCol += numColsOfEachTable[i]; - } - return splited; - } + //public static Table[] splitTable(Table table) { + // TableSchema schema = table.getSchema(); + // final String[] colNames = schema.getFieldNames(); + // String idCol = colNames[0]; + // if (!idCol.equalsIgnoreCase("table_id")) { + // throw new AkIllegalArgumentException("The table can't be splited."); + // } + // + // String lastCol = colNames[colNames.length - 1]; + // int maxTableId = Integer.parseInt(lastCol.substring(1, lastCol.indexOf('_'))); + // int numTables = maxTableId + 1; + // + // int[] numColsOfEachTable = new int[numTables]; + // for (int i = 1; i < colNames.length; i++) { + // int tableId = Integer.parseInt(colNames[i].substring(1, lastCol.indexOf('_'))); + // numColsOfEachTable[tableId]++; + // } + // + // Table[] splited = new Table[numTables]; + // int startCol = 1; + // for (int i = 0; i < numTables; i++) { + // if (numColsOfEachTable[i] == 0) { + // continue; + // } + // String[] selectedCols = Arrays.copyOfRange(colNames, startCol, startCol + numColsOfEachTable[i]); + // BatchOperator sub = BatchOperator.fromTable(table) + // .where(String.format("%s=%d", "table_id", i)) + // .select(selectedCols); + // + // // recover the col names + // String prefix = String.format("t%d_", i); + // StringBuilder sbd = new StringBuilder(); + // for (int j = 0; j < selectedCols.length; j++) { + // if (j > 0) { + // sbd.append(","); + // } + // sbd.append(selectedCols[j].substring(prefix.length())); + // } + // sub = sub.as(sbd.toString()); + // splited[i] = sub.getOutputTable(); + // startCol += numColsOfEachTable[i]; + // } + // return splited; + //} public static Row getRow(Row row, int... keepIdxs) { Row res = null; diff --git a/core/src/main/java/com/alibaba/alink/common/viz/AlinkViz.java b/core/src/main/java/com/alibaba/alink/common/viz/AlinkViz.java new file mode 100644 index 000000000..3ddbdeda9 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/viz/AlinkViz.java @@ -0,0 +1,26 @@ +package com.alibaba.alink.common.viz; + +import org.apache.flink.ml.api.misc.param.ParamInfo; +import org.apache.flink.ml.api.misc.param.ParamInfoFactory; +import org.apache.flink.ml.api.misc.param.WithParams; + +public interface AlinkViz extends WithParams { + + ParamInfo VIZ_NAME = ParamInfoFactory + .createParamInfo("vizName", String.class) + .setDescription("Name of Visualization") + .setOptional() + .build(); + + default VizDataWriterInterface getVizDataWriter() { + return ScreenManager.getScreenManager().getVizDataWriter(getParams()); + } + + default String getVizName() { + return getParams().get(VIZ_NAME); + } + + default T setVizName(String value) { + return set(VIZ_NAME, value); + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/DataTypeDisplayInterface.java b/core/src/main/java/com/alibaba/alink/common/viz/DataTypeDisplayInterface.java similarity index 94% rename from core/src/main/java/com/alibaba/alink/common/DataTypeDisplayInterface.java rename to core/src/main/java/com/alibaba/alink/common/viz/DataTypeDisplayInterface.java index 86179b879..b15a43c04 100644 --- a/core/src/main/java/com/alibaba/alink/common/DataTypeDisplayInterface.java +++ b/core/src/main/java/com/alibaba/alink/common/viz/DataTypeDisplayInterface.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.common; +package com.alibaba.alink.common.viz; public interface DataTypeDisplayInterface { diff --git a/core/src/main/java/com/alibaba/alink/common/viz/DummyVizDataWriter.java b/core/src/main/java/com/alibaba/alink/common/viz/DummyVizDataWriter.java new file mode 100644 index 000000000..0572f67d5 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/viz/DummyVizDataWriter.java @@ -0,0 +1,38 @@ +package com.alibaba.alink.common.viz; + +import java.util.List; + +public class DummyVizDataWriter implements VizDataWriterInterface { + private static final long serialVersionUID = -3525933894792429368L; + + @Override + public void writeBatchData(long dataType, String data, long timeStamp) { + + } + + @Override + public void writeStreamData(long dataType, String data, long timeStamp) { + + } + + @Override + public void writeBatchData(List data) { + + } + + @Override + public void writeStreamData(List data) { + + } + + @Override + public void writeBatchMeta(VizOpMeta meta) { + + } + + @Override + public void writeStreamMeta(VizOpMeta meta) { + + } + +} diff --git a/core/src/main/java/com/alibaba/alink/common/viz/DummyVizManager.java b/core/src/main/java/com/alibaba/alink/common/viz/DummyVizManager.java new file mode 100644 index 000000000..3879d4b75 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/viz/DummyVizManager.java @@ -0,0 +1,18 @@ +package com.alibaba.alink.common.viz; + +import org.apache.flink.ml.api.misc.param.Params; + +public class DummyVizManager implements VizManagerInterface { + + private static final long serialVersionUID = 8639356595306220989L; + + @Override + public VizDataWriterInterface getVizDataWriter(Params params) { + return new DummyVizDataWriter(); + } + + @Override + public String getVizName() { + return null; + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/viz/ScreenManager.java b/core/src/main/java/com/alibaba/alink/common/viz/ScreenManager.java new file mode 100644 index 000000000..f329c0eb3 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/viz/ScreenManager.java @@ -0,0 +1,17 @@ +package com.alibaba.alink.common.viz; + +import com.alibaba.alink.common.viz.DummyVizManager; +import com.alibaba.alink.common.viz.VizManagerInterface; + +public class ScreenManager { + //public static Boolean ApsEnableLogging = Boolean.FALSE; + static VizManagerInterface screenManager = new DummyVizManager(); + + public static VizManagerInterface getScreenManager() { + return screenManager; + } + + public static void setScreenManager(VizManagerInterface manger) { + screenManager = manger; + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/viz/VizData.java b/core/src/main/java/com/alibaba/alink/common/viz/VizData.java new file mode 100644 index 000000000..862f54692 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/viz/VizData.java @@ -0,0 +1,30 @@ +package com.alibaba.alink.common.viz; + +import com.alibaba.alink.common.utils.AlinkSerializable; + +import java.io.Serializable; + +public class VizData implements AlinkSerializable, Serializable { + private static final long serialVersionUID = -3722895114138878009L; + public long dataId; + public String data; + public long timeStamp; + + public VizData(long dataId, String data, long timeStamp) { + this.dataId = dataId; + this.data = data; + this.timeStamp = timeStamp; + } + + public long getDataId() { + return dataId; + } + + public String getData() { + return data; + } + + public long getTimeStamp() { + return timeStamp; + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/viz/VizDataWriterForModelInfo.java b/core/src/main/java/com/alibaba/alink/common/viz/VizDataWriterForModelInfo.java new file mode 100644 index 000000000..8e83055b5 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/viz/VizDataWriterForModelInfo.java @@ -0,0 +1,101 @@ +package com.alibaba.alink.common.viz; + +import org.apache.flink.api.common.functions.RichGroupReduceFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.operator.batch.utils.DataSetUtil; +import com.alibaba.alink.operator.common.tree.Node; +import com.alibaba.alink.operator.common.tree.TreeModelDataConverter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; + +import static com.alibaba.alink.common.utils.JsonConverter.gson; + +public class VizDataWriterForModelInfo { + private final static Logger LOG = LoggerFactory.getLogger(VizDataWriterForModelInfo.class); + + static public void writeModelInfo(VizDataWriterInterface writer, String opName, TableSchema schema, + DataSet model, Params params) { + VizOpMeta vizOpMeta = new VizOpMeta(); + + vizOpMeta.opName = opName; + + vizOpMeta.dataInfos = new VizOpDataInfo[1]; + vizOpMeta.dataInfos[0] = new VizOpDataInfo(0, VizOpDataInfo.WriteVizDataType.OnlyOnce); + + vizOpMeta.cascades = new HashMap <>(); + vizOpMeta.cascades.put(gson.toJson(new String[] {"model"}), new VizOpChartData(0)); + + vizOpMeta.setSchema(schema); + vizOpMeta.params = params; + vizOpMeta.isOutput = false; + + writer.writeBatchMeta(vizOpMeta); + + DataSet result = model.mapPartition(new VizDataWriterMapperForTable( + writer, 0, schema.getFieldNames(), schema.getFieldTypes())) + .setParallelism(1); + DataSetUtil.linkDummySink(result); + } + + static public void writeTreeModelInfo(VizDataWriterInterface writer, String opName, TableSchema schema, + DataSet model, Params params) { + TypeInformation [] types = schema.getFieldTypes(); + TypeInformation labelType = types[types.length - 1]; + DataSet processedModel = model + .reduceGroup(new PruneTreeMapper(labelType)); + writeModelInfo(writer, opName, schema, processedModel, params); + } + + private static void pruneTree(Node node, int depth, int maxDepthAllowed) { + if (depth + 1 >= maxDepthAllowed) { + if (node.getNextNodes() != null) { + node.setNextNodes(new Node[] {}); + } + return; + } + if (node.getNextNodes() == null) { + return; + } + for (Node child : node.getNextNodes()) { + pruneTree(child, depth + 1, maxDepthAllowed); + } + } + + public static class PruneTreeMapper extends RichGroupReduceFunction { + private static final long serialVersionUID = -4197833778656802284L; + static int MAX_DEPTH_ALLOWED_FOR_VIZ = 14; + private final TypeInformation labelType; + + public PruneTreeMapper(TypeInformation labelType) { + this.labelType = labelType; + } + + @Override + public void reduce(Iterable values, Collector out) throws Exception { + LOG.info("PruneTreeMapper start"); + TreeModelDataConverter model = new TreeModelDataConverter(this.labelType); + + List modelRows = new ArrayList <>(); + for (Row row : values) { + modelRows.add(row); + } + model.load(modelRows); + + for (Node node : model.roots) { + pruneTree(node, 0, MAX_DEPTH_ALLOWED_FOR_VIZ); + } + model.save(model, out); + LOG.info("PruneTreeMapper end"); + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/viz/VizDataWriterInterface.java b/core/src/main/java/com/alibaba/alink/common/viz/VizDataWriterInterface.java new file mode 100644 index 000000000..3fd28b52e --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/viz/VizDataWriterInterface.java @@ -0,0 +1,20 @@ +package com.alibaba.alink.common.viz; + +import java.io.Serializable; +import java.util.List; + +public interface VizDataWriterInterface extends Serializable { + + void writeBatchData(long dataId, String data, long timeStamp); + + void writeStreamData(long dataId, String data, long timeStamp); + + void writeBatchData(List data); + + void writeStreamData(List data); + + void writeStreamMeta(VizOpMeta meta); + + void writeBatchMeta(VizOpMeta meta); +} + diff --git a/core/src/main/java/com/alibaba/alink/common/viz/VizDataWriterMapperForTable.java b/core/src/main/java/com/alibaba/alink/common/viz/VizDataWriterMapperForTable.java new file mode 100644 index 000000000..665519d8e --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/viz/VizDataWriterMapperForTable.java @@ -0,0 +1,59 @@ +package com.alibaba.alink.common.viz; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.utils.AlinkSerializable; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static com.alibaba.alink.common.utils.JsonConverter.gson; + +/** + * Use `kv` to write `table` data. + * Writing `table` data directly could be un-parsable due to special characters in tables, like `,`, `\n`, and so on. + */ +public class VizDataWriterMapperForTable implements MapPartitionFunction { + private static final long serialVersionUID = -8765896344566463568L; + VizDataWriterInterface writer; + String[] colNames; + TypeInformation[] colTypes; + int dataId; + + VizDataWriterMapperForTable(VizDataWriterInterface writer, int dataId, String[] colNames, + TypeInformation[] colTypes) { + this.writer = writer; + this.colNames = colNames; + this.colTypes = colTypes; + this.dataId = dataId; + } + + @Override + public void mapPartition(Iterable iterable, Collector collector) throws Exception { + Table result = new Table(); + result.colNamesJson = gson.toJson(colNames); + result.colTypesJson = gson.toJson(Arrays.stream(colTypes).map(d -> d.toString()).collect(Collectors.toList())); + for (Row row : iterable) { + int n = row.getArity(); + Object[] fields = new Object[n]; + for (int i = 0; i < n; i += 1) { + fields[i] = row.getField(i); + } + result.fieldsJson.add(gson.toJson(fields)); + } + String resultJson = gson.toJson(result); + // System.err.println(resultJson); + writer.writeBatchData(this.dataId, resultJson, System.currentTimeMillis()); + } + + public static class Table implements AlinkSerializable { + public String colNamesJson; + public String colTypesJson; + public List fieldsJson = new ArrayList <>(); + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/viz/VizManagerInterface.java b/core/src/main/java/com/alibaba/alink/common/viz/VizManagerInterface.java new file mode 100644 index 000000000..da12e3cad --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/viz/VizManagerInterface.java @@ -0,0 +1,13 @@ +package com.alibaba.alink.common.viz; + +import org.apache.flink.ml.api.misc.param.Params; + +import java.io.Serializable; + +public interface VizManagerInterface extends Serializable { + + VizDataWriterInterface getVizDataWriter(Params params); + + String getVizName(); + +} diff --git a/core/src/main/java/com/alibaba/alink/common/viz/VizOpChartData.java b/core/src/main/java/com/alibaba/alink/common/viz/VizOpChartData.java new file mode 100644 index 000000000..158ead145 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/viz/VizOpChartData.java @@ -0,0 +1,27 @@ +package com.alibaba.alink.common.viz; + +import com.alibaba.alink.common.utils.AlinkSerializable; + +import java.io.Serializable; + +public class VizOpChartData implements AlinkSerializable, Serializable { + + public int dataId; + + public String[] keys = null; //option, json value + + public VizOpChartData(int dataId) { + this.dataId = dataId; + } + + public VizOpChartData(int dataId, String key) { + this.dataId = dataId; + this.keys = new String[] {key}; + } + + public VizOpChartData(int dataId, String[] keys) { + this.dataId = dataId; + this.keys = keys; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/common/viz/VizOpDataInfo.java b/core/src/main/java/com/alibaba/alink/common/viz/VizOpDataInfo.java new file mode 100644 index 000000000..abb135f2b --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/viz/VizOpDataInfo.java @@ -0,0 +1,46 @@ +package com.alibaba.alink.common.viz; + +import com.alibaba.alink.common.utils.AlinkSerializable; + +import java.io.Serializable; + +public class VizOpDataInfo implements AlinkSerializable, Serializable { + + public int dataId; + public String dataType; //table, kv + public String colNames = null;//col1,col2,col3, only for table + public String writeDataType = "continuous"; //continuous or onlyOnce + + public VizOpDataInfo(int dataId) { + this.dataId = dataId; + this.dataType = "kv"; + } + + public VizOpDataInfo(int dataId, String colNames) { + this.dataId = dataId; + this.colNames = colNames; + this.dataType = "table"; + } + + public VizOpDataInfo(int dataId, WriteVizDataType writeDataType) { + this.dataId = dataId; + this.dataType = "kv"; + if (writeDataType == WriteVizDataType.OnlyOnce) { + this.writeDataType = "onlyOnce"; + } + } + + public VizOpDataInfo(int dataId, String colNames, WriteVizDataType writeDataType) { + this.dataId = dataId; + this.colNames = colNames; + this.dataType = "table"; + if (writeDataType == WriteVizDataType.OnlyOnce) { + this.writeDataType = "onlyOnce"; + } + } + + public enum WriteVizDataType { + Continuous, + OnlyOnce + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/viz/VizOpMeta.java b/core/src/main/java/com/alibaba/alink/common/viz/VizOpMeta.java new file mode 100644 index 000000000..c7b974333 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/viz/VizOpMeta.java @@ -0,0 +1,59 @@ +package com.alibaba.alink.common.viz; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; + +import com.alibaba.alink.common.utils.AlinkSerializable; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +public class VizOpMeta implements AlinkSerializable, Serializable { + + private static final long serialVersionUID = 3940407105766723917L; + public String opName; //package.op AllStatBatchOp if withStat / Op + + public VizOpDataInfo[] dataInfos; + public Map cascades = new HashMap <>(); + + public String opType; //stream or batch + + //screen param info + public VizOpTableSchema[] schemas; + public boolean isOutput = true; + public Params params; + + public void setSchema(TableSchema s) { + setSchemas(new TableSchema[] {s}); + } + + public void setSchema(String[] colNames, TypeInformation[] colTypes) { + setSchemas(new TableSchema[] {new TableSchema(colNames, colTypes)}); + } + + public void setSchemas(TableSchema[] s) { + schemas = new VizOpTableSchema[s.length]; + for (int i = 0; i < s.length; i++) { + schemas[i] = new VizOpTableSchema(); + int len = s[i].getFieldNames().length; + schemas[i].colNames = new String[len]; + for (int j = 0; j < len; j++) { + schemas[i].colNames[j] = s[i].getFieldName(j).get(); + } + schemas[i].colTypes = new String[len]; + for (int j = 0; j < len; j++) { + schemas[i].colTypes[j] = s[i].getFieldType(j).get().toString().toUpperCase(); + } + } + + } + + static class VizOpTableSchema implements AlinkSerializable, Serializable { + private static final long serialVersionUID = 8371261277649665050L; + public String[] colNames; + public String[] colTypes; + } +} + diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/BatchOperator.java b/core/src/main/java/com/alibaba/alink/operator/batch/BatchOperator.java index 430e4ac34..13176b5f2 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/BatchOperator.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/BatchOperator.java @@ -18,6 +18,7 @@ import com.alibaba.alink.common.MLEnvironment; import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.MTable; +import com.alibaba.alink.common.annotation.Internal; import com.alibaba.alink.common.exceptions.AkFlinkExecutionErrorException; import com.alibaba.alink.common.exceptions.AkIllegalOperationException; import com.alibaba.alink.common.exceptions.AkPreconditions; @@ -27,7 +28,7 @@ import com.alibaba.alink.common.io.annotations.IoOpAnnotation; import com.alibaba.alink.common.lazy.LazyEvaluation; import com.alibaba.alink.common.lazy.LazyObjectsManager; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.AlgoOperator; import com.alibaba.alink.operator.batch.dataproc.FirstNBatchOp; @@ -45,7 +46,7 @@ import com.alibaba.alink.operator.batch.utils.DiveVisualizer.DiveVisualizerConsumer; import com.alibaba.alink.operator.batch.utils.UDFBatchOp; import com.alibaba.alink.operator.batch.utils.UDTFBatchOp; -import com.alibaba.alink.operator.common.sql.BatchSqlOperators; +import com.alibaba.alink.operator.batch.sql.BatchSqlOperators; import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummary; import org.apache.commons.lang3.tuple.Pair; @@ -541,6 +542,7 @@ private static ExecutionEnvironment getExecutionEnvironment( return env; } + @Internal @IoOpAnnotation(name = "mem_batch_sink", ioType = IOType.SinkBatch) private static class MemSinkBatchOp extends BaseSinkBatchOp { private static final long serialVersionUID = -2595920715328848084L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/PipelinePredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/PipelinePredictBatchOp.java new file mode 100644 index 000000000..667e43dda --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/PipelinePredictBatchOp.java @@ -0,0 +1,44 @@ +package com.alibaba.alink.operator.batch; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.exceptions.AkIllegalDataException; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.params.PipelineModelBatchPredictParams; +import com.alibaba.alink.pipeline.PipelineModel; + +/** + * Pipeline prediction. + */ + +@NameCn("Pipeline 预测") +@NameEn("Pipeline prediction") +public final class PipelinePredictBatchOp extends BatchOperator + implements PipelineModelBatchPredictParams { + + public PipelinePredictBatchOp() { + super(new Params()); + } + + public PipelinePredictBatchOp(Params params) { + super(params); + } + + @Override + public PipelinePredictBatchOp linkFrom(BatchOperator ... inputs) { + try { + BatchOperator data = checkAndGetFirst(inputs); + final PipelineModel pipelineModel = PipelineModel.load(getModelFilePath()) + .setMLEnvironmentId(data.getMLEnvironmentId()); + BatchOperator result = pipelineModel.transform(data); + this.setOutput(DataSetConversionUtil.toTable(data.getMLEnvironmentId(), + result.getDataSet(), result.getSchema())); + return this; + } catch (Exception ex) { + ex.printStackTrace(); + throw new AkIllegalDataException(ex.toString()); + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/ApplyAssociationRuleBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/ApplyAssociationRuleBatchOp.java new file mode 100644 index 000000000..b22ed4df8 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/ApplyAssociationRuleBatchOp.java @@ -0,0 +1,36 @@ +package com.alibaba.alink.operator.batch.associationrule; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.associationrule.ApplyAssociationRuleModelMapper; +import com.alibaba.alink.params.mapper.SISOMapperParams; + +/** + * The batch op of applying the Association Rules. + */ +@ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) +@NameCn("关联规则预测") +@NameEn("Association Rule Prediction") +public class ApplyAssociationRuleBatchOp extends ModelMapBatchOp + implements SISOMapperParams { + + private static final long serialVersionUID = 674848671578909834L; + + public ApplyAssociationRuleBatchOp() { + this(null); + } + + /** + * constructor. + * + * @param params the parameters set. + */ + public ApplyAssociationRuleBatchOp(Params params) { + super(ApplyAssociationRuleModelMapper::new, params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/ApplySequenceRuleBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/ApplySequenceRuleBatchOp.java new file mode 100644 index 000000000..5829e7526 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/ApplySequenceRuleBatchOp.java @@ -0,0 +1,36 @@ +package com.alibaba.alink.operator.batch.associationrule; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.associationrule.ApplySequenceRuleModelMapper; +import com.alibaba.alink.params.mapper.SISOMapperParams; + +/** + * The batch op of applying Sequence Rules. + */ +@ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) +@NameCn("序列规则预测") +@NameEn("Sequence Rule Prediction") +public class ApplySequenceRuleBatchOp extends ModelMapBatchOp + implements SISOMapperParams { + + private static final long serialVersionUID = -26090263164779243L; + + public ApplySequenceRuleBatchOp() { + this(null); + } + + /** + * constructor. + * + * @param params the parameters set. + */ + public ApplySequenceRuleBatchOp(Params params) { + super(ApplySequenceRuleModelMapper::new, params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/FpGrowthBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/FpGrowthBatchOp.java index 76a1e186f..e0f9406a2 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/FpGrowthBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/FpGrowthBatchOp.java @@ -20,6 +20,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -27,7 +28,7 @@ import com.alibaba.alink.common.annotation.PortSpec.OpType; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.associationrule.AssociationRule; @@ -58,6 +59,7 @@ @ParamSelectColumnSpec(name = "itemsCol", allowedTypeCollections = TypeCollections.STRING_TYPE) @NameCn("FpGrowth") +@NameEn("FpGrowth") public final class FpGrowthBatchOp extends BatchOperator implements FpGrowthParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/GroupedFpGrowthBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/GroupedFpGrowthBatchOp.java new file mode 100644 index 000000000..aae86e621 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/GroupedFpGrowthBatchOp.java @@ -0,0 +1,364 @@ +package com.alibaba.alink.operator.batch.associationrule; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichGroupReduceFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Table; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.StringUtils; + +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortSpec.OpType; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.associationrule.FpTree; +import com.alibaba.alink.operator.common.associationrule.FpTreeImpl; +import com.alibaba.alink.params.associationrule.GroupedFpGrowthParams; +import org.apache.commons.lang3.ArrayUtils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/* + * batch op of grouped fp-growth + */ +@InputPorts(values = @PortSpec(value = PortType.DATA, opType = OpType.BATCH)) +@OutputPorts(values = { + @PortSpec(value = PortType.MODEL, desc = PortDesc.ASSOCIATION_PATTERNS), + @PortSpec(value = PortType.MODEL, desc = PortDesc.ASSOCIATION_RULES), +}) +@ParamSelectColumnSpec(name = "itemsCol", + allowedTypeCollections = TypeCollections.STRING_TYPE) +@ParamSelectColumnSpec(name = "groupCol") +@NameCn("分组FPGrowth训练") +@NameEn("Grouped FpGrowth Training") +public final class GroupedFpGrowthBatchOp + extends BatchOperator + implements GroupedFpGrowthParams { + + private static final long serialVersionUID = -3434563610385164063L; + + public GroupedFpGrowthBatchOp() { + } + + public GroupedFpGrowthBatchOp(Params params) { + super(params); + } + + @Override + public GroupedFpGrowthBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + final String itemsCol = getItemsCol(); + final String groupedCol = getGroupCol(); + final int minSupportCnt = getMinSupportCount(); + final double minSupportPercent = getMinSupportPercent(); + final int maxPatternLength = getMaxPatternLength(); + final double minLift = getMinLift(); + final double minConfidence = getMinConfidence(); + + in = in.select(new String[] {groupedCol, itemsCol}); + DataSet rows = in.getDataSet(); + DataSet > withGroupId = rows + .map(new MapFunction >() { + private static final long serialVersionUID = -2408050110594809903L; + + @Override + public Tuple2 map(Row value) throws Exception { + String key = String.valueOf(value.getField(0)); + return Tuple2.of(key, value); + } + }); + + DataSet > groupCount = withGroupId + .map(new MapFunction , Tuple2 >() { + private static final long serialVersionUID = -9076532747012786151L; + + @Override + public Tuple2 map(Tuple2 value) throws Exception { + return Tuple2.of(value.f0, 1L); + } + }) + .groupBy(0) + .reduce(new ReduceFunction >() { + private static final long serialVersionUID = 6511332413000462792L; + + @Override + public Tuple2 reduce(Tuple2 value1, Tuple2 value2) + throws Exception { + value1.f1 += value2.f1; + return value1; + } + }); + + DataSet > patterns = withGroupId + .groupBy(0) + .reduceGroup(new RichGroupReduceFunction , Tuple2 >() { + private static final long serialVersionUID = -2758244844753267506L; + + @Override + public void reduce(Iterable > values, Collector > out) + throws Exception { + Object key = null; + List transactions = new ArrayList <>(); + for (Tuple2 v : values) { + key = v.f1.getField(0); + transactions.add((String) v.f1.getField(1)); + } + final Object constKey = key; + final int minSupportCount = decideMinSupportCount(minSupportCnt, minSupportPercent, + transactions.size()); + final Map itemCounts = getItemCounts(transactions); + final Tuple2 , List > ordered = orderItems(itemCounts); + final Map tokenToIndex = ordered.f0; + final List orderedItems = ordered.f1; + final int[] qualifiedItemIndices = getQualifiedItemIndices(itemCounts, tokenToIndex, + minSupportCount); + + FpTree fpTree = new FpTreeImpl(); + fpTree.createTree(); + for (String transaction : transactions) { + if (StringUtils.isNullOrWhitespaceOnly(transaction)) { + continue; + } + String[] items = transaction.split(FpGrowthBatchOp.ITEM_SEPARATOR); + Set qualifiedItems = new HashSet <>(items.length); + for (String item : items) { + if (itemCounts.get(item) >= minSupportCount) { + qualifiedItems.add(tokenToIndex.get(item)); + } + } + int[] t = toArray(qualifiedItems); + Arrays.sort(t); + fpTree.addTransaction(t); + } + + // System.out.println("key: " + key); + fpTree.initialize(); + fpTree.printProfile(); + + fpTree.extractAll(qualifiedItemIndices, minSupportCount, maxPatternLength, + new Collector >() { + @Override + public void collect(Tuple2 record) { + String itemset = indicesToTokens(record.f0, orderedItems); + long supportCount = record.f1; + long itemCount = record.f0.length; + Row row = new Row(FpGrowthBatchOp.ITEMSETS_COL_NAMES.length + 1); + row.setField(0, constKey); + row.setField(1, itemset); + row.setField(2, supportCount); + row.setField(3, itemCount); + out.collect(Tuple2.of(String.valueOf(constKey), row)); + } + + @Override + public void close() { + } + }); + + fpTree.destroyTree(); + } + }) + .name("gen_patterns"); + + DataSet rules = patterns + .groupBy(0) + .reduceGroup(new RichGroupReduceFunction , Row>() { + private static final long serialVersionUID = 3848228188758749261L; + transient List > bc; + + @Override + public void open(Configuration parameters) throws Exception { + bc = getRuntimeContext().getBroadcastVariable("groupCount"); + } + + @Override + public void reduce(Iterable > values, Collector out) throws Exception { + Map patterns = new HashMap <>(); + String key = null; + for (Tuple2 t2 : values) { + key = t2.f0; + patterns.put((String) t2.f1.getField(1), (Long) t2.f1.getField(2)); + } + final String constKey = key; + Long tranCnt = null; + for (Tuple2 c : bc) { + if (c.f0.equals(key)) { + tranCnt = c.f1; + break; + } + } + Preconditions.checkArgument(tranCnt != null); + final Long transactionCnt = tranCnt; + + patterns.forEach((k, v) -> { + String[] items = k.split(FpGrowthBatchOp.ITEM_SEPARATOR); + if (items.length > 1) { + for (int i = 0; i < items.length; i++) { + int n = 0; + StringBuilder sbd = new StringBuilder(); + for (int j = 0; j < items.length; j++) { + if (j == i) { + continue; + } + if (n > 0) { + sbd.append(FpGrowthBatchOp.ITEM_SEPARATOR); + } + sbd.append(items[j]); + n++; + } + String ante = sbd.toString(); + String conseq = items[i]; + Long supportXY = v; + Long supportX = patterns.get(ante); + Long supportY = patterns.get(conseq); + Preconditions.checkArgument(supportX != null); + Preconditions.checkArgument(supportY != null); + Preconditions.checkArgument(supportXY != null); + double confidence = supportXY.doubleValue() / supportX.doubleValue(); + double lift = supportXY.doubleValue() * transactionCnt.doubleValue() / ( + supportX.doubleValue() + * supportY.doubleValue()); + double support = supportXY.doubleValue() / transactionCnt.doubleValue(); + + Row ruleOutput = new Row(7); + ruleOutput.setField(0, constKey); + ruleOutput.setField(1, ante + "=>" + conseq); + ruleOutput.setField(2, (long) items.length); + ruleOutput.setField(3, lift); + ruleOutput.setField(4, support); + ruleOutput.setField(5, confidence); + ruleOutput.setField(6, supportXY); + if (lift >= minLift && confidence >= minConfidence) { + out.collect(ruleOutput); + } + } + } + }); + } + }) + .withBroadcastSet(groupCount, "groupCount") + .name("gen_rules"); + + DataSet outputPatterns = patterns + .map(new MapFunction , Row>() { + private static final long serialVersionUID = -4247869441801301592L; + + @Override + public Row map(Tuple2 value) throws Exception { + return value.f1; + } + }); + + Table patternsTable = DataSetConversionUtil.toTable(getMLEnvironmentId(), outputPatterns, + ArrayUtils.addAll(new String[] {in.getColNames()[0]}, FpGrowthBatchOp.ITEMSETS_COL_NAMES), + ArrayUtils.addAll(new TypeInformation[] {in.getColTypes()[0]}, FpGrowthBatchOp.ITEMSETS_COL_TYPES)); + + Table rulesTable = DataSetConversionUtil.toTable(getMLEnvironmentId(), rules, + ArrayUtils.addAll(new String[] {in.getColNames()[0]}, FpGrowthBatchOp.RULES_COL_NAMES), + ArrayUtils.addAll(new TypeInformation[] {in.getColTypes()[0]}, FpGrowthBatchOp.RULES_COL_TYPES)); + + this.setOutputTable(patternsTable); + this.setSideOutputTables(new Table[] {rulesTable}); + return this; + } + + private static int decideMinSupportCount(int minSupportCnt, double minSupportPercent, int transactionCount) { + if (minSupportCnt >= 0) { + return minSupportCnt; + } + return (int) Math.floor(transactionCount * minSupportPercent); + } + + private static Map getItemCounts(List transactions) { + Map itemCounts = new HashMap <>(); + for (String transaction : transactions) { + if (StringUtils.isNullOrWhitespaceOnly(transaction)) { + continue; + } + String[] items = transaction.split(FpGrowthBatchOp.ITEM_SEPARATOR); + Set itemSet = new HashSet <>(); + itemSet.addAll(Arrays.asList(items)); + for (String item : itemSet) { + itemCounts.merge(item, 1, ((a, b) -> a + b)); + } + } + return itemCounts; + } + + private static Tuple2 , List > orderItems(Map itemCounts) { + List allItems = new ArrayList <>(itemCounts.size()); + itemCounts.forEach((k, v) -> { + allItems.add(k); + }); + allItems.sort(new Comparator () { + @Override + public int compare(String o1, String o2) { + return Integer.compare(itemCounts.get(o2), itemCounts.get(o1)); + } + }); + Map tokenToIndex = new HashMap <>(itemCounts.size()); + for (int i = 0; i < allItems.size(); i++) { + tokenToIndex.put(allItems.get(i), i); + } + return Tuple2.of(tokenToIndex, allItems); + } + + private static int[] toArray(Set list) { + int[] array = new int[list.size()]; + int n = 0; + for (Integer i : list) { + array[n++] = i; + } + return array; + } + + private static int[] getQualifiedItemIndices(Map itemCounts, Map tokenToIndex, + int minSupportCount) { + List qualified = new ArrayList <>(); + itemCounts.forEach((k, v) -> { + if (v >= minSupportCount) { + qualified.add(k); + } + }); + int[] indices = new int[qualified.size()]; + for (int i = 0; i < qualified.size(); i++) { + indices[i] = tokenToIndex.get(qualified.get(i)); + } + return indices; + } + + private static String indicesToTokens(int[] items, List orderedItems) { + StringBuilder sbd = new StringBuilder(); + for (int i = 0; i < items.length; i++) { + if (i > 0) { + sbd.append(FpGrowthBatchOp.ITEM_SEPARATOR); + } + sbd.append(orderedItems.get(items[i])); + } + return sbd.toString(); + } + +} + diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/PrefixSpanBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/PrefixSpanBatchOp.java index a70022105..9abf4ec1c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/PrefixSpanBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/associationrule/PrefixSpanBatchOp.java @@ -23,13 +23,14 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortSpec.OpType; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.associationrule.ParallelPrefixSpan; @@ -53,6 +54,7 @@ @OutputPorts(values = @PortSpec(value = PortType.MODEL)) @ParamSelectColumnSpec(name = "itemsCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("PrefixSpan") +@NameEn("PrefixSpan") public final class PrefixSpanBatchOp extends BatchOperator implements PrefixSpanParams { private static final Logger LOG = LoggerFactory.getLogger(PrefixSpanBatchOp.class); diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/audio/ExtractMfccFeatureBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/audio/ExtractMfccFeatureBatchOp.java index 1ff0db55a..4e31c28b5 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/audio/ExtractMfccFeatureBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/audio/ExtractMfccFeatureBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.audio.ExtractMfccFeatureMapper; import com.alibaba.alink.params.audio.ExtractMfccFeatureParams; @NameCn("MFCC特征提取") +@NameEn("MFCC Feature Extraction") public class ExtractMfccFeatureBatchOp extends MapBatchOp implements ExtractMfccFeatureParams { public ExtractMfccFeatureBatchOp() { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/audio/ReadAudioToTensorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/audio/ReadAudioToTensorBatchOp.java index e84224335..eec5c9a28 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/audio/ReadAudioToTensorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/audio/ReadAudioToTensorBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -11,6 +12,7 @@ @ParamSelectColumnSpec(name="relativeFilePathCol",allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("音频转张量") +@NameEn("Audio To Tensor") public class ReadAudioToTensorBatchOp extends MapBatchOp implements ReadAudioToTensorParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/BertTextClassifierPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/BertTextClassifierPredictBatchOp.java index 15605019e..dca5f00e1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/BertTextClassifierPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/BertTextClassifierPredictBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; /** * Prediction with text classifier using Bert models. */ @NameCn("Bert文本分类预测") +@NameEn("Bert Text Classification Prediction") public class BertTextClassifierPredictBatchOp extends TFTableModelClassifierPredictBatchOp { public BertTextClassifierPredictBatchOp() { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/BertTextClassifierTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/BertTextClassifierTrainBatchOp.java index d1dd12023..38e4ce4c9 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/BertTextClassifierTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/BertTextClassifierTrainBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.dl.BaseEasyTransferTrainBatchOp; @@ -11,6 +12,7 @@ import com.alibaba.alink.params.dl.HasTaskType; import com.alibaba.alink.params.tensorflow.bert.BertTextTrainParams; import com.alibaba.alink.params.tensorflow.bert.HasTaskName; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Train a text classifier using Bert models. @@ -18,6 +20,8 @@ @ParamSelectColumnSpec(name = "textCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @ParamSelectColumnSpec(name = "labelCol") @NameCn("Bert文本分类训练") +@NameEn("Bert Text Classifier Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.BertTextClassifier") public class BertTextClassifierTrainBatchOp extends BaseEasyTransferTrainBatchOp implements BertTextTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/BertTextPairClassifierPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/BertTextPairClassifierPredictBatchOp.java index cfca58832..47a310654 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/BertTextPairClassifierPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/BertTextPairClassifierPredictBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; /** * Prediction with text pair classifier using Bert models. */ @NameCn("Bert文本对分类预测") +@NameEn("Bert Text Pair Classifier Prediction") public class BertTextPairClassifierPredictBatchOp extends TFTableModelClassifierPredictBatchOp { public BertTextPairClassifierPredictBatchOp() { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/BertTextPairClassifierTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/BertTextPairClassifierTrainBatchOp.java index 7ed173464..e4800d1a5 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/BertTextPairClassifierTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/BertTextPairClassifierTrainBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.dl.BaseEasyTransferTrainBatchOp; @@ -11,6 +12,7 @@ import com.alibaba.alink.params.dl.HasTaskType; import com.alibaba.alink.params.tensorflow.bert.BertTextPairTrainParams; import com.alibaba.alink.params.tensorflow.bert.HasTaskName; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Train a text pair classifier using Bert models. @@ -19,6 +21,8 @@ @ParamSelectColumnSpec(name = "textPairCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @ParamSelectColumnSpec(name = "labelCol") @NameCn("Bert文本对分类训练") +@NameEn("Bert Text Pair Classifier Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.BertTextPairClassifier") public class BertTextPairClassifierTrainBatchOp extends BaseEasyTransferTrainBatchOp implements BertTextPairTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/C45PredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/C45PredictBatchOp.java index 3ae6214a3..922113a9f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/C45PredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/C45PredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.tree.predictors.RandomForestModelMapper; import com.alibaba.alink.params.classification.C45PredictParams; @@ -11,6 +12,7 @@ * The batch operator that predict the data using the c45 model. */ @NameCn("C45决策树分类预测") +@NameEn("C45 Decision Tree Prediction") public final class C45PredictBatchOp extends ModelMapBatchOp implements C45PredictParams { private static final long serialVersionUID = -3642003580227332493L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/C45TrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/C45TrainBatchOp.java index 3a05bbc46..ae2950e59 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/C45TrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/C45TrainBatchOp.java @@ -3,7 +3,8 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp; import com.alibaba.alink.operator.common.tree.TreeModelInfo; import com.alibaba.alink.operator.common.tree.TreeUtil; @@ -11,12 +12,15 @@ import com.alibaba.alink.params.shared.tree.HasFeatureSubsamplingRatio; import com.alibaba.alink.params.shared.tree.HasNumTreesDefaltAs10; import com.alibaba.alink.params.shared.tree.HasSubsamplingRatio; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Fit a c45 model. */ @NameCn("C45决策树分类训练") -public final class C45TrainBatchOp extends BaseRandomForestTrainBatchOp +@NameEn("C45 Decision Tree Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.C45") +public class C45TrainBatchOp extends BaseRandomForestTrainBatchOp implements C45TrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/CartPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/CartPredictBatchOp.java index 981f0fd14..db6e7b38f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/CartPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/CartPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.tree.predictors.RandomForestModelMapper; import com.alibaba.alink.params.classification.CartPredictParams; @@ -11,6 +12,7 @@ * The batch operator that predict the data using the cart model. */ @NameCn("CART决策树分类预测") +@NameEn("CART Decision Tree Prediction") public final class CartPredictBatchOp extends ModelMapBatchOp implements CartPredictParams { private static final long serialVersionUID = 2672681021424392380L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/CartTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/CartTrainBatchOp.java index b09a29080..fb206e399 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/CartTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/CartTrainBatchOp.java @@ -3,7 +3,8 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp; import com.alibaba.alink.operator.common.tree.TreeModelInfo; import com.alibaba.alink.operator.common.tree.TreeUtil; @@ -11,12 +12,15 @@ import com.alibaba.alink.params.shared.tree.HasFeatureSubsamplingRatio; import com.alibaba.alink.params.shared.tree.HasNumTreesDefaltAs10; import com.alibaba.alink.params.shared.tree.HasSubsamplingRatio; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Fit a cart model. */ @NameCn("CART决策树分类训练") -public final class CartTrainBatchOp extends BaseRandomForestTrainBatchOp +@NameEn("CART Decision Tree Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.Cart") +public class CartTrainBatchOp extends BaseRandomForestTrainBatchOp implements CartTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/DecisionTreePredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/DecisionTreePredictBatchOp.java index dbff1cb66..544069f30 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/DecisionTreePredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/DecisionTreePredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.tree.predictors.RandomForestModelMapper; import com.alibaba.alink.params.classification.DecisionTreePredictParams; @@ -28,6 +29,7 @@ * @see Random_forest */ @NameCn("决策树预测") +@NameEn("Decision Tree Prediction") public final class DecisionTreePredictBatchOp extends ModelMapBatchOp implements DecisionTreePredictParams { private static final long serialVersionUID = 3664269451746168314L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/DecisionTreeTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/DecisionTreeTrainBatchOp.java index 793eee50c..6e8619da2 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/DecisionTreeTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/DecisionTreeTrainBatchOp.java @@ -3,7 +3,8 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp; import com.alibaba.alink.operator.common.tree.TreeModelInfo; import com.alibaba.alink.operator.common.tree.TreeUtil; @@ -11,6 +12,7 @@ import com.alibaba.alink.params.classification.RandomForestTrainParams; import com.alibaba.alink.params.shared.tree.HasFeatureSubsamplingRatio; import com.alibaba.alink.params.shared.tree.HasSubsamplingRatio; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * The random forest use the bagging to prevent the overfitting. @@ -33,7 +35,9 @@ * @see Random_forest */ @NameCn("决策树训练") -public final class DecisionTreeTrainBatchOp extends BaseRandomForestTrainBatchOp +@NameEn("Decision Tree Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.DecisionTreeClassifier") +public class DecisionTreeTrainBatchOp extends BaseRandomForestTrainBatchOp implements DecisionTreeTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/FmClassifierModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/FmClassifierModelInfoBatchOp.java index 5d0b13385..3059eaa78 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/FmClassifierModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/FmClassifierModelInfoBatchOp.java @@ -3,8 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; -import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.fm.FmClassifierModelInfo; import java.util.List; @@ -34,8 +33,4 @@ protected FmClassifierModelInfo createModelInfo(List rows) { return new FmClassifierModelInfo(rows); } - @Override - protected BatchOperator processModel() { - return this; - } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/FmClassifierTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/FmClassifierTrainBatchOp.java index f06768c7e..a712876e8 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/FmClassifierTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/FmClassifierTrainBatchOp.java @@ -5,13 +5,14 @@ import com.alibaba.alink.common.annotation.NameCn; import com.alibaba.alink.common.annotation.NameEn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; -import com.alibaba.alink.common.lazy.WithTrainInfo; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithTrainInfo; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.fm.FmClassifierModelInfo; import com.alibaba.alink.operator.common.fm.FmClassifierModelTrainInfo; import com.alibaba.alink.operator.common.fm.FmTrainBatchOp; import com.alibaba.alink.params.recommendation.FmTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import java.util.List; @@ -20,6 +21,7 @@ */ @NameCn("FM分类训练") @NameEn("FM Classification Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.FmClassifier") public class FmClassifierTrainBatchOp extends FmTrainBatchOp implements FmTrainParams , WithModelInfoBatchOp , diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/GbdtPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/GbdtPredictBatchOp.java index 9dfcf7d7e..920dfcbef 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/GbdtPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/GbdtPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -14,6 +15,7 @@ */ @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("GBDT分类器预测") +@NameEn("GBDT Classifier Prediction") public final class GbdtPredictBatchOp extends ModelMapBatchOp implements GbdtPredictParams { private static final long serialVersionUID = 2801048935862838531L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/GbdtTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/GbdtTrainBatchOp.java index fa31dc8bf..79790f9b0 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/GbdtTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/GbdtTrainBatchOp.java @@ -3,12 +3,14 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.common.tree.TreeModelInfo.GbdtModelInfo; import com.alibaba.alink.operator.common.tree.parallelcart.BaseGbdtTrainBatchOp; import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossType; import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossUtils; import com.alibaba.alink.params.classification.GbdtTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Fit a binary classfication model. @@ -16,7 +18,9 @@ * @see BaseGbdtTrainBatchOp */ @NameCn("GBDT分类器训练") -public final class GbdtTrainBatchOp extends BaseGbdtTrainBatchOp +@NameEn("GBDT Classifier Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.GbdtClassifier") +public class GbdtTrainBatchOp extends BaseGbdtTrainBatchOp implements GbdtTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/Id3PredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/Id3PredictBatchOp.java index 0b4932377..d78c5a6a1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/Id3PredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/Id3PredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.tree.predictors.RandomForestModelMapper; import com.alibaba.alink.params.classification.Id3PredictParams; @@ -11,6 +12,7 @@ * The batch operator that predict the data using the id3 model. */ @NameCn("ID3决策树分类预测") +@NameEn("ID3 Decision Tree Prediction") public final class Id3PredictBatchOp extends ModelMapBatchOp implements Id3PredictParams { private static final long serialVersionUID = 7494960454797499134L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/Id3TrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/Id3TrainBatchOp.java index 280c4f62f..62ae8633c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/Id3TrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/Id3TrainBatchOp.java @@ -3,7 +3,8 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp; import com.alibaba.alink.operator.common.tree.TreeModelInfo; import com.alibaba.alink.operator.common.tree.TreeUtil; @@ -11,12 +12,15 @@ import com.alibaba.alink.params.shared.tree.HasFeatureSubsamplingRatio; import com.alibaba.alink.params.shared.tree.HasNumTreesDefaltAs10; import com.alibaba.alink.params.shared.tree.HasSubsamplingRatio; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Fit a id3 model. */ @NameCn("ID3决策树分类训练") -public final class Id3TrainBatchOp extends BaseRandomForestTrainBatchOp +@NameEn("ID3 Decision Tree Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.Id3") +public class Id3TrainBatchOp extends BaseRandomForestTrainBatchOp implements Id3TrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/KerasSequentialClassifierPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/KerasSequentialClassifierPredictBatchOp.java index f8d3e8ee8..4f674faa9 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/KerasSequentialClassifierPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/KerasSequentialClassifierPredictBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; /** * Prediction with a classifier using a Keras Sequential model. */ @NameCn("KerasSequential分类预测") +@NameEn("KerasSequential Classifier Prediction") public class KerasSequentialClassifierPredictBatchOp extends TFTableModelClassifierPredictBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/KerasSequentialClassifierTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/KerasSequentialClassifierTrainBatchOp.java index 3cbe3a4a4..8775ef289 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/KerasSequentialClassifierTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/KerasSequentialClassifierTrainBatchOp.java @@ -3,14 +3,18 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.dl.BaseKerasSequentialTrainBatchOp; import com.alibaba.alink.common.dl.TaskType; import com.alibaba.alink.params.dl.HasTaskType; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Train a classifier using a Keras Sequential model. */ @NameCn("KerasSequential分类训练") +@NameEn("KerasSequential Classifier Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.KerasSequentialClassifier") public class KerasSequentialClassifierTrainBatchOp extends BaseKerasSequentialTrainBatchOp { public KerasSequentialClassifierTrainBatchOp() { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/KnnPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/KnnPredictBatchOp.java index 427953496..e5bbbb951 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/KnnPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/KnnPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -15,6 +16,7 @@ @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("最近邻分类预测") +@NameEn("Knn Prediction") public final class KnnPredictBatchOp extends ModelMapBatchOp implements KnnPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/KnnTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/KnnTrainBatchOp.java index 4b0c1d07a..44608ade7 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/KnnTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/KnnTrainBatchOp.java @@ -5,6 +5,7 @@ import com.alibaba.alink.common.annotation.FeatureColsVectorColMutexRule; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; @@ -15,6 +16,7 @@ import com.alibaba.alink.operator.batch.dataproc.vector.VectorAssemblerBatchOp; import com.alibaba.alink.operator.batch.similarity.VectorNearestNeighborTrainBatchOp; import com.alibaba.alink.params.classification.KnnTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * KNN is to classify unlabeled observations by assigning them to the class of the most similar labeled examples. @@ -26,6 +28,8 @@ @FeatureColsVectorColMutexRule @ParamSelectColumnSpec(name = "labelCol") @NameCn("最近邻分类训练") +@NameEn("Knn Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.KnnClassifier") public final class KnnTrainBatchOp extends BatchOperator implements KnnTrainParams { private static final long serialVersionUID = -3118065094037473283L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/LinearSvmModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/LinearSvmModelInfoBatchOp.java index b886f2959..ef295aed8 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/LinearSvmModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/LinearSvmModelInfoBatchOp.java @@ -3,8 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; -import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.linear.LinearClassifierModelInfo; import java.util.List; @@ -30,8 +29,4 @@ protected LinearClassifierModelInfo createModelInfo(List rows) { return new LinearClassifierModelInfo(rows); } - @Override - protected BatchOperator processModel() { - return this; - } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/LinearSvmTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/LinearSvmTrainBatchOp.java index 0137f2a95..71e6edc78 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/LinearSvmTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/LinearSvmTrainBatchOp.java @@ -4,11 +4,12 @@ import com.alibaba.alink.common.annotation.NameCn; import com.alibaba.alink.common.annotation.NameEn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp; import com.alibaba.alink.operator.common.linear.LinearClassifierModelInfo; import com.alibaba.alink.operator.common.linear.LinearModelType; import com.alibaba.alink.params.classification.LinearBinaryClassTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Linear svm train batch operator. it uses hinge loss func by setting LinearModelType = SVM and model name = "linear @@ -16,6 +17,7 @@ */ @NameCn("线性支持向量机训练") @NameEn("Linear SVM Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.LinearSvm") public final class LinearSvmTrainBatchOp extends BaseLinearModelTrainBatchOp implements LinearBinaryClassTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/LogisticRegressionModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/LogisticRegressionModelInfoBatchOp.java index 500bcb92b..2b36cbfee 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/LogisticRegressionModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/LogisticRegressionModelInfoBatchOp.java @@ -3,8 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; -import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.linear.LinearClassifierModelInfo; import java.util.List; @@ -30,8 +29,4 @@ protected LinearClassifierModelInfo createModelInfo(List rows) { return new LinearClassifierModelInfo(rows); } - @Override - protected BatchOperator processModel() { - return this; - } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/LogisticRegressionTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/LogisticRegressionTrainBatchOp.java index 7060e40fe..79c68238d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/LogisticRegressionTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/LogisticRegressionTrainBatchOp.java @@ -4,11 +4,12 @@ import com.alibaba.alink.common.annotation.NameCn; import com.alibaba.alink.common.annotation.NameEn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp; import com.alibaba.alink.operator.common.linear.LinearClassifierModelInfo; import com.alibaba.alink.operator.common.linear.LinearModelType; import com.alibaba.alink.params.classification.LinearBinaryClassTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Logistic regression train batch operator. we use log loss func by setting LinearModelType = LR and model @@ -16,6 +17,7 @@ */ @NameCn("逻辑回归训练") @NameEn("Logistic Regression Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.LogisticRegression") public final class LogisticRegressionTrainBatchOp extends BaseLinearModelTrainBatchOp implements LinearBinaryClassTrainParams , WithModelInfoBatchOp implements MultilayerPerceptronTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesModelInfoBatchOp.java index 4a6b7b1b8..8fe0b75de 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.classification.NaiveBayesModelData; import com.alibaba.alink.operator.common.classification.NaiveBayesModelDataConverter; @@ -23,7 +23,7 @@ public NaiveBayesModelInfoBatchOp(Params params) { } @Override - protected NaiveBayesModelInfo createModelInfo(List rows) { + public NaiveBayesModelInfo createModelInfo(List rows) { NaiveBayesModelData modelData = new NaiveBayesModelDataConverter().load(rows); NaiveBayesModelInfo modelInfo = new NaiveBayesModelInfo(modelData.featureNames, modelData.isCate, diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesPredictBatchOp.java index 0286ebe63..d6ceeb80d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.classification.NaiveBayesModelMapper; import com.alibaba.alink.params.classification.NaiveBayesPredictParams; @@ -11,6 +12,7 @@ * Naive Bayes Predictor. */ @NameCn("朴素贝叶斯预测") +@NameEn("Naive Bayes Prediction") public class NaiveBayesPredictBatchOp extends ModelMapBatchOp implements NaiveBayesPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesTextModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesTextModelInfoBatchOp.java index 58ce71257..a29e30482 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesTextModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesTextModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.classification.NaiveBayesTextModelData; import com.alibaba.alink.operator.common.classification.NaiveBayesTextModelDataConverter; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesTextPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesTextPredictBatchOp.java index ee21581c4..f825611f9 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesTextPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesTextPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -20,6 +21,7 @@ */ @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("朴素贝叶斯文本分类预测") +@NameEn("Naive Bayes Text Prediction") public final class NaiveBayesTextPredictBatchOp extends ModelMapBatchOp implements NaiveBayesTextPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesTextTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesTextTrainBatchOp.java index 1c51af0ed..8691ff080 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesTextTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesTextTrainBatchOp.java @@ -16,12 +16,13 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.linalg.DenseMatrix; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; @@ -30,10 +31,11 @@ import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.classification.NaiveBayesTextModelData; import com.alibaba.alink.operator.common.classification.NaiveBayesTextModelDataConverter; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; import com.alibaba.alink.params.classification.NaiveBayesTextTrainParams; import com.alibaba.alink.params.shared.colname.HasFeatureCols; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import java.util.ArrayList; @@ -52,6 +54,8 @@ @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @ParamSelectColumnSpec(name = "weightCol", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("朴素贝叶斯文本分类训练") +@NameEn("Naive Bayes Text Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.NaiveBayesTextClassifier") public class NaiveBayesTextTrainBatchOp extends BatchOperator implements NaiveBayesTextTrainParams , diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesTrainBatchOp.java index 3e4f5b4b6..dc99ef163 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/NaiveBayesTrainBatchOp.java @@ -16,6 +16,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; @@ -23,7 +24,7 @@ import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.classification.NaiveBayesModelData; @@ -31,6 +32,7 @@ import com.alibaba.alink.operator.common.tree.Preprocessing; import com.alibaba.alink.params.classification.NaiveBayesTrainParams; import com.alibaba.alink.params.shared.colname.HasCategoricalCols; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import java.util.ArrayList; import java.util.Arrays; @@ -50,6 +52,8 @@ @ParamSelectColumnSpec(name = "featureCols") @ParamSelectColumnSpec(name = "weightCol", allowedTypeCollections = TypeCollections.DOUBLE_TYPE) @NameCn("朴素贝叶斯训练") +@NameEn("Naive Bayes Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.NaiveBayes") public class NaiveBayesTrainBatchOp extends BatchOperator implements NaiveBayesTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/RandomForestPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/RandomForestPredictBatchOp.java index 19589b8e7..dd1da5b74 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/RandomForestPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/RandomForestPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.tree.predictors.RandomForestModelMapper; import com.alibaba.alink.params.classification.RandomForestPredictParams; @@ -11,6 +12,7 @@ * The batch operator that predict the data using the random forest model. */ @NameCn("随机森林预测") +@NameEn("Random Forest Prediction") public final class RandomForestPredictBatchOp extends ModelMapBatchOp implements RandomForestPredictParams { private static final long serialVersionUID = -4391732102873972774L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/RandomForestTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/RandomForestTrainBatchOp.java index 1ccac55ba..2f97e0a61 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/RandomForestTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/RandomForestTrainBatchOp.java @@ -3,10 +3,12 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp; import com.alibaba.alink.operator.common.tree.TreeModelInfo; import com.alibaba.alink.params.classification.RandomForestTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Fit a random forest classification model. @@ -14,7 +16,9 @@ * @see BaseRandomForestTrainBatchOp */ @NameCn("随机森林训练") -public final class RandomForestTrainBatchOp extends BaseRandomForestTrainBatchOp implements +@NameEn("Random Forest Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.RandomForestClassifier") +public class RandomForestTrainBatchOp extends BaseRandomForestTrainBatchOp implements RandomForestTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/SoftmaxModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/SoftmaxModelInfoBatchOp.java index 0098b46f7..20e5e3297 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/SoftmaxModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/SoftmaxModelInfoBatchOp.java @@ -3,8 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; -import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.linear.SoftmaxModelInfo; import java.util.List; @@ -27,8 +26,4 @@ protected SoftmaxModelInfo createModelInfo(List rows) { return new SoftmaxModelInfo(rows); } - @Override - protected BatchOperator processModel() { - return this; - } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/SoftmaxTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/SoftmaxTrainBatchOp.java index 430930a0c..e2fd1e3cc 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/SoftmaxTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/SoftmaxTrainBatchOp.java @@ -11,6 +11,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.configuration.Configuration; @@ -34,13 +35,13 @@ import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.exceptions.AkIllegalModelException; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; -import com.alibaba.alink.common.lazy.WithTrainInfo; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithTrainInfo; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.linalg.Vector; import com.alibaba.alink.common.model.ModelParamName; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; @@ -61,6 +62,7 @@ import com.alibaba.alink.params.shared.colname.HasVectorCol; import com.alibaba.alink.params.shared.linear.HasL1; import com.alibaba.alink.params.shared.linear.LinearTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import java.util.ArrayList; import java.util.HashMap; @@ -89,8 +91,10 @@ @NameCn("Softmax训练") @NameEn("Softmax Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.Softmax") public final class SoftmaxTrainBatchOp extends BatchOperator - implements SoftmaxTrainParams , WithTrainInfo , + implements SoftmaxTrainParams , + WithTrainInfo , WithModelInfoBatchOp { private static final long serialVersionUID = 2291776467437931890L; @@ -255,7 +259,7 @@ public void reduce(Iterable values, Collector out) { this.setOutput(modelRows, new LinearModelDataConverter(labelType).getModelSchema()); - this.setSideOutputTables(getSideTablesOfCoefficient(modelRows)); + this.setSideOutputTables(getSideTablesOfCoefficient(modelRows, coefs.project(1))); return this; } @@ -483,9 +487,7 @@ public void mapPartition(Iterable > iterable, Collector collector) throws Exception { List coefVectors = new ArrayList <>(); boolean hasIntercept = this.meta.get(ModelParamName.HAS_INTERCEPT_ITEM); - double[] convInfo = null; for (Tuple2 coefVector : iterable) { - convInfo = coefVector.f1; this.meta.set(ModelParamName.VECTOR_SIZE, coefVector.f0.size() / (labelSize - 1) - (hasIntercept ? 1 : 0)); this.meta.set(ModelParamName.LOSS_CURVE, coefVector.f1); @@ -514,12 +516,11 @@ public void mapPartition(Iterable > iterable, } LinearModelData modelData = new LinearModelData(labelType, meta, featureNames, coefVectors.get(0)); - modelData.convergenceInfo = convInfo; new LinearModelDataConverter(this.labelType).save(modelData, collector); } } - private Table[] getSideTablesOfCoefficient(DataSet modelRow) { + private Table[] getSideTablesOfCoefficient(DataSet modelRow, DataSet> convergenceInfo) { DataSet model = modelRow.mapPartition(new MapPartitionFunction () { private static final long serialVersionUID = 2063366042018382802L; @@ -542,12 +543,14 @@ public void mapPartition(Iterable values, Collector out) public void mapPartition(Iterable values, Collector out) { LinearModelData model = values.iterator().next(); - double[] cinfo = model.convergenceInfo; + double[] cinfo = ((Tuple1 )getRuntimeContext().getBroadcastVariable("cinfo").get(0)).f0; out.collect(Row.of(0L, JsonConverter.toJson(model.getMetaInfo()))); out.collect(Row.of(4L, JsonConverter.toJson(cinfo))); } - }).setParallelism(1).withBroadcastSet(model, "model"); + }).setParallelism(1) + .withBroadcastSet(model, "model") + .withBroadcastSet(convergenceInfo, "cinfo"); Table summaryTable = DataSetConversionUtil.toTable(getMLEnvironmentId(), summary, new TableSchema( new String[] {"title", "info"}, new TypeInformation[] {Types.LONG, Types.STRING})); diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/classification/XGBoostTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/classification/XGBoostTrainBatchOp.java index 0a92bbb1d..f89d6f99b 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/classification/XGBoostTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/classification/XGBoostTrainBatchOp.java @@ -6,9 +6,11 @@ import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.tree.BaseXGBoostTrainBatchOp; import com.alibaba.alink.params.xgboost.XGBoostTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; @NameCn("XGBoost二分类训练") @NameEn("XGBoost Binary Classification Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.classification.XGBoostClassifier") public final class XGBoostTrainBatchOp extends BaseXGBoostTrainBatchOp implements XGBoostTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/AgnesBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/AgnesBatchOp.java new file mode 100644 index 000000000..2998b044d --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/AgnesBatchOp.java @@ -0,0 +1,188 @@ +package com.alibaba.alink.operator.batch.clustering; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.ReservedColsWithSecondInputSpec; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; +import com.alibaba.alink.common.linalg.VectorUtil; +//import com.alibaba.alink.common.utils.AlinkViz; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.clustering.agnes.Agnes; +import com.alibaba.alink.operator.common.clustering.agnes.AgnesCluster; +import com.alibaba.alink.operator.common.clustering.agnes.AgnesModelInfoBatchOp; +import com.alibaba.alink.operator.common.clustering.agnes.AgnesSample; +import com.alibaba.alink.operator.common.clustering.agnes.Linkage; +import com.alibaba.alink.operator.common.distance.ContinuousDistance; +import com.alibaba.alink.operator.common.evaluation.EvaluationUtil; +import com.alibaba.alink.params.clustering.AgnesParams; + +import java.util.ArrayList; +import java.util.List; + +/** + * @author guotao.gt + */ +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = { + @PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT), + @PortSpec(value = PortType.EVAL_METRICS), +}) +@ReservedColsWithSecondInputSpec +@ParamSelectColumnSpec(name = "vectorCol", portIndices = 0, allowedTypeCollections = {TypeCollections.VECTOR_TYPES}) +@ParamSelectColumnSpec(name = "idCol", portIndices = 0) +@NameCn("Agnes") +@NameEn("Agnes") +public final class AgnesBatchOp extends BatchOperator + implements AgnesParams , + //AlinkViz , + WithModelInfoBatchOp { + + private static final long serialVersionUID = -7069169801410116405L; + + public AgnesBatchOp() { + super(null); + } + + public AgnesBatchOp(Params params) { + super(params); + } + + @Override + public AgnesBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + + final int k = this.getParams().get(K); + final double distanceThreshold = getParams().get(DISTANCE_THRESHOLD); + final DistanceType distanceType = get(DISTANCE_TYPE); + final Linkage linkage = getParams().get(LINKAGE); + ContinuousDistance distance = distanceType.getFastDistance(); + + if (k <= 1 && distanceThreshold == Double.MAX_VALUE) { + throw new RuntimeException("k should larger than 1,or distanceThreshold should be set"); + } + TypeInformation idType = TableUtil.findColTypeWithAssertAndHint(in.getSchema(), this.getIdCol()); + + DataSet data = in.select(new String[] {this.getIdCol(), this.getVectorCol()}).getDataSet() + .map(new MapFunction () { + private static final long serialVersionUID = -4667000522433310128L; + @Override + public AgnesSample map(Row row) throws Exception { + String idColValue = row.getField(0).toString(); + // the default clusterID is set as 0, will be set later. + return new AgnesSample(idColValue, 0, VectorUtil.getDenseVector(row.getField(1)), 1.0); + } + }); + + DataSet clusters = data.mapPartition( + new AgnesKernel(distanceThreshold, k, distance, linkage)) + .setParallelism(1); + // start get cluster result + DataSet dataRow = clusters.flatMap( + new TransferClusterResult(idType)); + + DataSet mergeInfo = clusters.flatMap(new MergeInfo(idType)); + + TableSchema outputSchema = new TableSchema(new String[] {this.getIdCol(), this.getPredictionCol()}, + new TypeInformation[] { + idType, AlinkTypes.LONG}); + + this.setOutput(dataRow, outputSchema); + this.setSideOutputTables(new Table[] { + DataSetConversionUtil.toTable(getMLEnvironmentId(), mergeInfo, + new TableSchema(new String[] {"NodeId", "MergeIteration", "ParentId"}, new TypeInformation[] { + idType, AlinkTypes.LONG, idType})) + }); + + this.setOutput(dataRow, outputSchema); + return this; + } + + public static class AgnesKernel implements MapPartitionFunction { + private static final long serialVersionUID = 886248302149838023L; + private double distanceThreshold; + private int k; + private ContinuousDistance distance; + private Linkage linkage; + + public AgnesKernel(double distanceThreshold, int k, ContinuousDistance distance, Linkage linkage) { + this.distanceThreshold = distanceThreshold; + this.k = k; + this.distance = distance; + this.linkage = linkage; + } + + @Override + public void mapPartition(Iterable values, Collector out) throws Exception { + List samples = new ArrayList <>(); + for (AgnesSample sample : values) { + samples.add(sample); + } + + List clusters = Agnes.startAnalysis(samples, k, distanceThreshold, linkage, distance); + for (AgnesCluster cluster : clusters) { + out.collect(cluster); + } + } + } + + public static class TransferClusterResult implements FlatMapFunction { + private static final long serialVersionUID = 531203134457473817L; + private long clusterId = 0; + private TypeInformation idType; + + public TransferClusterResult(TypeInformation idType) { + this.idType = idType; + } + + @Override + public void flatMap(AgnesCluster cluster, Collector out) throws Exception { + for (AgnesSample dp : cluster.getAgnesSamples()) { + out.collect(Row.of(EvaluationUtil.castTo(dp.getSampleId(), idType), clusterId)); + } + clusterId++; + } + } + + public static class MergeInfo implements FlatMapFunction { + private static final long serialVersionUID = 531203134457473817L; + private TypeInformation idType; + + public MergeInfo(TypeInformation idType) { + this.idType = idType; + } + + @Override + public void flatMap(AgnesCluster cluster, Collector out) throws Exception { + for (AgnesSample dp : cluster.getAgnesSamples()) { + out.collect( + Row.of(EvaluationUtil.castTo(dp.getSampleId(), idType), dp.getMergeIter(), dp.getParentId())); + } + } + } + + @Override + public AgnesModelInfoBatchOp getModelInfoBatchOp() { + return new AgnesModelInfoBatchOp(this.getParams()).linkFrom(this); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/BisectingKMeansModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/BisectingKMeansModelInfoBatchOp.java index 20b60c947..cef397c21 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/BisectingKMeansModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/BisectingKMeansModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.operator.common.clustering.BisectingKMeansModelData; import com.alibaba.alink.operator.common.clustering.BisectingKMeansModelDataConverter; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/BisectingKMeansPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/BisectingKMeansPredictBatchOp.java index ba5e6ed03..5bb90edca 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/BisectingKMeansPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/BisectingKMeansPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.clustering.BisectingKMeansModelMapper; import com.alibaba.alink.params.clustering.BisectingKMeansPredictParams; @@ -11,6 +12,7 @@ * Bisecting KMeans prediction based on the model fitted by BisectingKMeansTrainBatchOp. */ @NameCn("二分K均值聚类预测") +@NameEn("Bisecting KMeans Prediction") public final class BisectingKMeansPredictBatchOp extends ModelMapBatchOp implements BisectingKMeansPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/BisectingKMeansTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/BisectingKMeansTrainBatchOp.java index 6f621cf05..8be237f1a 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/BisectingKMeansTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/BisectingKMeansTrainBatchOp.java @@ -24,6 +24,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; @@ -32,7 +33,7 @@ import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.exceptions.AkIllegalStateException; import com.alibaba.alink.common.exceptions.AkPreconditions; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.linalg.BLAS; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.MatVecOp; @@ -46,9 +47,10 @@ import com.alibaba.alink.operator.common.dataproc.FirstReducer; import com.alibaba.alink.operator.common.distance.ContinuousDistance; import com.alibaba.alink.operator.common.distance.EuclideanDistance; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; import com.alibaba.alink.params.clustering.BisectingKMeansTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -78,6 +80,8 @@ @ParamSelectColumnSpec(name = "vectorCol", portIndices = 0, allowedTypeCollections = {TypeCollections.VECTOR_TYPES}) @NameCn("二分K均值聚类训练") +@NameEn("Bisecting KMeans Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.clustering.BisectingKMeans") public final class BisectingKMeansTrainBatchOp extends BatchOperator implements BisectingKMeansTrainParams , WithModelInfoBatchOp + implements DbscanParams , + WithModelInfoBatchOp { + + private static final long serialVersionUID = 2680984884814078788L; + + public DbscanBatchOp() { + super(new Params()); + } + + public DbscanBatchOp(Params params) { + super(params); + } + + @Override + public DbscanBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + + String idColName = this.getIdCol(); + String vectorColName = this.getVectorCol(); + + DataSet > idLabelVector = DataSetUtils + .zipWithIndex(in.select(new String[] {idColName, vectorColName}).getDataSet()) + .map(new RichMapFunction , Tuple3 >() { + private static final long serialVersionUID = -4516567863938069544L; + + @Override + public Tuple3 map(Tuple2 value) { + return Tuple3.of(value.f0.intValue(), value.f1.getField(0), + VectorUtil.getVector(value.f1.getField(1))); + } + }); + + BatchOperator data = new DataSetWrapperBatchOp( + idLabelVector.map(new MapFunction , Row>() { + private static final long serialVersionUID = 672726382584730805L; + + @Override + public Row map(Tuple3 t) { + return Row.of(t.f0, t.f2); + } + }), new String[] {"alink_unique_id", "vector"}, new TypeInformation[] {AlinkTypes.INT, AlinkTypes.VECTOR}); + + VectorNearestNeighborTrainBatchOp train = new VectorNearestNeighborTrainBatchOp() + .setIdCol("alink_unique_id") + .setSelectedCol("vector") + .setMetric(this.getDistanceType().name()) + .linkFrom(data); + + VectorNearestNeighborPredictBatchOp predict = new VectorNearestNeighborPredictBatchOp() + .setSelectedCol("vector") + .setRadius(this.getEpsilon()) + .setReservedCols("alink_unique_id") + .linkFrom(train, data); + + DataSet > dataSet = predict + .select(new String[] {"alink_unique_id", "vector"}) + .getDataSet() + .mapPartition(new GetCorePoints(this.getMinPoints())); + + DataSet > taskIdLabel = dataSet + //.rebalance() + .mapPartition(new ReduceLocal()); + + IterativeDataSet > loop = taskIdLabel.iterate(Integer.MAX_VALUE); + + //得到全局的label + DataSet > globalMaxLabel = loop.flatMap( + new FlatMapFunction , Tuple2 >() { + private static final long serialVersionUID = -4049782728006537532L; + + @Override + public void flatMap(Tuple2 value, Collector > out) { + int[] keys = value.f1.getKeys(); + int[] clusterIds = value.f1.getClusterIds(); + for (int i = 0; i < keys.length; i++) { + out.collect(Tuple2.of(keys[i], clusterIds[i])); + } + } + }).groupBy(0) + .aggregate(Aggregations.MAX, 1); + + DataSet > update = taskIdLabel + .mapPartition(new LocalMerge()) + .withBroadcastSet(globalMaxLabel, "global"); + + DataSet > feedBack = update.project(0, 1); + + DataSet > filter = update.filter( + new FilterFunction >() { + private static final long serialVersionUID = -1489238776369734510L; + + @Override + public boolean filter(Tuple3 value) { + return !value.f2; + } + }); + + DataSet > idNeighborFinalLabel = loop.closeWith(feedBack, filter); + + DataSet > idClusterId = idNeighborFinalLabel + .mapPartition(new AssignContinuousClusterId()); + + DataSet > idTypeClusterId = dataSet + .leftOuterJoin(idClusterId, BROADCAST_HASH_SECOND) + .where(0) + .equalTo(0) + .with(new AssignAllClusterId()); + + DataSet out = idLabelVector + .join(idTypeClusterId) + .where(0) + .equalTo(0) + .with(new JoinFunction , Tuple3 , Row>() { + private static final long serialVersionUID = 7638483527530324994L; + + @Override + public Row join(Tuple3 first, Tuple3 second) { + return Row.of(first.f1, second.f1.name(), (long) second.f2); + } + }); + + DataSet > idLongClusterId = idTypeClusterId.flatMap( + new FlatMapFunction , Tuple2 >() { + private static final long serialVersionUID = -4449631564554949600L; + + @Override + public void flatMap(Tuple3 value, Collector > out) { + if (value.f1.equals(Type.CORE)) { + out.collect(Tuple2.of(value.f0, (long) value.f2)); + } + } + }); + + DataSet > coreVectorClusterId = idLongClusterId.join(idLabelVector) + .where(0) + .equalTo(0) + .with(new JoinFunction , Tuple3 , Tuple2 > + () { + private static final long serialVersionUID = -1388744253754875541L; + + @Override + public Tuple2 join(Tuple2 first, + Tuple3 second) { + return Tuple2.of(second.f2, first.f1); + } + }); + + DataSet modelRows = coreVectorClusterId.mapPartition( + new SaveModel(vectorColName, getEpsilon(), getDistanceType())).setParallelism(1); + + TypeInformation type = in.getColTypes()[TableUtil.findColIndexWithAssertAndHint(in.getColNames(), idColName)]; + this.setOutput(out, new String[] {idColName, DbscanConstant.TYPE, this.getPredictionCol()}, + new TypeInformation[] { + type, AlinkTypes.STRING, AlinkTypes.LONG}); + + this.setSideOutputTables(new Table[] { + DataSetConversionUtil.toTable(getMLEnvironmentId(), modelRows, + new DbscanModelDataConverter().getModelSchema())}); + return this; + } + + static class SaveModel implements MapPartitionFunction , Row> { + private static final long serialVersionUID = 7638276873515252678L; + private String vectorColName; + private double epsilon; + private DistanceType distanceType; + + public SaveModel(String vectorColName, double epsilon, DistanceType distanceType) { + this.vectorColName = vectorColName; + this.epsilon = epsilon; + this.distanceType = distanceType; + } + + @Override + public void mapPartition(Iterable > values, Collector out) throws Exception { + DbscanModelTrainData modelData = new DbscanModelTrainData(); + modelData.coreObjects = values; + modelData.vectorColName = vectorColName; + modelData.epsilon = epsilon; + modelData.distanceType = distanceType; + + new DbscanModelDataConverter().save(modelData, out); + } + } + + static class AssignAllClusterId implements + JoinFunction , Tuple2 , Tuple3 > { + private static final long serialVersionUID = -3817364693537340163L; + + @Override + public Tuple3 join(Tuple3 first, + Tuple2 second) { + if (null == second) { + return Tuple3.of(first.f0, first.f1, Integer.MIN_VALUE); + } else { + if (first.f1.equals(Type.NOISE)) { + first.f1 = Type.LINKED; + } + return Tuple3.of(first.f0, first.f1, second.f1); + } + } + } + + public static class AssignContinuousClusterId + extends RichMapPartitionFunction , Tuple2 > { + private static final long serialVersionUID = -503944144706407009L; + + @Override + public void mapPartition(Iterable > values, + Collector > out) throws Exception { + if (getRuntimeContext().getIndexOfThisSubtask() == 0) { + for (Tuple2 value : values) { + int[] keys = value.f1.getKeys(); + int[] clusterIds = value.f1.getClusterIds(); + Map hashMap = new HashMap <>(); + int cnt = 0; + for (int i = 0; i < keys.length; i++) { + Integer clusterId = hashMap.get(clusterIds[i]); + if (null == clusterId) { + clusterId = cnt++; + hashMap.put(clusterIds[i], clusterId); + } + out.collect(Tuple2.of(keys[i], clusterId)); + } + return; + } + } + } + } + + public static class GetCorePoints implements MapPartitionFunction > { + private int minPoints; + + public GetCorePoints(int minPoints) { + this.minPoints = minPoints; + } + + @Override + public void mapPartition(Iterable values, Collector > collector) { + for (Row t : values) { + Integer id = (int) t.getField(0); + Tuple2 , List > tuple = NearestNeighborsMapper.extractKObject( + (String) t.getField(1), Integer.class); + if (null != tuple.f0 && tuple.f0.size() >= minPoints) { + int[] keys = new int[tuple.f0.size()]; + for (int i = 0; i < tuple.f0.size(); i++) { + keys[i] = (int) tuple.f0.get(i); + } + Arrays.sort(keys); + collector.collect(Tuple3.of(id, Type.CORE, keys)); + } else { + collector.collect(Tuple3.of(id, Type.NOISE, new int[0])); + } + } + } + } + + public static class ReduceLocal + extends RichMapPartitionFunction , Tuple2 > { + + private static final long serialVersionUID = -6516417460468964305L; + + @Override + public void mapPartition(Iterable > tuples, + Collector > collector) { + TreeMap treeMap = new TreeMap(); + int taskId = getRuntimeContext().getIndexOfThisSubtask(); + for (Tuple3 t : tuples) { + Preconditions.checkArgument(t.f1.equals(Type.NOISE) ^ t.f2.length > 0, "Noise must be empty!"); + if (t.f2.length > 0) { + updateTreeMap(treeMap, t.f2); + } + } + + collector.collect(Tuple2.of(taskId, treeMapToLocalCluster(treeMap))); + } + } + + public static void updateTreeMap(TreeMap map, int[] keys) { + int max = keys[keys.length - 1]; + for (int key : keys) { + max = Math.max(map.getOrDefault(key, max), max); + } + for (int key : keys) { + Integer value = map.get(key); + if (null == value) { + map.put(key, max); + } else { + if (max > value) { + map.put(map.get(value), max); + } + } + } + DbscanBatchOp.reduceTreeMap(map); + } + + public static boolean updateTreeMap(TreeMap map, LocalCluster sparseVector) { + int[] keys = sparseVector.getKeys(); + int[] clusterIds = sparseVector.getClusterIds(); + boolean isFinished = true; + + for (int i = 0; i < keys.length; i++) { + int parent = clusterIds[i]; + if (map.get(keys[i]) > map.get(parent)) { + isFinished = false; + map.put(parent, map.get(keys[i])); + } + } + reduceTreeMap(map); + return isFinished; + } + + public static void reduceTreeMap(TreeMap map) { + for (int key : map.descendingKeySet()) { + int parent = map.get(key); + int parentClusterId = map.get(parent); + if (parentClusterId > parent) { + map.put(key, parentClusterId); + } + } + } + + public static LocalCluster treeMapToLocalCluster(TreeMap treeMap) { + int[] clusterIds = new int[treeMap.size()]; + int[] keys = new int[treeMap.size()]; + int cnt = 0; + for (Map.Entry entry : treeMap.entrySet()) { + keys[cnt] = entry.getKey(); + clusterIds[cnt++] = entry.getValue(); + } + return new LocalCluster(keys, clusterIds); + } + + public static class LocalMerge + extends RichMapPartitionFunction , Tuple3 > { + private static final long serialVersionUID = -7265351591038567537L; + private TreeMap global; + + @Override + public void open(Configuration parameters) { + List > list = getRuntimeContext().getBroadcastVariable("global"); + global = new TreeMap <>(); + for (Tuple2 t : list) { + global.put(t.f0, t.f1); + } + reduceTreeMap(global); + } + + @Override + public void mapPartition(Iterable > tuples, + Collector > collector) { + boolean isFinished = true; + Integer taskId = null; + for (Tuple2 t : tuples) { + if (null == taskId) { + taskId = t.f0; + } + isFinished = updateTreeMap(global, t.f1); + } + + if (null != taskId) { + collector.collect(Tuple3.of(taskId, treeMapToLocalCluster(global), isFinished)); + } + } + } + + @Override + public DbscanModelInfoBatchOp getModelInfoBatchOp() { + return new DbscanModelInfoBatchOp(this.getParams()).linkFrom(this); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/DbscanPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/DbscanPredictBatchOp.java new file mode 100644 index 000000000..b1db21089 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/DbscanPredictBatchOp.java @@ -0,0 +1,29 @@ +package com.alibaba.alink.operator.batch.clustering; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.clustering.dbscan.DbscanModelMapper; +import com.alibaba.alink.params.clustering.ClusteringPredictParams; + +/** + * input parameters: -# predResultColName: required + */ + +@NameCn("DBSCAN预测") +@NameEn("DBSCAN Prediction") +public final class DbscanPredictBatchOp extends ModelMapBatchOp + implements ClusteringPredictParams { + + private static final long serialVersionUID = -7841302650523879193L; + + public DbscanPredictBatchOp() { + this(new Params()); + } + + public DbscanPredictBatchOp(Params params) { + super(DbscanModelMapper::new, params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GeoKMeansPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GeoKMeansPredictBatchOp.java index b66a2110e..6284d86a9 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GeoKMeansPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GeoKMeansPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.clustering.kmeans.KMeansModelMapper; import com.alibaba.alink.params.clustering.GeoKMeansPredictParams; @@ -11,6 +12,7 @@ * GeoKMeans prediction based on the model fitted by GeoKMeansTrainBatchOp. */ @NameCn("经纬度K均值聚类预测") +@NameEn("Geo KMeans Prediction") public final class GeoKMeansPredictBatchOp extends ModelMapBatchOp implements GeoKMeansPredictParams { private static final long serialVersionUID = 2666865290298374230L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GeoKMeansTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GeoKMeansTrainBatchOp.java index d8c2992cc..bd2998d24 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GeoKMeansTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GeoKMeansTrainBatchOp.java @@ -8,6 +8,7 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -24,6 +25,7 @@ import com.alibaba.alink.operator.common.distance.HaversineDistance; import com.alibaba.alink.params.clustering.GeoKMeansTrainParams; import com.alibaba.alink.params.shared.clustering.HasKMeansWithHaversineDistanceType; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import static com.alibaba.alink.operator.batch.clustering.KMeansTrainBatchOp.iterateICQ; import static com.alibaba.alink.operator.common.clustering.kmeans.KMeansInitCentroids.initKmeansCentroids; @@ -38,6 +40,8 @@ @ParamSelectColumnSpec(name = "latitudeCol", portIndices = 0, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}) @ParamSelectColumnSpec(name = "longitudeCol", portIndices = 0, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}) @NameCn("经纬度K均值聚类训练") +@NameEn("Geo KMeans Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.clustering.GeoKMeans") public final class GeoKMeansTrainBatchOp extends BatchOperator implements GeoKMeansTrainParams { private static final long serialVersionUID = 1190784726768283432L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GmmModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GmmModelInfoBatchOp.java index afced3fd1..a27d6b9d7 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GmmModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GmmModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.common.linalg.DenseMatrix; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.operator.common.clustering.ClusteringModelInfo; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GmmTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GmmTrainBatchOp.java index d24ed2223..de03be1b7 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GmmTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GmmTrainBatchOp.java @@ -27,7 +27,7 @@ import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.exceptions.AkIllegalStateException; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.linalg.DenseMatrix; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; @@ -39,10 +39,11 @@ import com.alibaba.alink.operator.common.clustering.GmmModelDataConverter; import com.alibaba.alink.operator.common.clustering.kmeans.KMeansInitCentroids; import com.alibaba.alink.operator.common.dataproc.FirstReducer; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; import com.alibaba.alink.operator.common.statistics.basicstatistic.MultivariateGaussian; import com.alibaba.alink.params.clustering.GmmTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -69,6 +70,7 @@ @ParamSelectColumnSpec(name = "vectorCol", portIndices = 0, allowedTypeCollections = {TypeCollections.VECTOR_TYPES}) @NameCn("高斯混合模型训练") @NameEn("GMM Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.clustering.GaussianMixture") public final class GmmTrainBatchOp extends BatchOperator implements GmmTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupDbscanBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupDbscanBatchOp.java new file mode 100644 index 000000000..70e77597b --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupDbscanBatchOp.java @@ -0,0 +1,221 @@ +package com.alibaba.alink.operator.batch.clustering; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.ReservedColsWithSecondInputSpec; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.utils.RowUtil; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.dataproc.AppendIdBatchOp; +import com.alibaba.alink.operator.common.clustering.DistanceType; +import com.alibaba.alink.operator.common.clustering.dbscan.DbscanConstant; +import com.alibaba.alink.operator.common.clustering.dbscan.DbscanNewSample; +import com.alibaba.alink.operator.common.distance.FastDistance; +import com.alibaba.alink.operator.common.distance.FastDistanceVectorData; +import com.alibaba.alink.params.clustering.GroupDbscanParams; +import com.alibaba.alink.params.clustering.HasLatitudeCol; +import com.alibaba.alink.params.clustering.HasLongitudeCol; +import com.alibaba.alink.params.shared.colname.HasPredictionCol; +import org.apache.commons.lang3.ArrayUtils; + +import java.util.Arrays; + +/** + * @author guotao.gt + */ +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = { + @PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT), +}) +@ReservedColsWithSecondInputSpec +@ParamSelectColumnSpec(name = "featureCols", portIndices = 0, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}) +@ParamSelectColumnSpec(name = "groupCols", portIndices = 0) +@ParamSelectColumnSpec(name = "idCol", portIndices = 0) +@NameCn("分组Dbscan") +@NameEn("Group Dbscan") +public class GroupDbscanBatchOp extends BatchOperator + implements GroupDbscanParams { + + private static final long serialVersionUID = 2259660296918166445L; + + public GroupDbscanBatchOp() { + this(new Params()); + } + + public GroupDbscanBatchOp(Params params) { + super(params); + } + + @Override + public GroupDbscanBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + + if (!this.getParams().contains(HasPredictionCol.PREDICTION_COL)) { + this.setPredictionCol("cluster_id"); + } + if (!this.getParams().contains(GroupDbscanParams.ID_COL)) { + this.setIdCol(AppendIdBatchOp.appendIdColName); + } + final Boolean isOutputVector = getParams().get(IS_OUTPUT_VECTOR); + final DistanceType distanceType = this.getDistanceType(); + final String latitudeColName = getParams().contains(HasLatitudeCol.LATITUDE_COL) ? + getParams().get(HasLatitudeCol.LATITUDE_COL) : null; + final String longitudeColName = getParams().contains(HasLongitudeCol.LONGITUDE_COL) ? + getParams().get(HasLongitudeCol.LONGITUDE_COL) : null; + final String[] featureColNames = DistanceType.HAVERSINE.equals(distanceType) + && (latitudeColName != null) + && (longitudeColName != null) ? + new String[] {latitudeColName, longitudeColName} : this.get(FEATURE_COLS); + + final int minPoints = getParams().get(MIN_POINTS); + final Double epsilon = getParams().get(EPSILON); + final String idCol = getParams().get(ID_COL); + final String predResultColName = this.getPredictionCol(); + + // groupColNames + final String[] groupColNames = this.getGroupCols(); + + Preconditions.checkArgument(distanceType != DistanceType.JACCARD, "Not support Jaccard Distance!"); + FastDistance distance = distanceType.getFastDistance(); + + if (distanceType.equals(DistanceType.HAVERSINE)) { + if (!(featureColNames != null && featureColNames.length == 2)) { + if ((latitudeColName == null || longitudeColName == null || latitudeColName.isEmpty() + || longitudeColName.isEmpty())) { + throw new RuntimeException("latitudeColName and longitudeColName should be set !"); + } + } + } else { + if ((featureColNames == null || featureColNames.length == 0)) { + throw new RuntimeException("featureColNames should be set !"); + } + } + + for (String groupColName : groupColNames) { + if (TableUtil.findColIndex(featureColNames, groupColName) >= 0) { + throw new RuntimeException("groupColNames should NOT be included in featureColNames!"); + } + } + + if (null == idCol || "".equals(idCol)) { + throw new RuntimeException("idCol column should be set!"); + } else if (TableUtil.findColIndex(featureColNames, idCol) >= 0) { + throw new RuntimeException("idCol column should NOT be included in featureColNames !"); + } else if (TableUtil.findColIndex(groupColNames, idCol) >= 0) { + throw new RuntimeException("idCol column should NOT be included in groupColNames !"); + } + final int dim = featureColNames.length; + + String[] selectedColNames = ArrayUtils.addAll(ArrayUtils.addAll(groupColNames, idCol), featureColNames); + String[] outputColNames = ArrayUtils.addAll( + ArrayUtils.addAll(groupColNames, idCol, DbscanConstant.TYPE, predResultColName)); + if (isOutputVector) { + outputColNames = ArrayUtils.add(outputColNames, DbscanConstant.FEATURE_COL_NAMES); + } else { + outputColNames = ArrayUtils.addAll(outputColNames, featureColNames); + } + + TypeInformation[] outputColTypes = new TypeInformation[outputColNames.length]; + Arrays.fill(outputColTypes, Types.STRING); + outputColTypes[groupColNames.length + 2] = Types.LONG; + if (!isOutputVector) { + Arrays.fill(outputColTypes, groupColNames.length + 3, outputColTypes.length, Types.DOUBLE); + } + + final TableSchema outputSchema = new TableSchema(outputColNames, outputColTypes); + + final int groupMaxSamples = getGroupMaxSamples(); + final boolean skip = getSkip(); + + DataSet rowDataSet = in.select(selectedColNames).getDataSet() + .map(new mapToDataVectorSample(dim, groupColNames.length, distance)) + .groupBy(new GroupGeoDbscanBatchOp.SelectGroup()) + .reduceGroup(new GroupGeoDbscanBatchOp.Clustering(epsilon, minPoints, distance, groupMaxSamples, skip)) + .map(new MapToRow(isOutputVector, outputColNames.length, groupColNames.length)); + + this.setOutput(rowDataSet, outputSchema); + return this; + } + + public static class mapToDataVectorSample extends RichMapFunction { + private static final long serialVersionUID = -6733405177253139009L; + private int dim; + private int groupColNamesSize; + private FastDistance distance; + + public mapToDataVectorSample(int dim, int groupColNamesSize, FastDistance distance) { + this.dim = dim; + this.groupColNamesSize = groupColNamesSize; + this.distance = distance; + } + + @Override + public DbscanNewSample map(Row row) throws Exception { + Row keep = new Row(row.getArity() - dim); + + String[] groupColNames = new String[groupColNamesSize]; + for (int i = 0; i < groupColNamesSize; i++) { + groupColNames[i] = row.getField(i).toString(); + keep.setField(i, groupColNames[i]); + } + keep.setField(groupColNamesSize, row.getField(groupColNamesSize).toString()); + + double[] values = new double[dim]; + for (int i = 0; i < values.length; i++) { + values[i] = ((Number) row.getField(i + groupColNamesSize + 1)).doubleValue(); + } + DenseVector vec = new DenseVector(values); + FastDistanceVectorData vector = distance.prepareVectorData(Tuple2.of(vec, keep)); + return new DbscanNewSample(vector, groupColNames); + } + } + + public static class MapToRow extends RichMapFunction { + private static final long serialVersionUID = 4213429941831979236L; + private int rowArity; + private int groupColNamesSize; + private Boolean isOutputVector; + + public MapToRow(Boolean isOutputVector, int rowArity, int groupColNamesSize) { + this.isOutputVector = isOutputVector; + this.rowArity = rowArity; + this.groupColNamesSize = groupColNamesSize; + } + + @Override + public Row map(DbscanNewSample value) throws Exception { + Row row = new Row(rowArity - groupColNamesSize - 1); + row.setField(0, value.getType().name()); + row.setField(1, value.getClusterId()); + + DenseVector v = (DenseVector) value.getVec().getVector(); + if (isOutputVector) { + row.setField(2, v.toString()); + } else { + for (int i = 0; i < v.size(); i++) { + row.setField(i + 2, v.get(i)); + } + } + return RowUtil.merge(value.getVec().getRows()[0], row); + } + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupDbscanModelBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupDbscanModelBatchOp.java new file mode 100644 index 000000000..eb6c712f8 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupDbscanModelBatchOp.java @@ -0,0 +1,231 @@ +package com.alibaba.alink.operator.batch.clustering; + +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.ReservedColsWithSecondInputSpec; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.utils.RowUtil; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.clustering.DistanceType; +import com.alibaba.alink.operator.common.clustering.dbscan.DbscanCenter; +import com.alibaba.alink.operator.common.clustering.dbscan.DbscanConstant; +import com.alibaba.alink.operator.common.clustering.dbscan.DbscanNewSample; +import com.alibaba.alink.operator.common.distance.FastDistance; +import com.alibaba.alink.operator.common.distance.FastDistanceVectorData; +import com.alibaba.alink.params.clustering.GroupDbscanModelParams; +import com.alibaba.alink.params.shared.colname.HasPredictionCol; +import org.apache.commons.lang3.ArrayUtils; +import scala.util.hashing.MurmurHash3; + +import java.util.Iterator; + +/** + * @author guotao.gt + */ +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = { + @PortSpec(value = PortType.MODEL, desc = PortDesc.MODEL_INFO), +}) +@ReservedColsWithSecondInputSpec +@ParamSelectColumnSpec(name = "featureCols", portIndices = 0, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}) +@ParamSelectColumnSpec(name = "groupCols", portIndices = 0) +@NameCn("分组Dbscan模型") +@NameEn("Group Dbscan Model") +public final class GroupDbscanModelBatchOp extends BatchOperator + implements GroupDbscanModelParams { + + private static final long serialVersionUID = 5788206252024914272L; + + public GroupDbscanModelBatchOp() { + this(new Params()); + } + + public GroupDbscanModelBatchOp(Params params) { + super(params); + } + + @Override + public GroupDbscanModelBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + + if (!this.getParams().contains(HasPredictionCol.PREDICTION_COL)) { + this.setPredictionCol("cluster_id"); + } + final DistanceType distanceType = getDistanceType(); + final String[] featureColNames = this.getParams().get(FEATURE_COLS); + + final int minPoints = getParams().get(MIN_POINTS); + final Double epsilon = getParams().get(EPSILON); + final String predResultColName = this.getPredictionCol(); + + // groupColNames + final String[] groupColNames = getParams().get(GROUP_COLS); + + for (String groupColName : groupColNames) { + if (TableUtil.findColIndex(featureColNames, groupColName) >= 0) { + throw new RuntimeException("groupColNames should NOT be included in featureColNames!"); + } + } + + String[] selectedColNames = ArrayUtils.addAll(groupColNames, featureColNames); + + Preconditions.checkArgument(distanceType != DistanceType.JACCARD, "Not support %s!", distanceType.name()); + FastDistance distance = distanceType.getFastDistance(); + + final int dim = featureColNames.length; + + String[] outputColNames = ArrayUtils.addAll( + ArrayUtils.addAll(groupColNames, predResultColName, DbscanConstant.COUNT), featureColNames); + TypeInformation[] outputColTypes = ArrayUtils.addAll( + ArrayUtils.addAll(TableUtil.findColTypesWithAssert(in.getSchema(), groupColNames), Types.LONG, + Types.LONG), TableUtil.findColTypesWithAssertAndHint(in.getSchema(), featureColNames)); + + final TableSchema outputSchema = new TableSchema(outputColNames, outputColTypes); + + final int groupMaxSamples = getGroupMaxSamples(); + final boolean skip = getSkip(); + + DataSet rowDataSet = in.select(selectedColNames).getDataSet() + .map(new mapToDataSample(dim, groupColNames.length, distance)) + .groupBy(new GroupGeoDbscanBatchOp.SelectGroup()) + .reduceGroup(new GroupGeoDbscanBatchOp.Clustering(epsilon, minPoints, distance, groupMaxSamples, skip)) + .groupBy(new SelectGroupAndClusterID()) + .reduceGroup(new getClusteringCenter(dim, distanceType)) + .map(new MapToRow(outputColNames.length, groupColNames.length)); + + this.setOutput(rowDataSet, outputSchema); + + return this; + } + + public static class mapToDataSample implements MapFunction { + private static final long serialVersionUID = 1491814462425438888L; + private int dim; + private int groupColNamesSize; + private FastDistance distance; + + public mapToDataSample(int dim, int groupColNamesSize, FastDistance distance) { + this.dim = dim; + this.groupColNamesSize = groupColNamesSize; + this.distance = distance; + } + + @Override + public DbscanNewSample map(Row row) throws Exception { + String[] groupColNames = new String[groupColNamesSize]; + for (int i = 0; i < groupColNamesSize; i++) { + groupColNames[i] = row.getField(i).toString(); + } + + double[] values = new double[dim]; + for (int i = 0; i < values.length; i++) { + values[i] = (Double) row.getField(i + groupColNamesSize); + } + DenseVector vec = new DenseVector(values); + + Row keep = new Row(groupColNamesSize); + for (int i = 0; i < keep.getArity(); i++) { + keep.setField(i, row.getField(i)); + } + FastDistanceVectorData vector = distance.prepareVectorData(Tuple2.of(vec, keep)); + + return new DbscanNewSample(vector, groupColNames); + } + } + + public static class getClusteringCenter + implements GroupReduceFunction > { + private static final long serialVersionUID = 6317085010066332931L; + private int dim; + private DistanceType distanceType; + + public getClusteringCenter(int dim, DistanceType distanceType) { + this.dim = dim; + this.distanceType = distanceType; + } + + @Override + public void reduce(Iterable values, Collector > out) + throws Exception { + Iterator iterator = values.iterator(); + long clusterId = 0; + Row groupColNames = null; + int count = 0; + DenseVector vector = new DenseVector(dim); + if (iterator.hasNext()) { + DbscanNewSample sample = iterator.next(); + clusterId = sample.getClusterId(); + groupColNames = sample.getVec().getRows()[0]; + vector.plusEqual(sample.getVec().getVector()); + count++; + } + // exclude the NOISE + if (clusterId > Integer.MIN_VALUE) { + while (iterator.hasNext()) { + vector.plusEqual(iterator.next().getVec().getVector()); + count++; + } + + vector.scaleEqual(1.0 / count); + + DbscanCenter dbscanCenter = new DbscanCenter (groupColNames, clusterId, + count, vector); + out.collect(dbscanCenter); + } + } + } + + public static class MapToRow implements MapFunction , Row> { + private static final long serialVersionUID = -5480092592936407825L; + private int rowArity; + private int groupColNamesSize; + + public MapToRow(int rowArity, int groupColNamesSize) { + this.rowArity = rowArity; + this.groupColNamesSize = groupColNamesSize; + } + + @Override + public Row map(DbscanCenter value) throws Exception { + Row row = new Row(rowArity - groupColNamesSize); + DenseVector denseVector = value.getValue(); + row.setField(0, value.getClusterId()); + row.setField(1, value.getCount()); + for (int i = 0; i < denseVector.size(); i++) { + row.setField(i + 2, denseVector.get(i)); + } + return RowUtil.merge(value.getGroupColNames(), row); + } + } + + public class SelectGroupAndClusterID implements KeySelector { + private static final long serialVersionUID = -8204871256389225863L; + + @Override + public Integer getKey(DbscanNewSample w) { + return new MurmurHash3().arrayHash(new Integer[] {(int) w.getClusterId(), w.getGroupHashKey()}, 0); + } + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupEmBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupEmBatchOp.java new file mode 100644 index 000000000..ba749e583 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupEmBatchOp.java @@ -0,0 +1,244 @@ +package com.alibaba.alink.operator.batch.clustering; + +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.clustering.LocalKMeans; +import com.alibaba.alink.operator.common.clustering.common.Sample; +import com.alibaba.alink.operator.common.distance.ContinuousDistance; +import com.alibaba.alink.params.clustering.GroupKMeansParams; +import com.alibaba.alink.params.shared.colname.HasPredictionCol; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.api.Types; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import java.util.ArrayList; +import java.util.List; + +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = { + @PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT), +}) +@ParamSelectColumnSpec(name = "featureCols", portIndices = 0, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}) +@ParamSelectColumnSpec(name = "groupCols", portIndices = 0) +@ParamSelectColumnSpec(name = "idCol", portIndices = 0) +@NameCn("分组EM") +@NameEn("Group EM") +public final class GroupEmBatchOp extends BatchOperator + implements GroupKMeansParams { + + private static final long serialVersionUID = 2403292854593151120L; + + public GroupEmBatchOp() { + super(null); + } + + public GroupEmBatchOp(Params params) { + super(params); + } + + @Override + public GroupEmBatchOp linkFrom(BatchOperator... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + + if(!this.getParams().contains(HasPredictionCol.PREDICTION_COL)){ + this.setPredictionCol("cluster_id"); + } + final String[] featureColNames = (this.getParams().contains(FEATURE_COLS) + && this.getFeatureCols() != null + && this.getFeatureCols().length > 0) ? + this.getFeatureCols() : TableUtil.getNumericCols(in.getSchema()); + final int k = this.getK(); + final double epsilon = this.getEpsilon(); + final int maxIter = this.getMaxIter(); + final DistanceType distanceType = getDistanceType(); + final String[] groupColNames = this.getGroupCols(); + final String idCol = this.getIdCol(); + final String predResultColName = this.getPredictionCol(); + + ContinuousDistance distance = distanceType.getFastDistance(); + + for (String groupColName : groupColNames) { + if (TableUtil.findColIndex(featureColNames, groupColName) >= 0) { + throw new RuntimeException("groupColNames should NOT be included in featureColNames!"); + } + } + + if (null == idCol || "".equals(idCol)) { + throw new RuntimeException("idCol column should be set!"); + } else if (TableUtil.findColIndex(featureColNames, idCol) >= 0) { + throw new RuntimeException("idCol column should NOT be included in featureColNames !"); + } else if (TableUtil.findColIndex(groupColNames, idCol) >= 0) { + throw new RuntimeException("idCol column should NOT be included in groupColNames !"); + } + + StringBuilder sbd = new StringBuilder(); + for (String groupColName : groupColNames) { + sbd.append("cast(`").append(groupColName).append("` as VARCHAR) as `").append(groupColName).append("`, "); + } + sbd.append("cast(`").append(idCol).append("` as VARCHAR) as `").append(idCol).append("`, "); + for (int i = 0; i < featureColNames.length; i++) { + if (i > 0) { + sbd.append(", "); + } + sbd.append("cast(`") + .append(featureColNames[i]) + .append("` as double) as `") + .append(featureColNames[i]) + .append("`"); + } + + final int dim = featureColNames.length; + + List columnNames = new ArrayList <>(); + for (String groupColName : groupColNames) { + columnNames.add(groupColName); + } + columnNames.add(idCol); + columnNames.add(predResultColName); + for (String col : featureColNames) { + columnNames.add(col); + } + + List columnTypes = new ArrayList <>(); + for (String groupColName : groupColNames) { + columnTypes.add(Types.STRING()); + } + columnTypes.add(Types.STRING()); + columnTypes.add(Types.LONG()); + for (String col : featureColNames) { + columnTypes.add(Types.DOUBLE()); + } + + final TableSchema outputSchema = new TableSchema( + columnNames.toArray(new String[columnNames.size()]), + columnTypes.toArray(new TypeInformation[columnTypes.size()]) + ); + + try { + DataSet rowDataSet = in.select(sbd.toString()).getDataSet() + .map(new mapToDataSample(dim, groupColNames.length)) + .groupBy(new SelectGroup()) + .reduceGroup(new Clustering(k, epsilon, maxIter, dim, distance)) + .map(new MapToRow(columnNames.size(), groupColNames.length)); + + this.setOutput(rowDataSet, outputSchema); + } catch (Exception ex) { + ex.printStackTrace(); + throw new RuntimeException(ex); + } + + return this; + } + + public static class mapToDataSample implements MapFunction { + private static final long serialVersionUID = -1252574650762802849L; + private int dim; + private int groupColNamesSize; + + public mapToDataSample(int dim, int groupColNamesSize) { + this.dim = dim; + this.groupColNamesSize = groupColNamesSize; + } + + @Override + public Sample map(Row row) throws Exception { + List groupColNames = new ArrayList <>(); + for (int i = 0; i < groupColNamesSize; i++) { + if(null == row.getField(i)){ + throw new RuntimeException("There is NULL value in group col!"); + } + groupColNames.add((String) row.getField(i)); + } + + if(null == row.getField(groupColNamesSize)){ + throw new RuntimeException("There is NULL value in id col!"); + } + String idColValue = (String) row.getField(groupColNamesSize); + + double[] values = new double[dim]; + for (int i = 0; i < values.length; i++) { + if(null == row.getField(i + groupColNamesSize + 1)){ + throw new RuntimeException("There is NULL value in value col!"); + } + values[i] = (Double) row.getField(i + groupColNamesSize + 1); + } + return new Sample(idColValue, new DenseVector(values), -1, + groupColNames.toArray(new String[groupColNamesSize])); + } + } + + public static class Clustering implements GroupReduceFunction { + private static final long serialVersionUID = -6401148777324895859L; + private int k; + private double epsilon; + private ContinuousDistance distance; + private int maxIter; + private int dim; + + public Clustering(int k, double epsilon, int maxIter, int dim, ContinuousDistance distance) { + this.epsilon = epsilon; + this.k = k; + this.distance = distance; + this.maxIter = maxIter; + this.dim = dim; + } + + @Override + public void reduce(Iterable values, Collector out) throws Exception { + LocalKMeans.clustering(values, out, k, epsilon, maxIter, distance); + } + } + + public static class MapToRow implements MapFunction { + private static final long serialVersionUID = 5045205789035382392L; + private int rowArity; + private int groupColNamesSize; + + public MapToRow(int rowArity, int groupColNamesSize) { + this.rowArity = rowArity; + this.groupColNamesSize = groupColNamesSize; + } + + @Override + public Row map(Sample value) throws Exception { + Row row = new Row(rowArity); + DenseVector denseVector = value.getVector(); + for (int i = 0; i < groupColNamesSize; i++) { + row.setField(i, value.getGroupColNames()[i]); + } + row.setField(groupColNamesSize, value.getSampleId()); + row.setField(groupColNamesSize + 1, value.getClusterId()); + for (int i = 0; i < denseVector.size(); i++) { + row.setField(i + groupColNamesSize + 2, denseVector.get(i)); + } + return row; + } + } + + public class SelectGroup implements KeySelector { + private static final long serialVersionUID = 4582197026301874450L; + + @Override + public String getKey(Sample w) { + return w.getGroupColNamesString(); + } + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupGeoDbscanBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupGeoDbscanBatchOp.java new file mode 100644 index 000000000..322f5a3b4 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupGeoDbscanBatchOp.java @@ -0,0 +1,257 @@ +package com.alibaba.alink.operator.batch.clustering; + +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.common.functions.RichGroupReduceFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.ReservedColsWithFirstInputSpec; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.utils.RowUtil; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.clustering.dbscan.DbscanConstant; +import com.alibaba.alink.operator.common.clustering.dbscan.DbscanNewSample; +import com.alibaba.alink.operator.common.clustering.dbscan.Type; +import com.alibaba.alink.operator.common.distance.FastDistance; +import com.alibaba.alink.operator.common.distance.FastDistanceVectorData; +import com.alibaba.alink.operator.common.distance.HaversineDistance; +import com.alibaba.alink.params.clustering.GroupGeoDbscanParams; +import org.apache.commons.lang3.ArrayUtils; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static com.alibaba.alink.operator.common.clustering.dbscan.Dbscan.UNCLASSIFIED; +import static com.alibaba.alink.operator.common.clustering.dbscan.Dbscan.expandCluster; + +/** + * @author guotao.gt + */ +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = { + @PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT), +}) +@ParamSelectColumnSpec(name = "idCol", portIndices = 0) +@ParamSelectColumnSpec(name = "groupCols", portIndices = 0) +@ParamSelectColumnSpec(name = "latitudeCol", portIndices = 0, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}) +@ParamSelectColumnSpec(name = "longitudeCol", portIndices = 0, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}) +@ReservedColsWithFirstInputSpec +@NameCn("分组经纬度Dbscan") +@NameEn("Group Geo Dbscan") + +public class GroupGeoDbscanBatchOp extends BatchOperator + implements GroupGeoDbscanParams { + + private static final long serialVersionUID = -1650606375272968610L; + + public GroupGeoDbscanBatchOp() { + this(null); + } + + public GroupGeoDbscanBatchOp(Params params) { + super(params); + } + + @Override + public GroupGeoDbscanBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + final String latitudeColName = getLatitudeCol(); + final String longitudeColName = getLongitudeCol(); + final int minPoints = getMinPoints(); + final Double epsilon = getParams().get(EPSILON); + final String idCol = getIdCol(); + final String predResultColName = getPredictionCol(); + final int groupMaxSamples = getGroupMaxSamples(); + final boolean skip = getSkip(); + final String[] keepColNames = getReservedCols(); + + FastDistance distance = new HaversineDistance(); + // groupColNames + final String[] groupColNames = this.getGroupCols(); + + if (null == idCol || "".equals(idCol)) { + throw new RuntimeException("idCol column should be set!"); + } else if (TableUtil.findColIndex(groupColNames, idCol) >= 0) { + throw new RuntimeException("idCol column should NOT be included in groupColNames !"); + } + + String[] resNames = null == keepColNames ? new String[] {DbscanConstant.TYPE, predResultColName} : + ArrayUtils.addAll(new String[] {DbscanConstant.TYPE, predResultColName}, keepColNames); + TypeInformation[] resTypes = null == keepColNames ? new TypeInformation[] {Types.STRING, Types.LONG} : + ArrayUtils.addAll(new TypeInformation[] {Types.STRING, Types.LONG}, + TableUtil.findColTypesWithAssertAndHint(in.getSchema(), keepColNames)); + + String[] columnNames = ArrayUtils.addAll(groupColNames, latitudeColName, longitudeColName); + columnNames = null == keepColNames ? columnNames : ArrayUtils.addAll(columnNames, keepColNames); + + DataSet res = in.select(columnNames).getDataSet() + .map(new mapToDataVectorSample(groupColNames.length, distance)) + .groupBy(new SelectGroup()) + .withPartitioner(new WeightPartitioner()) + .reduceGroup(new Clustering(epsilon, minPoints, distance, groupMaxSamples, skip)) + .map(new MapToRow()); + this.setOutput(res, new TableSchema(resNames, resTypes)); + + return this; + } + + public static class WeightPartitioner implements Partitioner { + private static final long serialVersionUID = -4197634749052990621L; + + @Override + public int partition(Integer key, int numPartitions) { + return Math.abs(key) % numPartitions; + } + } + + public static class mapToDataVectorSample extends RichMapFunction { + private static final long serialVersionUID = -9186022939852072237L; + private int groupColNamesSize; + private FastDistance distance; + + public mapToDataVectorSample(int groupColNamesSize, FastDistance distance) { + this.groupColNamesSize = groupColNamesSize; + this.distance = distance; + } + + @Override + public DbscanNewSample map(Row row) throws Exception { + String[] groupColNames = new String[groupColNamesSize]; + for (int i = 0; i < groupColNamesSize; i++) { + groupColNames[i] = row.getField(i).toString(); + } + + DenseVector vector = new DenseVector(2); + vector.set(0, ((Number) row.getField(groupColNamesSize)).doubleValue()); + vector.set(1, ((Number) row.getField(groupColNamesSize + 1)).doubleValue()); + + Row keep = new Row(row.getArity() - groupColNamesSize - 2); + for (int i = 0; i < keep.getArity(); i++) { + keep.setField(i, row.getField(groupColNamesSize + 2 + i)); + } + + FastDistanceVectorData data = distance.prepareVectorData(Tuple2.of(vector, keep)); + + return new DbscanNewSample(data, groupColNames); + } + } + + public static class SelectGroup implements KeySelector { + private static final long serialVersionUID = 6268163376874147254L; + + @Override + public Integer getKey(DbscanNewSample w) { + return w.getGroupHashKey(); + } + } + + public static class Clustering extends RichGroupReduceFunction { + private static final long serialVersionUID = 3474119012459738732L; + private double epsilon; + private int minPoints; + private FastDistance baseDistance; + private int groupMaxSamples; + private boolean skip; + + public Clustering(double epsilon, int minPoints, FastDistance baseDistance, int groupMaxSamples, + boolean skip) { + this.epsilon = epsilon; + this.minPoints = minPoints; + this.baseDistance = baseDistance; + this.groupMaxSamples = groupMaxSamples; + this.skip = skip; + } + + @Override + public void reduce(Iterable values, Collector out) throws Exception { + int clusterId = 0; + List samples = new ArrayList <>(); + for (DbscanNewSample sample : values) { + samples.add(sample); + } + List abandon = null; + if (samples.size() >= groupMaxSamples) { + if (skip) { + return; + } else { + Collections.shuffle(samples); + List selected = new ArrayList <>(groupMaxSamples); + abandon = new ArrayList <>(); + for (int i = 0; i < samples.size(); i++) { + if (i < groupMaxSamples) { + selected.add(samples.get(i)); + } else { + abandon.add(samples.get(i)); + } + } + samples = selected; + } + } + + for (DbscanNewSample dbscanSample : samples) { + if (dbscanSample.getClusterId() == UNCLASSIFIED) { + if (expandCluster(samples, dbscanSample, clusterId, epsilon, minPoints, baseDistance)) { + clusterId++; + } + } + } + + for (DbscanNewSample dbscanSample : samples) { + out.collect(dbscanSample); + } + + //deal with the abandon sample + if (null != abandon) { + for (DbscanNewSample sample : abandon) { + double d = Double.POSITIVE_INFINITY; + + for (DbscanNewSample dbscanNewSample : samples) { + if (dbscanNewSample.getType().equals(Type.CORE)) { + double distance = baseDistance.calc(dbscanNewSample.getVec(), sample.getVec()).get(0, 0); + if (distance < d) { + sample.setClusterId(dbscanNewSample.getClusterId()); + d = distance; + } + } + } + if (d > epsilon) { + sample.setType(Type.NOISE); + sample.setClusterId(Integer.MIN_VALUE); + } else { + sample.setType(Type.LINKED); + } + out.collect(sample); + } + } + } + } + + public static class MapToRow extends RichMapFunction { + private static final long serialVersionUID = 5024255660037882136L; + + @Override + public Row map(DbscanNewSample value) throws Exception { + return RowUtil.merge(Row.of(value.getType().name(), value.getClusterId()), value.getVec().getRows()[0]); + } + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupGeoDbscanModelBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupGeoDbscanModelBatchOp.java new file mode 100644 index 000000000..bd5555d99 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupGeoDbscanModelBatchOp.java @@ -0,0 +1,173 @@ +package com.alibaba.alink.operator.batch.clustering; + +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.utils.RowUtil; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.clustering.dbscan.DbscanConstant; +import com.alibaba.alink.operator.common.clustering.dbscan.DbscanNewSample; +import com.alibaba.alink.operator.common.distance.FastDistance; +import com.alibaba.alink.operator.common.distance.FastDistanceVectorData; +import com.alibaba.alink.operator.common.distance.HaversineDistance; +import com.alibaba.alink.params.clustering.GroupGeoDbscanModelParams; +import org.apache.commons.lang3.ArrayUtils; +import scala.util.hashing.MurmurHash3; + +import java.util.Iterator; + +/** + * @author guotao.gt + */ +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = { + @PortSpec(value = PortType.MODEL, desc = PortDesc.MODEL_INFO), +}) +@ParamSelectColumnSpec(name = "groupCols", portIndices = 0) +@ParamSelectColumnSpec(name = "latitudeCol", portIndices = 0, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}) +@ParamSelectColumnSpec(name = "longitudeCol", portIndices = 0, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}) +@NameCn("分组经纬度Dbscan模型") +@NameEn("Group Geo Dbscan Model") +public class GroupGeoDbscanModelBatchOp extends BatchOperator + implements GroupGeoDbscanModelParams { + + private static final long serialVersionUID = 6424042392598453910L; + + public GroupGeoDbscanModelBatchOp() { + this(null); + } + + public GroupGeoDbscanModelBatchOp(Params params) { + super(params); + } + + @Override + public GroupGeoDbscanModelBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + final String latitudeColName = getLatitudeCol(); + final String longitudeColName = getLongitudeCol(); + final int minPoints = getMinPoints(); + final Double epsilon = getParams().get(EPSILON); + final int groupMaxSamples = getGroupMaxSamples(); + final boolean skip = getSkip(); + + FastDistance distance = new HaversineDistance(); + // groupColNames + final String[] groupColNames = getParams().get(GROUP_COLS); + + if (null == groupColNames) { + throw new RuntimeException("groupColNames should not be null!"); + } + + String[] resNames = ArrayUtils.addAll( + new String[] {DbscanConstant.TYPE, "count", latitudeColName, longitudeColName}, groupColNames); + TypeInformation[] resTypes = ArrayUtils.addAll( + new TypeInformation[] {AlinkTypes.LONG, AlinkTypes.LONG, AlinkTypes.DOUBLE, AlinkTypes.DOUBLE}, + TableUtil.findColTypesWithAssertAndHint(in.getSchema(), groupColNames)); + + String[] columnNames = ArrayUtils.addAll(groupColNames, latitudeColName, longitudeColName); + + DataSet res = in.select(columnNames).getDataSet() + .map(new mapToDataVectorSample(groupColNames.length, distance)) + .groupBy(new GroupGeoDbscanBatchOp.SelectGroup()) + .withPartitioner(new GroupGeoDbscanBatchOp.WeightPartitioner()) + .reduceGroup( + new GroupGeoDbscanBatchOp.Clustering(epsilon, minPoints, distance, groupMaxSamples, skip)) + .groupBy(new SelectGroupCluster()) + .reduceGroup(new getClusteringCenter()); + + this.setOutput(res, new TableSchema(resNames, resTypes)); + + return this; + } + + public static class mapToDataVectorSample extends RichMapFunction { + private static final long serialVersionUID = -718882540657567670L; + private int groupColNamesSize; + private FastDistance distance; + + public mapToDataVectorSample(int groupColNamesSize, FastDistance distance) { + this.groupColNamesSize = groupColNamesSize; + this.distance = distance; + } + + @Override + public DbscanNewSample map(Row row) throws Exception { + String[] groupColNames = new String[groupColNamesSize]; + for (int i = 0; i < groupColNamesSize; i++) { + groupColNames[i] = row.getField(i).toString(); + } + DenseVector vector = new DenseVector(2); + vector.set(0, ((Number) row.getField(groupColNamesSize)).doubleValue()); + vector.set(1, ((Number) row.getField(groupColNamesSize + 1)).doubleValue()); + + Row keep = new Row(groupColNamesSize); + for (int i = 0; i < keep.getArity(); i++) { + keep.setField(i, row.getField(i)); + } + FastDistanceVectorData vec = distance.prepareVectorData(Tuple2.of(vector, keep)); + return new DbscanNewSample(vec, groupColNames); + } + } + + public static class SelectGroupCluster implements KeySelector { + private static final long serialVersionUID = 3160327441213761977L; + + @Override + public Integer getKey(DbscanNewSample w) { + String[] key = new String[] {String.valueOf(w.getGroupHashKey()), String.valueOf(w.getClusterId())}; + return new MurmurHash3().arrayHash(key, 0); + } + } + + public static class getClusteringCenter implements + GroupReduceFunction { + private static final long serialVersionUID = -1965967509192777460L; + + @Override + public void reduce(Iterable values, Collector out) throws Exception { + Iterator iterator = values.iterator(); + long clusterId = Integer.MIN_VALUE; + Row groupColNames = null; + long count = 0; + DenseVector vector = new DenseVector(2); + if (iterator.hasNext()) { + DbscanNewSample sample = iterator.next(); + groupColNames = sample.getVec().getRows()[0]; + clusterId = sample.getClusterId(); + vector.plusEqual((DenseVector) sample.getVec().getVector()); + count++; + } + // exclude the NOISE + if (clusterId > Integer.MIN_VALUE) { + while (iterator.hasNext()) { + vector.plusEqual((DenseVector) iterator.next().getVec().getVector()); + count++; + } + + vector.scaleEqual(1.0 / count); + out.collect(RowUtil.merge(Row.of(clusterId, count, vector.get(0), vector.get(1)), groupColNames)); + } + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupKMeansBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupKMeansBatchOp.java new file mode 100644 index 000000000..3a829a27a --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/GroupKMeansBatchOp.java @@ -0,0 +1,725 @@ +package com.alibaba.alink.operator.batch.clustering; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.operators.IterativeDataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.NumberSequenceIterator; +import org.apache.flink.util.Preconditions; + +import com.alibaba.alink.common.MLEnvironmentFactory; +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.common.comqueue.IterTaskObjKeeper; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.clustering.common.Sample; +import com.alibaba.alink.operator.common.distance.ContinuousDistance; +import com.alibaba.alink.operator.common.evaluation.EvaluationUtil; +import com.alibaba.alink.params.clustering.GroupKMeansParams; +import org.apache.commons.lang3.ArrayUtils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; + +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = { + @PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT), +}) +@ParamSelectColumnSpec(name = "featureCols", portIndices = 0, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}) +@ParamSelectColumnSpec(name = "groupCols", portIndices = 0) +@ParamSelectColumnSpec(name = "idCol", portIndices = 0) +@NameCn("分组Kmeans") +@NameEn("Group Kmeans") +public final class GroupKMeansBatchOp extends BatchOperator + implements GroupKMeansParams { + + public GroupKMeansBatchOp() { + super(null); + } + + public GroupKMeansBatchOp(Params params) { + super(params); + } + + @Override + public GroupKMeansBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + final String[] featureColNames = this.getFeatureCols(); + final int k = this.getK(); + final double epsilon = this.getEpsilon(); + final int maxIter = this.getMaxIter(); + final DistanceType distanceType = getDistanceType(); + final String[] groupColNames = this.getGroupCols(); + final String idCol = this.getIdCol(); + final ContinuousDistance distance = distanceType.getFastDistance(); + + if ((featureColNames == null || featureColNames.length == 0)) { + throw new RuntimeException("featureColNames should be set !"); + } + for (String groupColName : groupColNames) { + if (TableUtil.findColIndex(featureColNames, groupColName) >= 0) { + throw new RuntimeException("groupColNames should NOT be included in featureColNames!"); + } + } + if (null == idCol || "".equals(idCol)) { + throw new RuntimeException("idCol column should be set!"); + } else if (TableUtil.findColIndex(featureColNames, idCol) >= 0) { + throw new RuntimeException("idCol column should NOT be included in featureColNames !"); + } else if (TableUtil.findColIndex(groupColNames, idCol) >= 0) { + throw new RuntimeException("idCol column should NOT be included in groupColNames !"); + } + + final String[] outputCols = ArrayUtils.addAll(groupColNames, idCol, getPredictionCol()); + final TypeInformation [] outputTypes = ArrayUtils.addAll( + TableUtil.findColTypesWithAssertAndHint(in.getSchema(), groupColNames), + TableUtil.findColTypeWithAssertAndHint(in.getSchema(), idCol), Types.LONG); + + String[] inputColNames = in.getColNames(); + final int[] groupNameIndices = TableUtil.findColIndices(inputColNames, groupColNames); + final int idColIndex = TableUtil.findColIndex(inputColNames, idCol); + final int[] featureColIndex = TableUtil.findColIndices(inputColNames, featureColNames); + DataSet inputSamples = in.getDataSet() + .map(new MapRowToSample(groupNameIndices, idColIndex, featureColIndex)); + + DataSet > groupsAndSizes = inputSamples.groupBy(new GroupNameKeySelector()).reduceGroup( + new ComputingGroupSizes()); + + final String broadcastGroupSizeKey = "groupAndSizeKey"; + final long partitionInfoHandle = IterTaskObjKeeper.getNewHandle(); + final long cacheDataHandle = IterTaskObjKeeper.getNewHandle(); + final long cacheModelHandle = IterTaskObjKeeper.getNewHandle(); + final long lossHandle = IterTaskObjKeeper.getNewHandle(); + + IterativeDataSet loopStart = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment() + .fromParallelCollection(new NumberSequenceIterator(1L, 2L), + BasicTypeInfo.LONG_TYPE_INFO) + .map(new MapLongToObject()) + .iterate(maxIter); + + DataSet > rebalancedSamples = inputSamples.mapPartition( + new RebalanceDataAndCachePartitionInfo(broadcastGroupSizeKey, partitionInfoHandle)) + .withBroadcastSet(loopStart, "loopStart") + .withBroadcastSet(groupsAndSizes, broadcastGroupSizeKey) + .name("rebalanceData"); + + DataSet > initModelWithTargetWorkerIds = rebalancedSamples + .partitionCustom( + new HashPartitioner(), 0) + .mapPartition(new CacheSamplesAndGenInitModel(cacheDataHandle, partitionInfoHandle, k)) + .name("cacheTrainingDataAndGetInitModel"); + + DataSet cachedInitModel = initModelWithTargetWorkerIds.partitionCustom( + new HashPartitioner(), 0) + .mapPartition(new CacheInitModel(cacheModelHandle, partitionInfoHandle, lossHandle)) + .name("cacheInitModel"); + + DataSet > modelUpdates = loopStart.mapPartition( + new ComputeUpdates(cacheDataHandle, cacheModelHandle, partitionInfoHandle, distance)) + .withBroadcastSet(cachedInitModel, "cacheInitModel") + .name("computeUpdates"); + + DataSet loopOutput = modelUpdates.partitionCustom( + new HashPartitioner(), 0) + .mapPartition(new UpdateModel(cacheModelHandle, lossHandle, epsilon)) + .name("updateModel"); + + DataSet loopEnd = loopStart.closeWith(loopOutput, loopOutput); + DataSet transformedDataset = loopEnd.mapPartition( + new OutputDataSamples(cacheDataHandle, partitionInfoHandle, cacheModelHandle, lossHandle)) + .withBroadcastSet(loopEnd, "iterationEnd") + .name("outputDataSamples"); + DataSet rowDataSet = transformedDataset + .map(new MapSampleToRow(groupColNames.length, outputTypes)); + this.setOutput(rowDataSet, new TableSchema(outputCols, outputTypes)); + return this; + } + + + private static class MapRowToSample implements MapFunction { + private final int[] groupNameIndices; + private final int idColIndex; + private final int[] featureColIndices; + + public MapRowToSample(int[] groupNameIndices, int idColIndex, int[] featureColIndices) { + this.groupNameIndices = groupNameIndices; + this.idColIndex = idColIndex; + this.featureColIndices = featureColIndices; + } + + @Override + public Sample map(Row row) throws Exception { + String[] groupColNames = new String[groupNameIndices.length]; + for (int i = 0; i < groupColNames.length; i ++) { + Object o = row.getField(groupNameIndices[i]); + Preconditions.checkNotNull(o, "There is NULL value in group col!"); + groupColNames[i] = o.toString(); + } + double[] values = new double[featureColIndices.length]; + for (int i = 0; i < values.length; i ++) { + Object o = row.getField(featureColIndices[i]); + Preconditions.checkNotNull(o, "There is NULL value in feature col!"); + values[i] = ((Number) o).doubleValue(); + } + Object o = row.getField(idColIndex); + Preconditions.checkNotNull(o, "There is NULL value in id col!"); + String idColValue = o.toString(); + return new Sample(idColValue, new DenseVector(values), -1, groupColNames); + } + } + + private static class GroupNameKeySelector implements KeySelector { + + @Override + public String getKey(Sample value) throws Exception { + return value.getGroupColNamesString(); + } + } + + /** + * Computes number of elements in each group. + */ + private static class ComputingGroupSizes implements GroupReduceFunction > { + + @Override + public void reduce(Iterable values, Collector > out) throws Exception { + String groupName = null; + long groupSize = 0; + + Iterator iterator = values.iterator(); + if (iterator.hasNext()) { + Sample sample = iterator.next(); + groupName = sample.getGroupColNamesString(); + groupSize++; + } + while (iterator.hasNext()) { + groupSize++; + iterator.next(); + } + out.collect(Tuple2.of(groupName, groupSize)); + } + } + + private static class MapLongToObject implements MapFunction { + @Override + public Object map(Long value) throws Exception { + return new Object(); + } + } + + /** + * Caches the partition information on each TM at superStep-1 and re-balances data at superStep-2. + */ + private static class RebalanceDataAndCachePartitionInfo + extends RichMapPartitionFunction > { + + private final String broadcastGroupSizeKey; + + private final long partitionInfoHandle; + + public RebalanceDataAndCachePartitionInfo(String broadcastGroupSizeKey, long partitionInfoHandle) { + this.broadcastGroupSizeKey = broadcastGroupSizeKey; + this.partitionInfoHandle = partitionInfoHandle; + } + + @Override + public void mapPartition(Iterable values, Collector > out) throws Exception { + int superStepNum = getIterationRuntimeContext().getSuperstepNumber(); + if (superStepNum == 1) { + List > groupsAndSizes = getRuntimeContext().getBroadcastVariable( + broadcastGroupSizeKey); + Map sizeByGroup = new HashMap <>(groupsAndSizes.size()); + for (Tuple2 groupAndSize : groupsAndSizes) { + sizeByGroup.put(groupAndSize.f0, groupAndSize.f1); + } + Map partitionInfos = getPartitionInfo(sizeByGroup, + getRuntimeContext().getNumberOfParallelSubtasks()); + IterTaskObjKeeper.put(partitionInfoHandle, getRuntimeContext().getIndexOfThisSubtask(), + partitionInfos); + } else if (superStepNum == 2) { + Map partitionInfos = IterTaskObjKeeper.get(partitionInfoHandle, + getRuntimeContext().getIndexOfThisSubtask()); + HashMap offsetByName = new HashMap <>(partitionInfos.size()); + for (String groupName : partitionInfos.keySet()) { + offsetByName.put(groupName, -1); + } + for (Sample value : values) { + String groupName = value.getGroupColNamesString(); + int[] possibleWorkerIds = partitionInfos.get(groupName); + int offset = offsetByName.compute(groupName, (k, v) -> (v + 1) % possibleWorkerIds.length); + int workerId = possibleWorkerIds[offset]; + out.collect(Tuple2.of(workerId, value)); + } + } + } + } + + /** + * Caches samples in static memory and send initModel to corresponding workers at superStep-2. + */ + private static class CacheSamplesAndGenInitModel + extends RichMapPartitionFunction , Tuple3 > { + + private final long cacheDataHandle; + + private final long partitionInfoHandle; + + private final int numClusters; + + public CacheSamplesAndGenInitModel(long cacheDataHandle, long partitionInfoHandle, int numClusters) { + this.cacheDataHandle = cacheDataHandle; + this.partitionInfoHandle = partitionInfoHandle; + this.numClusters = numClusters; + } + + @Override + public void mapPartition(Iterable > values, + Collector > out) + throws Exception { + int superStep = getIterationRuntimeContext().getSuperstepNumber(); + int taskId = getRuntimeContext().getIndexOfThisSubtask(); + if (superStep == 2) { + // cache training data + Map partitionInfos = IterTaskObjKeeper.get(partitionInfoHandle, taskId); + Preconditions.checkNotNull(partitionInfos); + List groupNamesToHandle = getGroupNames(partitionInfos, taskId); + Map > cachedData = new HashMap <>(); + for (String groupName : groupNamesToHandle) { + cachedData.put(groupName, new ArrayList <>()); + } + + for (Tuple2 value : values) { + cachedData.get(value.f1.getGroupColNamesString()).add(value.f1); + } + + Map cachedDataArray = new HashMap <>(); + for (Map.Entry > entry : cachedData.entrySet()) { + cachedDataArray.put(entry.getKey(), entry.getValue().toArray(new Sample[0])); + } + IterTaskObjKeeper.put(cacheDataHandle, taskId, cachedDataArray); + + // generate init model and send to corresponding workers. + // Note: We assume that when one group is partitioned to multiple workers, then the number on one + // worker is greater than or equal to k. + for (String groupName : groupNamesToHandle) { + if (partitionInfos.get(groupName)[0] == taskId) { + // this worker do the initialization + Sample[] samples = cachedDataArray.get(groupName); + int k = Math.min(numClusters, samples.length); + + int dataDim = samples[0].getVector().getData().length; + double[][] initCenter = new double[k][dataDim]; + for (int i = 0; i < k; i++) { + System.arraycopy(samples[i].getVector().getData(), 0, initCenter[i], 0, dataDim); + } + for (int targetWorkerId : partitionInfos.get(groupName)) { + out.collect(Tuple3.of(targetWorkerId, groupName, initCenter)); + } + } + } + } + } + } + + /** + * Caches the init model at superStep-2. + */ + private static class CacheInitModel + extends RichMapPartitionFunction , Object> { + + private final long cacheModelHandler; + + private final long lossHandler; + + private final long partitionInfoHandle; + + public CacheInitModel(long cacheModelHandler, long partitionInfoHandle, long lossHandler) { + this.cacheModelHandler = cacheModelHandler; + this.partitionInfoHandle = partitionInfoHandle; + this.lossHandler = lossHandler; + } + + @Override + public void mapPartition(Iterable > values, Collector out) + throws Exception { + int superStep = getIterationRuntimeContext().getSuperstepNumber(); + int taskId = getRuntimeContext().getIndexOfThisSubtask(); + if (superStep == 2) { + Map groupAndOwners = IterTaskObjKeeper.get(partitionInfoHandle, taskId); + Preconditions.checkNotNull(groupAndOwners); + List groupsToHandle = getGroupNames(groupAndOwners, taskId); + + Map initModel = new HashMap <>(); + Map loss = new HashMap <>(); + for (Tuple3 val : values) { + initModel.put(val.f1, val.f2); + loss.put(val.f1, 0.); + } + + if (initModel.size() != groupsToHandle.size()) { + throw new RuntimeException("Illegal model size."); + } + IterTaskObjKeeper.put(cacheModelHandler, taskId, initModel); + IterTaskObjKeeper.put(lossHandler, taskId, loss); + } + } + } + + /** + * Computes model updates using the cache data and cache model from superStep-2. + */ + private static class ComputeUpdates + extends RichMapPartitionFunction > { + + private final long cacheDataHandle; + + private final long cacheModelHandle; + + private final ContinuousDistance distance; + + private final long partitionInfoHandle; + + public ComputeUpdates(long cacheDataHandle, long cacheModelHandle, long partitionInfoHandle, + ContinuousDistance distance) { + this.cacheDataHandle = cacheDataHandle; + this.cacheModelHandle = cacheModelHandle; + this.partitionInfoHandle = partitionInfoHandle; + this.distance = distance; + } + + @Override + public void mapPartition(Iterable values, Collector > out) + throws Exception { + int taskId = getRuntimeContext().getIndexOfThisSubtask(); + int superStepNum = getIterationRuntimeContext().getSuperstepNumber(); + if (superStepNum == 1) { + // Does nothing at step 1. + return; + } + Map cacheModel = IterTaskObjKeeper.get(cacheModelHandle, taskId); + Map cacheData = IterTaskObjKeeper.get(cacheDataHandle, taskId); + Map partitionInfos = IterTaskObjKeeper.get(partitionInfoHandle, taskId); + Preconditions.checkNotNull(cacheData); + Preconditions.checkNotNull(cacheModel); + Preconditions.checkNotNull(partitionInfos); + + String[] groupsHandled = cacheData.keySet().toArray(new String[0]); + + for (String groupName : groupsHandled) { + double[][] model = cacheModel.get(groupName); + Sample[] trainData = cacheData.get(groupName); + // last two elements are: number of data points, sum distance + int featureDim = model[0].length; + double[][] updates = new double[model.length][featureDim + 2]; + for (Sample sample : trainData) { + Tuple2 closestIdAndDistance = findClosestCluster(sample.getVector().getData(), + model, distance); + sample.setClusterId(closestIdAndDistance.f0); + double[] sampleData = sample.getVector().getData(); + for (int i = 0; i < sampleData.length; i++) { + updates[closestIdAndDistance.f0][i] += sampleData[i]; + } + // weight + updates[closestIdAndDistance.f0][featureDim] += 1; + // distance + updates[closestIdAndDistance.f0][featureDim + 1] += closestIdAndDistance.f1; + } + int[] targetWorkerIds = partitionInfos.get(groupName); + for (int wId : targetWorkerIds) { + out.collect(Tuple3.of(wId, groupName, updates)); + } + } + + } + } + + /** + * Update Kmeans models on each partition. Note that one worker may maintain multiple groups from superStep-2. + */ + private static class UpdateModel + extends RichMapPartitionFunction , Object> { + + private final long cacheModelHandle; + + private final long lossHandle; + + private final double tol; + + public UpdateModel(long cacheModelHandle, long lossHandle, double tol) { + this.cacheModelHandle = cacheModelHandle; + this.lossHandle = lossHandle; + this.tol = tol; + } + + @Override + public void mapPartition(Iterable > values, Collector out) + throws Exception { + int taskId = getRuntimeContext().getIndexOfThisSubtask(); + int superStepNum = getIterationRuntimeContext().getSuperstepNumber(); + if (superStepNum == 1) { + // Does nothing at superStep - 1. + // Not converged by default. + out.collect(new Object()); + return; + } + Map updates = new HashMap <>(); + for (Tuple3 value : values) { + String groupName = value.f1; + double[][] updateFromOneWorker = value.f2; + if (updates.containsKey(groupName)) { + double[][] accumulatedUpdates = updates.get(groupName); + for (int i = 0; i < accumulatedUpdates.length; i++) { + for (int j = 0; j < accumulatedUpdates[0].length; j++) { + accumulatedUpdates[i][j] += updateFromOneWorker[i][j]; + } + } + } else { + updates.put(groupName, updateFromOneWorker); + } + } + + boolean hasConverged = true; + Map cacheModel = IterTaskObjKeeper.get(cacheModelHandle, taskId); + Map cachedLoss = IterTaskObjKeeper.get(lossHandle, taskId); + Preconditions.checkNotNull(cachedLoss); + Preconditions.checkNotNull(cacheModel); + + for (Map.Entry entry : updates.entrySet()) { + String groupName = entry.getKey(); + double[][] accumulatedUpdates = entry.getValue(); + double[][] cachedGroupModel = cacheModel.get(groupName); + long numElements = 0; + double distance = 0; + for (int cId = 0; cId < accumulatedUpdates.length; cId++) { + double[] currentCluster = accumulatedUpdates[cId]; + numElements += currentCluster[currentCluster.length - 2]; + distance += currentCluster[currentCluster.length - 1]; + for (int i = 0; i < currentCluster.length - 2; i++) { + currentCluster[i] /= currentCluster[currentCluster.length - 2]; + } + System.arraycopy(currentCluster, 0, cachedGroupModel[cId], 0, cachedGroupModel[cId].length); + } + + double lossLastIteration = cachedLoss.get(groupName); + double currentLoss = distance / numElements; + cachedLoss.put(groupName, currentLoss); + if (Math.abs(currentLoss - lossLastIteration) > tol) { + hasConverged = false; + } + } + if (!hasConverged) { + out.collect(new Object()); + } + } + } + + /** + * Output data samples cached in memory and clear the all objects cached in static memory. + */ + private static class OutputDataSamples extends RichMapPartitionFunction { + + private final long cacheDataHandle; + + private final long partitionInfoHandle; + + private final long cacheModelHandle; + + private final long lossHandle; + + public OutputDataSamples(long cacheDataHandle, long partitionInfoHandle, long cacheModelHandle, + long lossHandle) { + this.cacheDataHandle = cacheDataHandle; + this.partitionInfoHandle = partitionInfoHandle; + this.cacheModelHandle = cacheModelHandle; + this.lossHandle = lossHandle; + } + + @Override + public void mapPartition(Iterable values, Collector out) throws Exception { + int numTasks = getRuntimeContext().getNumberOfParallelSubtasks(); + Map cachedData = null; + for (int i = 0; i < numTasks; i++) { + cachedData = IterTaskObjKeeper.containsAndRemoves(cacheDataHandle, i); + if (cachedData != null) { + break; + } + } + Preconditions.checkNotNull(cachedData); + for (Sample[] oneGroupData : cachedData.values()) { + for (Sample sample : oneGroupData) { + out.collect(sample); + } + } + for (int i = 0; i < numTasks; i++) { + IterTaskObjKeeper.remove(cacheModelHandle, i); + IterTaskObjKeeper.remove(partitionInfoHandle, i); + IterTaskObjKeeper.remove(lossHandle, i); + } + } + } + + /** + * Finds the closest cluster. + * + * @param dataPoint The input data point. + * @param centroids The centroids. + * @param distance The distance measure. + * @return The closest cluster Id and the corresponding distance. + */ + private static Tuple2 findClosestCluster(double[] dataPoint, double[][] centroids, + ContinuousDistance distance) { + double minDistance = Double.MAX_VALUE; + int closestClusterId = Integer.MAX_VALUE; + for (int i = 0; i < centroids.length; i++) { + double tmpDistance = distance.calc(dataPoint, centroids[i]); + if (tmpDistance < minDistance) { + minDistance = tmpDistance; + closestClusterId = i; + } + } + return Tuple2.of(closestClusterId, minDistance); + } + + /** + * Gets the groups that this worker needs to handle. + * + * @param partitionInfos GroupId and the workerIds that needs to handle group `groupId`. + * @param workerId The worker id. + * @return The groups that this worker needs to handle. + */ + private static List getGroupNames(Map partitionInfos, int workerId) { + List res = new ArrayList <>(); + for (Map.Entry entry : partitionInfos.entrySet()) { + for (int idx : entry.getValue()) { + if (idx == workerId) { + res.add(entry.getKey()); + break; + } + } + } + return res; + } + + /** + * Computes the owner of each group of training data, i.e., should it be handled by one worker or multiple workers. + * Note that this result of this function should be deterministic and unique given the same input. Because each + * worker executes this function individually. + * + * @param sizeByGroupName Number of data points in each group. + * @param numWorkers Number of workers. + * @return The owner of each group. + */ + @VisibleForTesting + static Map getPartitionInfo(Map sizeByGroupName, int numWorkers) { + // workerId, number of elements assigned to this worker. + PriorityQueue > workerAndAssignedNumElements = + new PriorityQueue <>( + (o1, o2) -> Long.compare(o1.f1, o2.f1) == 0 ? Integer.compare(o1.f0, o2.f0) + : Long.compare(o1.f1, o2.f1)); + for (int i = 0; i < numWorkers; i++) { + workerAndAssignedNumElements.add(Tuple2.of(i, 0L)); + } + + // converts to tuple array + Tuple2 [] sizeAndGroupNameArray = new Tuple2[sizeByGroupName.size()]; + int idx = 0; + for (Map.Entry entry : sizeByGroupName.entrySet()) { + sizeAndGroupNameArray[idx] = Tuple2.of(entry.getKey(), entry.getValue()); + idx++; + } + Arrays.sort(sizeAndGroupNameArray, + (o1, o2) -> -Long.compare(o1.f1, o2.f1) == 0 ? o1.f0.compareTo(o2.f0) : -Long.compare(o1.f1, o2.f1)); + + // stores the maintained result + Map > groupAndWorkerIds = new HashMap <>(sizeAndGroupNameArray.length); + long averageNumElementsPerWorker = 0; + for (Tuple2 groupAndSize : sizeAndGroupNameArray) { + groupAndWorkerIds.put(groupAndSize.f0, new ArrayList <>()); + averageNumElementsPerWorker += groupAndSize.f1; + } + averageNumElementsPerWorker = averageNumElementsPerWorker / numWorkers + 1; + + for (Tuple2 stringLongTuple2 : sizeAndGroupNameArray) { + String groupName = stringLongTuple2.f0; + long numElementsInThisGroup = stringLongTuple2.f1; + // splits large groups. + long numWorkersNeeded = numElementsInThisGroup / averageNumElementsPerWorker; + // does not split small groups. + if (numWorkersNeeded == 0) { + numWorkersNeeded = 1; + } + long numElementsInEachWorker = numElementsInThisGroup / numWorkersNeeded; + for (int splitId = 0; splitId < numWorkersNeeded; splitId++) { + Tuple2 smallestWorker = workerAndAssignedNumElements.remove(); + groupAndWorkerIds.get(groupName).add(smallestWorker.f0); + workerAndAssignedNumElements.add( + Tuple2.of(smallestWorker.f0, numElementsInEachWorker + smallestWorker.f1)); + } + } + + Map result = new HashMap <>(); + for (Map.Entry > entry : groupAndWorkerIds.entrySet()) { + int[] targetWorkerIds = entry.getValue().stream().mapToInt(Integer::intValue).toArray(); + Arrays.sort(targetWorkerIds); + result.put(entry.getKey(), targetWorkerIds); + } + return result; + } + + private static class MapSampleToRow implements MapFunction { + private final int groupColNamesSize; + private final TypeInformation [] outputTypes; + + public MapSampleToRow(int groupColNamesSize, TypeInformation [] outputTypes) { + this.groupColNamesSize = groupColNamesSize; + this.outputTypes = outputTypes; + } + + @Override + public Row map(Sample value) throws Exception { + Row row = new Row(groupColNamesSize + 2); + for (int i = 0; i < groupColNamesSize; i++) { + row.setField(i, EvaluationUtil.castTo(value.getGroupColNames()[i], outputTypes[i])); + } + row.setField(groupColNamesSize, EvaluationUtil.castTo(value.getSampleId(), + outputTypes[groupColNamesSize])); + row.setField(groupColNamesSize + 1, value.getClusterId()); + return row; + } + } + + private static class HashPartitioner implements Partitioner { + + @Override + public int partition(Integer key, int numPartitions) { + return key % numPartitions; + } + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KMeansModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KMeansModelInfoBatchOp.java index 6511ff09d..7a6e475c9 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KMeansModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KMeansModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.operator.common.clustering.ClusteringModelInfo; import com.alibaba.alink.operator.common.clustering.kmeans.KMeansModelDataConverter; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KMeansPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KMeansPredictBatchOp.java index e92556b05..d78904f78 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KMeansPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KMeansPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.clustering.kmeans.KMeansModelMapper; import com.alibaba.alink.params.clustering.KMeansPredictParams; @@ -11,6 +12,7 @@ * KMeans prediction based on the model fitted by KMeansTrainBatchOp. */ @NameCn("K均值聚类预测") +@NameEn("KMeans Prediction") public final class KMeansPredictBatchOp extends ModelMapBatchOp implements KMeansPredictParams { private static final long serialVersionUID = -4673084154965905629L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KMeansTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KMeansTrainBatchOp.java index f70164efe..6860bb4c5 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KMeansTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KMeansTrainBatchOp.java @@ -10,6 +10,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -20,7 +21,7 @@ import com.alibaba.alink.common.comqueue.communication.AllReduce; import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.exceptions.AkPreconditions; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.linalg.Vector; import com.alibaba.alink.operator.batch.BatchOperator; @@ -34,10 +35,11 @@ import com.alibaba.alink.operator.common.distance.FastDistance; import com.alibaba.alink.operator.common.distance.FastDistanceMatrixData; import com.alibaba.alink.operator.common.distance.FastDistanceVectorData; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; import com.alibaba.alink.params.clustering.KMeansTrainParams; import com.alibaba.alink.params.shared.clustering.HasKMeansWithHaversineDistanceType; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * k-mean clustering is a method of vector quantization, originally from signal processing, that is popular for cluster @@ -52,6 +54,8 @@ }) @ParamSelectColumnSpec(name = "vectorCol", portIndices = 0, allowedTypeCollections = {TypeCollections.VECTOR_TYPES}) @NameCn("K均值聚类训练") +@NameEn("KMeans Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.clustering.KMeans") public final class KMeansTrainBatchOp extends BatchOperator implements KMeansTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KModesPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KModesPredictBatchOp.java new file mode 100644 index 000000000..5ac248c0a --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KModesPredictBatchOp.java @@ -0,0 +1,33 @@ +package com.alibaba.alink.operator.batch.clustering; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.clustering.kmodes.KModesModelMapper; +import com.alibaba.alink.params.clustering.ClusteringPredictParams; + +/** + * @author guotao.gt + */ + +@NameCn("Kmodes预测") +@NameEn("KModes Prediction") +public final class KModesPredictBatchOp extends ModelMapBatchOp + implements ClusteringPredictParams { + + private static final long serialVersionUID = -893588697092734428L; + + /** + * null constructor + */ + public KModesPredictBatchOp() { + this(null); + } + + public KModesPredictBatchOp(Params params) { + super(KModesModelMapper::new, params); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KModesTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KModesTrainBatchOp.java new file mode 100644 index 000000000..e78b56b79 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/KModesTrainBatchOp.java @@ -0,0 +1,380 @@ +package com.alibaba.alink.operator.batch.clustering; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.operators.IterativeDataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.utils.DataSetUtils; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.clustering.kmodes.KModesModel; +import com.alibaba.alink.operator.common.clustering.kmodes.KModesModelData; +import com.alibaba.alink.operator.common.distance.OneZeroDistance; +import com.alibaba.alink.params.clustering.KModesTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Partitioning a large set of objects into homogeneous clusters is a fundamental operation in data mining. The k-mean + * algorithm is best suited for implementing this operation because of its efficiency in clustering large data sets. + * However, working only on numeric values limits its use in data mining because data sets in data mining often contain + * categorical values. In this paper we present an algorithm, called k-modes, to extend the k-mean paradigm to + * categorical domains. We introduce new dissimilarity measures to deal with categorical objects, replace mean of + * clusters with modes, and use a frequency based method to update modes in the clustering process to minimise the + * clustering cost function. Tested with the well known soybean disease data set the algorithm has demonstrated a very + * good classification performance. Experiments on a very large health insurance data set consisting of half a million + * records and 34 categorical attributes show that the algorithm is scalable in terms of both the number of clusters and + * the number of records. + *

+ * Huang, Zhexue. "A fast clustering algorithm to cluster very large categorical data sets in data mining." DMKD 3.8 + * (1997): 34-39. + * + * @author guotao.gt + */ +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = {@PortSpec(value = PortType.MODEL)}) +@ParamSelectColumnSpec(name = "featureCols", portIndices = 0, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}) +@NameCn("Kmodes训练") +@NameEn("KModes Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.clustering.KModes") +public final class KModesTrainBatchOp extends BatchOperator + implements KModesTrainParams { + + private static final long serialVersionUID = 7392501340000162512L; + + /** + * null constructor. + */ + public KModesTrainBatchOp() { + super(new Params()); + } + + /** + * this constructor has all parameter + * + * @param params + */ + public KModesTrainBatchOp(Params params) { + super(params); + } + + /** + * union two map,the same key will plus the value + * + * @param kv1 the 1st kv map + * @param kv2 the 2nd kv map + * @return the map united + */ + private static Map unionMaps(Map kv1, Map kv2) { + Map kv = new HashMap <>(); + kv.putAll(kv1); + for (String k : kv2.keySet()) { + if (kv.containsKey(k)) { + kv.put(k, kv.get(k) + kv2.get(k)); + } else { + kv.put(k, kv2.get(k)); + } + } + return kv; + } + + /** + * union two map array + * + * @param kvArray1 the 1st kv array map + * @param kvArray2 the 2nd kv array map + * @param dim the array's length + * @return the map array united + */ + private static Map [] unionMaps(Map [] kvArray1, Map [] + kvArray2, + int dim) { + Map [] kvArray = new HashMap[dim]; + for (int i = 0; i < dim; i++) { + kvArray[i] = unionMaps(kvArray1[i], kvArray2[i]); + } + return kvArray; + } + + /** + * get the max key of map + * + * @param kv the kv map + * @return the max key of map + */ + private static String getKOfMaxV(Map kv) { + Integer tmp = Integer.MIN_VALUE; + String k = null; + for (Map.Entry entry : kv.entrySet()) { + if (entry.getValue() > tmp) { + tmp = entry.getValue(); + k = entry.getKey(); + } + } + return k; + } + + /** + * get the max keys of map array + * + * @param kvs the kv map array + * @return the max keys of map array + */ + private static String[] getKOfMaxV(Map kvs[]) { + String[] ks = new String[kvs.length]; + for (int i = 0; i < kvs.length; i++) { + ks[i] = getKOfMaxV(kvs[i]); + } + return ks; + } + + @Override + public KModesTrainBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + // get the input parameter's value + final String[] featureColNames = (this.getParams().contains(FEATURE_COLS) + && this.getFeatureCols() != null + && this.getFeatureCols().length > 0) ? + this.getFeatureCols() : in.getSchema().getFieldNames(); + + final int numIter = this.getNumIter(); + final int k = this.getK(); + + if ((featureColNames == null || featureColNames.length == 0)) { + throw new RuntimeException("featureColNames should be set !"); + } + + // construct the sql to get the input feature column name + StringBuilder sbd = new StringBuilder(); + for (int i = 0; i < featureColNames.length; i++) { + if (i > 0) { + sbd.append(", "); + } + sbd.append("cast(`") + .append(featureColNames[i]) + .append("` as VARCHAR) as `") + .append(featureColNames[i]) + .append("`"); + } + + // get the input data needed + DataSet data = in.select(sbd.toString()).getDataSet() + .map(new MapFunction () { + private static final long serialVersionUID = 8380190916941374707L; + + @Override + public String[] map(Row row) throws Exception { + String[] values = new String[row.getArity()]; + for (int i = 0; i < values.length; i++) { + values[i] = (String) row.getField(i); + } + return values; + } + }); + + /** + * initial the centroid + * Tuple3: clusterId, clusterWeight, clusterCentroid + */ + DataSet > initCentroid = DataSetUtils + .zipWithIndex(DataSetUtils.sampleWithSize(data, false, k)) + .map(new MapFunction , Tuple3 >() { + private static final long serialVersionUID = -6852532761276146862L; + + @Override + public Tuple3 map(Tuple2 v) + throws Exception { + return new Tuple3 <>(v.f0, 0., v.f1); + } + }) + .withForwardedFields("f0->f0;f1->f2"); + + IterativeDataSet > loop = initCentroid.iterate(numIter); + DataSet > samplesWithClusterId = assignClusterId(data, loop); + DataSet > updatedCentroid = updateCentroid(samplesWithClusterId, + k, featureColNames.length); + DataSet > finalCentroid = loop.closeWith(updatedCentroid); + + // map the final centroid to row type, plus with the meta info + DataSet modelRows = finalCentroid + .mapPartition(new MapPartitionFunction , Row>() { + private static final long serialVersionUID = -3961032097333930998L; + + @Override + public void mapPartition(Iterable > iterable, + Collector out) throws Exception { + KModesModelData modelData = new KModesModelData(); + modelData.centroids = new ArrayList <>(); + for (Tuple3 t : iterable) { + modelData.centroids.add(t); + } + modelData.featureColNames = featureColNames; + + // meta plus data + new KModesModel().save(modelData, out); + } + }) + .setParallelism(1); + + // store the clustering model to the table + this.setOutput(modelRows, new KModesModel().getModelSchema()); + + return this; + } + + /** + * assign clusterId to sample + * + * @param data the whole sample data + * @param centroids the centroids of clusters + * @return the DataSet of sample with clusterId + */ + private DataSet > assignClusterId( + DataSet data, + DataSet > centroids) { + + class FindClusterOp extends RichMapFunction > { + private static final long serialVersionUID = 6305282153314372806L; + List > centroids; + OneZeroDistance distance; + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + centroids = getRuntimeContext().getBroadcastVariable("centroids"); + this.distance = new OneZeroDistance(); + } + + @Override + public Tuple2 map(String[] denseVector) throws Exception { + long clusterId = KModesModel.findCluster(centroids, denseVector, distance); + return new Tuple2 <>(clusterId, denseVector); + } + } + + return data.map(new FindClusterOp()) + .withBroadcastSet(centroids, "centroids"); + } + + /** + * update centroid of cluster + * + * @param samplesWithClusterId sample with clusterId + * @param k the number of clusters + * @param dim the vectorSize of featureColNames + * @return the DataSet of center with clusterId and the weight(the number of samples belong to the cluster) + */ + private DataSet > updateCentroid( + DataSet > samplesWithClusterId, final int k, final int dim) { + + // tuple3: clusterId, clusterWeight, clusterCentroid + DataSet []>> localAggregate = + samplesWithClusterId.mapPartition(new DataPartition(k, dim)); + + return localAggregate + .groupBy(0) + .reduce(new DataReduce(dim)) + .map( + new MapFunction []>, Tuple3 >() { + private static final long serialVersionUID = -6833217715929845251L; + + @Override + public Tuple3 map( + Tuple3 []> in) { + return new Tuple3 <>(in.f0, in.f1, getKOfMaxV(in.f2)); + } + }) + .withForwardedFields("f0;f1"); + } + + /** + * calc local centroids + */ + public static class DataPartition + implements MapPartitionFunction , Tuple3 []>> { + + private static final long serialVersionUID = 4053491536690724820L; + private int k; + private int dim; + + public DataPartition(int k, int dim) { + this.k = k; + this.dim = dim; + } + + @Override + public void mapPartition(Iterable > iterable, + Collector []>> collector) + throws Exception { + Map [][] localCentroids = new HashMap[k][dim]; + for (int i = 0; i < k; i++) { + for (int j = 0; j < dim; j++) { + localCentroids[i][j] = new HashMap (32); + } + } + double[] localCounts = new double[k]; + Arrays.fill(localCounts, 0.); + + for (Tuple2 point : iterable) { + int clusterId = point.f0.intValue(); + localCounts[clusterId] += 1.0; + + for (int j = 0; j < dim; j++) { + if (localCentroids[clusterId][j].containsKey(point.f1[j])) { + localCentroids[clusterId][j].put(point.f1[j], + localCentroids[clusterId][j].get(point.f1[j]) + 1); + } else { + localCentroids[clusterId][j].put(point.f1[j], 1); + } + } + } + + for (int i = 0; i < localCentroids.length; i++) { + collector.collect(new Tuple3 <>((long) i, localCounts[i], localCentroids[i])); + } + } + } + + /** + * calc Global Centroids + */ + public static class DataReduce implements ReduceFunction []>> { + + private static final long serialVersionUID = -7472289261425686956L; + private int dim; + + public DataReduce(int dim) { + this.dim = dim; + } + + @Override + public Tuple3 []> reduce(Tuple3 []> + in1, + Tuple3 []> + in2) { + return new Tuple3 <>(in1.f0, in1.f1 + in2.f1, unionMaps(in1.f2, in2.f2, dim)); + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/LdaModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/LdaModelInfoBatchOp.java index b0ca697bd..70fee8cc5 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/LdaModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/LdaModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import java.util.List; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/LdaPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/LdaPredictBatchOp.java index d2b556fe9..59ec5bab3 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/LdaPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/LdaPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -14,6 +15,7 @@ */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("LDA预测") +@NameEn("LDA Prediction") public final class LdaPredictBatchOp extends ModelMapBatchOp implements LdaPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/LdaTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/LdaTrainBatchOp.java index 5e51b057c..5d5ee4bb1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/clustering/LdaTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/clustering/LdaTrainBatchOp.java @@ -19,6 +19,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; @@ -27,11 +28,11 @@ import com.alibaba.alink.common.comqueue.IterativeComQueue; import com.alibaba.alink.common.comqueue.communication.AllReduce; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.linalg.DenseMatrix; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.linalg.Vector; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.RowCollector; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; @@ -52,9 +53,10 @@ import com.alibaba.alink.operator.common.nlp.DocCountVectorizerModelData; import com.alibaba.alink.operator.common.nlp.DocCountVectorizerModelMapper; import com.alibaba.alink.operator.common.nlp.FeatureType; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; import com.alibaba.alink.params.clustering.LdaTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import com.google.common.collect.Lists; import org.apache.commons.math3.random.RandomDataGenerator; @@ -81,6 +83,8 @@ }) @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("LDA训练") +@NameEn("LDA Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.clustering.Lda") public class LdaTrainBatchOp extends BatchOperator implements LdaTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/AggLookupBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/AggLookupBatchOp.java index 0a67ed9ac..a649af720 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/AggLookupBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/AggLookupBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.dataproc.AggLookupModelMapper; import com.alibaba.alink.params.dataproc.AggLookupParams; @@ -10,6 +11,7 @@ /** */ @NameCn("Agg表查找") +@NameEn("Agg Lookup") public class AggLookupBatchOp extends ModelMapBatchOp implements AggLookupParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/AppendIdBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/AppendIdBatchOp.java index de3dfd1ba..69246e101 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/AppendIdBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/AppendIdBatchOp.java @@ -16,10 +16,11 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.RowUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.params.dataproc.AppendIdBatchParams; @@ -34,6 +35,7 @@ @InputPorts(values = {@PortSpec(PortType.DATA)}) @OutputPorts(values = {@PortSpec(PortType.DATA)}) @NameCn("添加id列") +@NameEn("Append Id") public final class AppendIdBatchOp extends BatchOperator implements AppendIdBatchParams { public final static String appendIdColName = "append_id"; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/FirstNBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/FirstNBatchOp.java index fcc4d7aba..951c7a4b5 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/FirstNBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/FirstNBatchOp.java @@ -6,6 +6,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -19,6 +20,7 @@ @InputPorts(values = @PortSpec(PortType.DATA)) @OutputPorts(values = @PortSpec(PortType.DATA)) @NameCn("前N个数") +@NameEn("FirstN") public class FirstNBatchOp extends BatchOperator implements FirstNParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/FlattenMTableBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/FlattenMTableBatchOp.java index 8bc86e7cc..96f90789d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/FlattenMTableBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/FlattenMTableBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.FlatMapBatchOp; @@ -14,7 +15,9 @@ */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.MTABLE_TYPES) +@ParamSelectColumnSpec(name = "reservedCols") @NameCn("MTable展开") +@NameEn("Flatten MTable") public class FlattenMTableBatchOp extends FlatMapBatchOp implements FlattenMTableParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeIndexerStringPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeIndexerStringPredictBatchOp.java index 0c2b1de08..09057d307 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeIndexerStringPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeIndexerStringPredictBatchOp.java @@ -19,6 +19,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -47,7 +48,8 @@ @ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.LONG_TYPES) @ReservedColsWithFirstInputSpec -@NameCn("并行ID化预测") +@NameCn("超大ID化预测") +@NameEn("Huge Indexer String Prediction") public final class HugeIndexerStringPredictBatchOp extends BatchOperator implements HugeMultiStringIndexerPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeLookupBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeLookupBatchOp.java index dfcc8bcc1..bebbce76b 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeLookupBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeLookupBatchOp.java @@ -10,6 +10,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -32,6 +33,7 @@ @ParamSelectColumnSpec(name="mapKeyCols") @ParamSelectColumnSpec(name="mapValueCols") @NameCn("HugeLookup") +@NameEn("HugeLookup") public class HugeLookupBatchOp extends BatchOperator implements LookupParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeMultiIndexerStringPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeMultiIndexerStringPredictBatchOp.java index 949efd4a4..95c471a37 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeMultiIndexerStringPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeMultiIndexerStringPredictBatchOp.java @@ -18,6 +18,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -49,6 +50,7 @@ allowedTypeCollections = TypeCollections.LONG_TYPES) @ReservedColsWithFirstInputSpec @NameCn("多列并行反ID化预测") +@NameEn("Huge Multi Indexer String Prediction") public final class HugeMultiIndexerStringPredictBatchOp extends BatchOperator implements HugeMultiStringIndexerPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeMultiStringIndexerPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeMultiStringIndexerPredictBatchOp.java index 6fcf3369e..776f4f839 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeMultiStringIndexerPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeMultiStringIndexerPredictBatchOp.java @@ -22,6 +22,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortDesc; import com.alibaba.alink.common.annotation.PortSpec; @@ -56,7 +57,8 @@ @OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) @SelectedColsWithFirstInputSpec @ReservedColsWithFirstInputSpec -@NameCn("HugeStringIndexer预测") +@NameCn("HugeMultiStringIndexer预测") +@NameEn("Huge Multi String Indexer Prediction") public final class HugeMultiStringIndexerPredictBatchOp extends BatchOperator implements HugeMultiStringIndexerPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeStringIndexerPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeStringIndexerPredictBatchOp.java index 89d9a0abd..59fddeef6 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeStringIndexerPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/HugeStringIndexerPredictBatchOp.java @@ -21,6 +21,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortDesc; import com.alibaba.alink.common.annotation.PortSpec; @@ -48,6 +49,7 @@ @OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) @SelectedColsWithFirstInputSpec @NameCn("并行ID化预测") +@NameEn("Huge String Indexer Prediction") public final class HugeStringIndexerPredictBatchOp extends BatchOperator implements HugeMultiStringIndexerPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ImputerModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ImputerModelInfoBatchOp.java index 52bca223b..c983956eb 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ImputerModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ImputerModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.dataproc.ImputerModelInfo; import com.alibaba.alink.params.dataproc.ImputerTrainParams; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ImputerPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ImputerPredictBatchOp.java index 52912c8e7..5d452b06c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ImputerPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ImputerPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.dataproc.ImputerModelMapper; import com.alibaba.alink.params.dataproc.ImputerPredictParams; @@ -13,6 +14,7 @@ * Strategy support min, max, mean or value. */ @NameCn("缺失值填充批预测") +@NameEn("Imputer Predict") public class ImputerPredictBatchOp extends ModelMapBatchOp implements ImputerPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ImputerTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ImputerTrainBatchOp.java index 8be927e4e..26cebb44c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ImputerTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ImputerTrainBatchOp.java @@ -11,7 +11,12 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamCond; +import com.alibaba.alink.common.annotation.ParamCond.CondType; +import com.alibaba.alink.common.annotation.ParamMutexRule; +import com.alibaba.alink.common.annotation.ParamMutexRule.ActionType; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortSpec.OpType; @@ -19,15 +24,16 @@ import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.utils.RowCollector; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.dataproc.ImputerModelDataConverter; import com.alibaba.alink.operator.common.dataproc.ImputerModelInfo; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummary; import com.alibaba.alink.params.dataproc.ImputerTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Imputer completes missing values in a dataSet, but only same type of columns can be selected at the same time. @@ -43,6 +49,16 @@ @ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("缺失值填充训练") +@NameEn("Imputer Train") +@ParamMutexRule( + name = "fillValue", type = ActionType.SHOW, + cond = @ParamCond( + name = "strategy", + type = CondType.WHEN_IN_VALUES, + values = "VALUE" + ) +) +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.dataproc.Imputer") public class ImputerTrainBatchOp extends BatchOperator implements ImputerTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/IndexToStringPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/IndexToStringPredictBatchOp.java index 69ef6fde1..3ae0386c4 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/IndexToStringPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/IndexToStringPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -15,6 +16,7 @@ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPE) @NameCn("IndexToString预测") +@NameEn("Index To String Prediction") public final class IndexToStringPredictBatchOp extends ModelMapBatchOp implements IndexToStringPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/JsonValueBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/JsonValueBatchOp.java index 96e8851b6..dfb4437c7 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/JsonValueBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/JsonValueBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.utils.JsonPathMapper; @@ -14,6 +15,7 @@ */ @ParamSelectColumnSpec(name="selectedCol",allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("JSON值抽取") +@NameEn("Json Value Extraction") public final class JsonValueBatchOp extends MapBatchOp implements JsonValueParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupBatchOp.java index f39518ef2..b6ee5b146 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.SelectedColsWithSecondInputSpec; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.dataproc.LookupModelMapper; @@ -13,6 +14,7 @@ */ @SelectedColsWithSecondInputSpec @NameCn("表查找") +@NameEn("Lookup Table") public class LookupBatchOp extends ModelMapBatchOp implements LookupParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupHBaseBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupHBaseBatchOp.java index 4962e5b9e..41e050c0c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupHBaseBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupHBaseBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -13,7 +14,8 @@ * batch op for lookup from hbase. */ @ParamSelectColumnSpec(name = "rowKeyCol", allowedTypeCollections = TypeCollections.STRING_TYPE) -@NameCn("添加HBase数据") +@NameCn("查询HBase数据表") +@NameEn("Lookup HBase Table") public class LookupHBaseBatchOp extends MapBatchOp implements LookupHBaseParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupRecentDaysBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupRecentDaysBatchOp.java index d3d31e3d2..b290a83d9 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupRecentDaysBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupRecentDaysBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.SelectedColsWithSecondInputSpec; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.dataproc.LookupRecentDaysModelMapper; @@ -13,6 +14,7 @@ */ @SelectedColsWithSecondInputSpec @NameCn("表查找") +@NameEn("Lookup Recent Days Table") public class LookupRecentDaysBatchOp extends ModelMapBatchOp implements LookupRecentDaysParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupRedisRowBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupRedisRowBatchOp.java index f4e5bf32f..e2f766e1e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupRedisRowBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupRedisRowBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.SelectedColsWithFirstInputSpec; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.dataproc.LookupRedisMapper; @@ -12,7 +13,8 @@ * batch op for lookup from redis. */ @SelectedColsWithFirstInputSpec -@NameCn("Redis 表查找") +@NameCn("Redis 表查找Row类型") +@NameEn("Lookup Redis Table For Row") public class LookupRedisRowBatchOp extends MapBatchOp implements LookupRedisParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupRedisStringBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupRedisStringBatchOp.java index e06229b8e..5f4ee6adb 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupRedisStringBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/LookupRedisStringBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.SelectedColsWithFirstInputSpec; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.dataproc.LookupRedisStringMapper; @@ -13,6 +14,7 @@ */ @SelectedColsWithFirstInputSpec @NameCn("Redis 表查找String类型") +@NameEn("Lookup Redis Table For String") public class LookupRedisStringBatchOp extends MapBatchOp implements LookupStringRedisParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MaxAbsScalerModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MaxAbsScalerModelInfoBatchOp.java index 189175371..2d91099df 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MaxAbsScalerModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MaxAbsScalerModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.dataproc.MaxAbsScalarModelInfo; import java.util.List; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MaxAbsScalerPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MaxAbsScalerPredictBatchOp.java index 1144afec1..c47e7b37b 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MaxAbsScalerPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MaxAbsScalerPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.dataproc.MaxAbsScalerModelMapper; import com.alibaba.alink.params.dataproc.MaxAbsScalerPredictParams; @@ -13,6 +14,7 @@ * MaxAbsPredict will scale the dataSet with model which trained from MaxAbsTrain. */ @NameCn("绝对值最大化批预测") +@NameEn("MaxAbs Scaler Prediction") public final class MaxAbsScalerPredictBatchOp extends ModelMapBatchOp implements MaxAbsScalerPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MaxAbsScalerTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MaxAbsScalerTrainBatchOp.java index 56280b7bd..9975c7907 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MaxAbsScalerTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MaxAbsScalerTrainBatchOp.java @@ -10,20 +10,22 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortSpec.OpType; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.dataproc.MaxAbsScalarModelInfo; import com.alibaba.alink.operator.common.dataproc.MaxAbsScalerModelDataConverter; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummary; import com.alibaba.alink.params.dataproc.MaxAbsScalerTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * MaxAbsScaler transforms a dataSet of rows, rescaling each feature to range @@ -34,6 +36,8 @@ @OutputPorts(values = @PortSpec(value = PortType.MODEL)) @ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("绝对值最大化训练") +@NameEn("MaxAbs Scaler Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.dataproc.MaxAbsScaler") public class MaxAbsScalerTrainBatchOp extends BatchOperator implements MaxAbsScalerTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MinMaxScalerModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MinMaxScalerModelInfoBatchOp.java index 189811112..b508e818b 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MinMaxScalerModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MinMaxScalerModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.dataproc.MinMaxScalerModelInfo; import java.util.List; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MinMaxScalerPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MinMaxScalerPredictBatchOp.java index 4ea2e46a6..5a884214d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MinMaxScalerPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MinMaxScalerPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.dataproc.MinMaxScalerModelMapper; import com.alibaba.alink.params.dataproc.MinMaxScalerPredictParams; @@ -13,6 +14,7 @@ * MinMaxScalerPredict will scale the dataSet with model which trained from MaxAbsTrain. */ @NameCn("归一化批预测") +@NameEn("Min Max Scaler Batch Predict") public final class MinMaxScalerPredictBatchOp extends ModelMapBatchOp implements MinMaxScalerPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MinMaxScalerTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MinMaxScalerTrainBatchOp.java index 6d4b84e2b..98ab22492 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MinMaxScalerTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MinMaxScalerTrainBatchOp.java @@ -11,20 +11,22 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortSpec.OpType; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.dataproc.MinMaxScalerModelDataConverter; import com.alibaba.alink.operator.common.dataproc.MinMaxScalerModelInfo; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummary; import com.alibaba.alink.params.dataproc.MinMaxScalerTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * MinMaxScaler transforms a dataSet of rows, rescaling each feature @@ -35,6 +37,8 @@ @OutputPorts(values = @PortSpec(value = PortType.MODEL)) @ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("归一化训练") +@NameEn("Min Max Scaler Train") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.dataproc.MinMaxScaler") public class MinMaxScalerTrainBatchOp extends BatchOperator implements MinMaxScalerTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MultiStringIndexerTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MultiStringIndexerTrainBatchOp.java index 9560392eb..551fcc0cc 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MultiStringIndexerTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/MultiStringIndexerTrainBatchOp.java @@ -30,6 +30,7 @@ import com.alibaba.alink.params.dataproc.HasStringOrderTypeDefaultAsRandom; import com.alibaba.alink.params.dataproc.MultiStringIndexerTrainParams; import com.alibaba.alink.params.shared.colname.HasSelectedCols; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Encode several columns of strings to bigint type indices. The indices are consecutive bigint type @@ -50,6 +51,7 @@ @ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.INT_LONG_STRING_TYPES) @NameCn("多字段字符串编码训练") @NameEn("Multiple String Indexer Train") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.dataproc.MultiStringIndexer") public final class MultiStringIndexerTrainBatchOp extends BatchOperator implements MultiStringIndexerTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/RebalanceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/RebalanceBatchOp.java index b7e96f564..31f60a81d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/RebalanceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/RebalanceBatchOp.java @@ -8,6 +8,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortDesc; import com.alibaba.alink.common.annotation.PortSpec; @@ -22,6 +23,7 @@ @InputPorts(values = {@PortSpec(PortType.DATA)}) @OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) @NameCn("数据Rebalance") +@NameEn("Data Rebalance") public final class RebalanceBatchOp extends BatchOperator { private static final long serialVersionUID = -4236329417415800780L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/SampleBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/SampleBatchOp.java index c3f690ccc..638e54c06 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/SampleBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/SampleBatchOp.java @@ -7,6 +7,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -21,6 +22,7 @@ @InputPorts(values = @PortSpec(PortType.DATA)) @OutputPorts(values = @PortSpec(PortType.DATA)) @NameCn("随机采样") +@NameEn("Data Sampling") public final class SampleBatchOp extends BatchOperator implements SampleParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/SampleWithSizeBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/SampleWithSizeBatchOp.java index 906a8dace..098542dbd 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/SampleWithSizeBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/SampleWithSizeBatchOp.java @@ -7,6 +7,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -19,6 +20,7 @@ @InputPorts(values = @PortSpec(PortType.DATA)) @OutputPorts(values = @PortSpec(PortType.DATA)) @NameCn("固定条数随机采样") +@NameEn("Data Sampling With Fixed Size") public class SampleWithSizeBatchOp extends BatchOperator implements SampleWithSizeParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ShuffleBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ShuffleBatchOp.java index 92b42d3fb..20696e715 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ShuffleBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ShuffleBatchOp.java @@ -10,6 +10,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -26,6 +27,7 @@ @InputPorts(values = @PortSpec(PortType.DATA)) @OutputPorts(values = @PortSpec(PortType.DATA)) @NameCn("打乱数据顺序") +@NameEn("Data Shuffling") public final class ShuffleBatchOp extends BatchOperator { private static final long serialVersionUID = 4849933592970017744L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/SparseFeatureIndexerTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/SparseFeatureIndexerTrainBatchOp.java index 15fde3dc9..cdc08ade9 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/SparseFeatureIndexerTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/SparseFeatureIndexerTrainBatchOp.java @@ -4,6 +4,7 @@ import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichFilterFunction; import org.apache.flink.api.common.functions.RichGroupReduceFunction; import org.apache.flink.api.common.functions.RichMapPartitionFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; @@ -12,6 +13,8 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.java.utils.DataSetUtils; +import org.apache.flink.configuration.Configuration; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.Table; import org.apache.flink.types.Row; @@ -22,11 +25,12 @@ import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortDesc; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.exceptions.AkIllegalDataException; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.dataproc.HugeStringIndexerUtil; @@ -36,10 +40,14 @@ import java.util.Comparator; import java.util.HashSet; import java.util.Iterator; +import java.util.List; import java.util.PriorityQueue; @InputPorts(values = @PortSpec(value = PortType.DATA)) -@OutputPorts(values = @PortSpec(value = PortType.MODEL)) +@OutputPorts(values = { + @PortSpec(value = PortType.MODEL, desc = PortDesc.MODEL_INFO), + @PortSpec(value = PortType.DATA, desc = PortDesc.FEATURE_FREQUENCY) +}) @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPE) @NameCn("稀疏特征编码训练") @NameEn("Sparse Feature Indexer Train") @@ -64,12 +72,26 @@ public SparseFeatureIndexerTrainBatchOp linkFrom(BatchOperator ... inputs) { String feaSplit = getSpareFeatureDelimiter(); String feaValueSplit = getKvValDelimiter(); boolean hasWeight = getHasValue(); + double minSupport = getMinPercent(); + String[] candidateTags = getCandidateTags(); BatchOperator in = checkAndGetFirst(inputs); TypeInformation feaType = TableUtil.findColType(in.getSchema(), featureColName); if (!feaType.equals(Types.STRING)) { throw new AkIllegalDataException("featureColName type must be string, but input type is " + feaType); } DataSet dataset = in.select(featureColName).getDataSet(); + + DataSet cnt = DataSetUtils.countElementsPerPartition(dataset) + .sum(1) + .map(new MapFunction , Long>() { + private static final long serialVersionUID = -8507632108475760763L; + + @Override + public Long map(Tuple2 value) { + return value.f1; + } + }).name("statics_sample_number"); + DataSet> feaFreSta = dataset.flatMap( new FlatMapFunction >() { @Override @@ -105,6 +127,19 @@ public Tuple2 reduce(Tuple2 value1, return Tuple2.of(value1.f0 + value2.f0, value1.f1); } }).name("split_and_count_fea_frequency"); + if (candidateTags != null && candidateTags.length > 0) { + feaFreSta = feaFreSta.filter(new RichFilterFunction >() { + @Override + public boolean filter(Tuple2 value) throws Exception { + for (String tag : candidateTags) { + if (value.f1.contains(tag)) { + return true; + } + } + return false; + } + }).name("filter_candidate_fea_tag"); + } this.setSideOutputTables( new Table[] {DataSetConversionUtil.toTable(getMLEnvironmentId(),feaFreSta.map( new MapFunction , Row>() { @@ -122,6 +157,22 @@ public boolean filter(Tuple2 value) throws Exception { return value.f0 >= minFrequency; } }).name("filter_less_frequency_fea"); + } else if (minSupport > 0) { + feaFreSta = feaFreSta + .filter(new RichFilterFunction >() { + private Integer count; + + @Override + public void open(Configuration parameters) throws Exception { + List countList = getRuntimeContext().getBroadcastVariable("count"); + count = (int) Math.floor(countList.get(0) * minSupport); + } + @Override + public boolean filter(Tuple2 value) throws Exception { + return value.f0 >= count; + } + }).withBroadcastSet(cnt, "count") + .name("filter_less_frequency_fea"); } if (topN > 0) { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/SplitBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/SplitBatchOp.java index 4e6db7159..3f74dbfac 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/SplitBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/SplitBatchOp.java @@ -16,6 +16,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -23,7 +24,7 @@ import com.alibaba.alink.common.exceptions.AkIllegalStateException; import com.alibaba.alink.common.exceptions.AkPreconditions; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.RowUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.params.dataproc.SplitParams; @@ -45,6 +46,7 @@ @InputPorts(values = @PortSpec(PortType.DATA)) @OutputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(PortType.DATA)}) @NameCn("数据拆分") +@NameEn("Data Splitting") public final class SplitBatchOp extends BatchOperator implements SplitParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StandardScalerModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StandardScalerModelInfoBatchOp.java index 8a6569986..0ba25d111 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StandardScalerModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StandardScalerModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.dataproc.StandardScalerModelInfo; import java.util.List; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StandardScalerPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StandardScalerPredictBatchOp.java index ed9c0889f..8929c198c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StandardScalerPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StandardScalerPredictBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.PortDesc; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -19,6 +20,7 @@ @PortSpec(value = PortType.DATA, desc = PortDesc.PREDICT_INPUT_DATA) }) @NameCn("标准化批预测") +@NameEn("Standard Scaler Batch Predict") public final class StandardScalerPredictBatchOp extends ModelMapBatchOp implements StandardPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StandardScalerTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StandardScalerTrainBatchOp.java index 4301031f8..a1eaae7b5 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StandardScalerTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StandardScalerTrainBatchOp.java @@ -11,28 +11,32 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortSpec.OpType; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.dataproc.StandardScalerModelDataConverter; import com.alibaba.alink.operator.common.dataproc.StandardScalerModelInfo; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummary; import com.alibaba.alink.params.dataproc.StandardTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * StandardScaler transforms a dataset, normalizing each feature to have unit standard deviation and/or zero mean. */ @InputPorts(values = @PortSpec(value = PortType.DATA, opType = OpType.BATCH)) @OutputPorts(values = @PortSpec(value = PortType.MODEL)) -@ParamSelectColumnSpec(name = "selectCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("标准化训练") +@NameEn("Standard Scaler Train") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.dataproc.StandardScaler") public class StandardScalerTrainBatchOp extends BatchOperator implements StandardTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StratifiedSampleBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StratifiedSampleBatchOp.java index 72fbf8ef3..2909ca668 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StratifiedSampleBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StratifiedSampleBatchOp.java @@ -11,6 +11,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; @@ -34,6 +35,7 @@ @OutputPorts(values = @PortSpec(PortType.DATA)) @ParamSelectColumnSpec(name = "strataCol", portIndices = 0) @NameCn("分层随机采样") +@NameEn("Stratified Sampling") public final class StratifiedSampleBatchOp extends BatchOperator implements StratifiedSampleParams , HashWithReplacementParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StratifiedSampleWithSizeBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StratifiedSampleWithSizeBatchOp.java index 468fc7cea..db71699f8 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StratifiedSampleWithSizeBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StratifiedSampleWithSizeBatchOp.java @@ -11,6 +11,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; @@ -32,6 +33,7 @@ @OutputPorts(values = @PortSpec(PortType.DATA)) @ParamSelectColumnSpec(name = "strataCol", portIndices = 0) @NameCn("固定条数分层随机采样") +@NameEn("Stratified Sampling With Fixed Size") public final class StratifiedSampleWithSizeBatchOp extends BatchOperator implements StrafiedSampleWithSizeParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StringIndexerTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StringIndexerTrainBatchOp.java index 24e613290..6b970666b 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StringIndexerTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/StringIndexerTrainBatchOp.java @@ -28,6 +28,7 @@ import com.alibaba.alink.operator.common.dataproc.StringIndexerModelDataConverter; import com.alibaba.alink.params.dataproc.HasStringOrderTypeDefaultAsRandom; import com.alibaba.alink.params.dataproc.StringIndexerTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Encode one column of strings to bigint type indices. @@ -49,6 +50,7 @@ @ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.INT_LONG_STRING_TYPES) @NameCn("字符串编码训练") @NameEn("String Indexer Train") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.dataproc.StringIndexer") public final class StringIndexerTrainBatchOp extends BatchOperator implements StringIndexerTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/TensorToVectorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/TensorToVectorBatchOp.java index 3ccb10999..f30817d60 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/TensorToVectorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/TensorToVectorBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -14,6 +15,7 @@ */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.NUMERIC_TENSOR_TYPES) @NameCn("张量转向量") +@NameEn("Tensor To Vector") public class TensorToVectorBatchOp extends MapBatchOp implements TensorToVectorParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ToMTableBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ToMTableBatchOp.java index 06c8f6eca..a6896a2c3 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ToMTableBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ToMTableBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.dataproc.ToMTableMapper; import com.alibaba.alink.params.dataproc.ToMTableParams; @@ -11,6 +12,7 @@ * batch op for transforming to MTable. */ @NameCn("转MTable") +@NameEn("To MTable") public class ToMTableBatchOp extends MapBatchOp implements ToMTableParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ToTensorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ToTensorBatchOp.java index eee2f9a38..27588727f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ToTensorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ToTensorBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.dataproc.ToTensorMapper; import com.alibaba.alink.params.dataproc.ToTensorParams; @@ -11,6 +12,7 @@ * batch op for transforming to tensor. */ @NameCn("转Tensor") +@NameEn("To Tensor") public class ToTensorBatchOp extends MapBatchOp implements ToTensorParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ToVectorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ToVectorBatchOp.java index 64fc06266..b0f085dbc 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ToVectorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/ToVectorBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.dataproc.ToVectorMapper; import com.alibaba.alink.params.dataproc.ToVectorParams; @@ -11,6 +12,7 @@ * batch op for transforming to vector. */ @NameCn("转向量") +@NameEn("To Vector") public class ToVectorBatchOp extends MapBatchOp implements ToVectorParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/TypeConvertBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/TypeConvertBatchOp.java index fa55345c9..db366fa28 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/TypeConvertBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/TypeConvertBatchOp.java @@ -7,9 +7,10 @@ import org.apache.flink.table.api.Types; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortDesc; import com.alibaba.alink.common.annotation.PortSpec; @@ -37,6 +38,7 @@ @InputPorts(values = {@PortSpec(PortType.DATA)}) @OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) @NameCn("类型转换") +@NameEn("Type Converter") public final class TypeConvertBatchOp extends BatchOperator implements TypeConvertParams { @@ -150,6 +152,10 @@ public TypeConvertBatchOp linkFrom(BatchOperator ... inputs) { outTypes[i] = Types.STRING(); type = "VARCHAR"; break; + case "decimal": + outTypes[i] = Types.DECIMAL(); + type = "DECIMAL"; + break; default: throw new AkUnsupportedOperationException("Not support type:" + this.newType); } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/VectorToTensorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/VectorToTensorBatchOp.java index 788aa9342..a4a856b5a 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/VectorToTensorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/VectorToTensorBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -14,6 +15,7 @@ */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量转张量") +@NameEn("Vector To Tensor") public class VectorToTensorBatchOp extends MapBatchOp implements VectorToTensorParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/WeightSampleBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/WeightSampleBatchOp.java index 6c1d603f5..8767eedba 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/WeightSampleBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/WeightSampleBatchOp.java @@ -12,6 +12,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; @@ -38,6 +39,7 @@ @OutputPorts(values = @PortSpec(PortType.DATA)) @ParamSelectColumnSpec(name = "weightCol", allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}) @NameCn("加权采样") +@NameEn("Weighted Sampling") public class WeightSampleBatchOp extends BatchOperator implements WeightSampleParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/AnyToTripleBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/AnyToTripleBatchOp.java index 35dca6846..5582f09b8 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/AnyToTripleBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/AnyToTripleBatchOp.java @@ -2,6 +2,7 @@ import org.apache.flink.ml.api.misc.param.Params; +import com.alibaba.alink.common.annotation.Internal; import com.alibaba.alink.common.annotation.NameCn; import com.alibaba.alink.operator.batch.utils.FlatMapBatchOp; import com.alibaba.alink.operator.common.dataproc.format.AnyToTripleFlatMapper; @@ -14,6 +15,7 @@ */ @NameCn("") +@Internal class AnyToTripleBatchOp> extends FlatMapBatchOp implements ToTripleParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToCsvBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToCsvBatchOp.java index 08734cd3c..ec5ba2016 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToCsvBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToCsvBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.dataproc.format.FormatType; import com.alibaba.alink.params.dataproc.format.ColumnsToCsvParams; @@ -10,6 +11,7 @@ * Transform data type from Columns to Csv. */ @NameCn("列数据转CSV") +@NameEn("Columns To Csv") public class ColumnsToCsvBatchOp extends BaseFormatTransBatchOp implements ColumnsToCsvParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToJsonBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToJsonBatchOp.java index b73a02a57..1c90b85c9 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToJsonBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToJsonBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.dataproc.format.FormatType; import com.alibaba.alink.params.dataproc.format.ColumnsToJsonParams; @@ -10,6 +11,7 @@ * Transform data type from Columns to Json. */ @NameCn("列数据转JSON") +@NameEn("Columns To JSON") public class ColumnsToJsonBatchOp extends BaseFormatTransBatchOp implements ColumnsToJsonParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToKvBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToKvBatchOp.java index e3364e335..1e41c1a5d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToKvBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToKvBatchOp.java @@ -3,6 +3,8 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.operator.common.dataproc.format.FormatType; import com.alibaba.alink.params.dataproc.format.ColumnsToKvParams; @@ -10,6 +12,8 @@ * Transform data type from Columns to Kv. */ @NameCn("列数据转KV") +@NameEn("table to kv") +@ParamSelectColumnSpec(name = "selectedCols") public class ColumnsToKvBatchOp extends BaseFormatTransBatchOp implements ColumnsToKvParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToTripleBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToTripleBatchOp.java index 44923b3c8..05c547389 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToTripleBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToTripleBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.dataproc.format.FormatType; import com.alibaba.alink.params.dataproc.format.ColumnsToTripleParams; @@ -10,6 +11,7 @@ * Transform data type from Columns to Triple. */ @NameCn("列数据转三元组") +@NameEn("Columns To Triple") public class ColumnsToTripleBatchOp extends AnyToTripleBatchOp implements ColumnsToTripleParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToVectorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToVectorBatchOp.java index d8505f6e3..bddee669d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToVectorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/ColumnsToVectorBatchOp.java @@ -3,6 +3,9 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; import com.alibaba.alink.params.dataproc.format.ColumnsToVectorParams; @@ -10,6 +13,8 @@ * Transform data type from Columns to Vector. */ @NameCn("列数据转向量") +@NameEn("table to vector") +@ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) public class ColumnsToVectorBatchOp extends BaseFormatTransBatchOp implements ColumnsToVectorParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToColumnsBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToColumnsBatchOp.java index 61d1612e5..d2354bc40 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToColumnsBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToColumnsBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "csvCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("CSV转列数据") +@NameEn("Csv To Columns") public class CsvToColumnsBatchOp extends BaseFormatTransBatchOp implements CsvToColumnsParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToJsonBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToJsonBatchOp.java index d5177cbac..abe15ad9a 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToJsonBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToJsonBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "csvCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("CSV转JSON") +@NameEn("Csv To JSON") public class CsvToJsonBatchOp extends BaseFormatTransBatchOp implements CsvToJsonParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToKvBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToKvBatchOp.java index 2cb43b0a4..fdbaf2652 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToKvBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToKvBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "csvCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("CSV转KV") +@NameEn("Csv To KV") public class CsvToKvBatchOp extends BaseFormatTransBatchOp implements CsvToKvParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToTripleBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToTripleBatchOp.java index 0d6e67939..57e97c370 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToTripleBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToTripleBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "csvCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("CSV转三元组") +@NameEn("Csv To Triple") public class CsvToTripleBatchOp extends AnyToTripleBatchOp implements CsvToTripleParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToVectorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToVectorBatchOp.java index 92a894849..da4607db4 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToVectorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/CsvToVectorBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "csvCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("CSV转向量") +@NameEn("Csv To Vector") public class CsvToVectorBatchOp extends BaseFormatTransBatchOp implements CsvToVectorParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToColumnsBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToColumnsBatchOp.java index 3422349d9..3b29d4cab 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToColumnsBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToColumnsBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "jsonCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("JSON转列数据") +@NameEn("JSON To Columns") public class JsonToColumnsBatchOp extends BaseFormatTransBatchOp implements JsonToColumnsParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToCsvBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToCsvBatchOp.java index d944d7b03..303db6deb 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToCsvBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToCsvBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "jsonCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("JSON转CSV") +@NameEn("JSON To CSV") public class JsonToCsvBatchOp extends BaseFormatTransBatchOp implements JsonToCsvParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToKvBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToKvBatchOp.java index 5de495926..8ae787b42 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToKvBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToKvBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "jsonCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("JSON转KV") +@NameEn("JSON To KV") public class JsonToKvBatchOp extends BaseFormatTransBatchOp implements JsonToKvParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToTripleBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToTripleBatchOp.java index 7b1efe96e..cd4b642aa 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToTripleBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToTripleBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "jsonCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("JSON转三元组") +@NameEn("JSON To Triple") public class JsonToTripleBatchOp extends AnyToTripleBatchOp implements JsonToTripleParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToVectorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToVectorBatchOp.java index e12a67631..9eee1152c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToVectorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/JsonToVectorBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "jsonCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("JSON转向量") +@NameEn("JSON To Vector") public class JsonToVectorBatchOp extends BaseFormatTransBatchOp implements JsonToVectorParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToColumnsBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToColumnsBatchOp.java index c9f68d4d4..541a2cfc4 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToColumnsBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToColumnsBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "kvCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("KV转列数据") +@NameEn("KV To Columns") public class KvToColumnsBatchOp extends BaseFormatTransBatchOp implements KvToColumnsParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToCsvBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToCsvBatchOp.java index 706e96519..07a51622c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToCsvBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToCsvBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "kvCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("KV转CSV") +@NameEn("KV To CSV") public class KvToCsvBatchOp extends BaseFormatTransBatchOp implements KvToCsvParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToJsonBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToJsonBatchOp.java index ea80fd3cc..4579ab5f5 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToJsonBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToJsonBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "kvCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("KV转JSON") +@NameEn("KV To JSON") public class KvToJsonBatchOp extends BaseFormatTransBatchOp implements KvToJsonParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToTripleBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToTripleBatchOp.java index 9cb49866c..7274e9dd1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToTripleBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToTripleBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "kvCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("KV转三元组") +@NameEn("KV To Triple") public class KvToTripleBatchOp extends AnyToTripleBatchOp implements KvToTripleParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToVectorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToVectorBatchOp.java index 0e8c97438..ae91c2fea 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToVectorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/KvToVectorBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "kvCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("KV转向量") +@NameEn("KV To Vector") public class KvToVectorBatchOp extends BaseFormatTransBatchOp implements KvToVectorParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToAnyBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToAnyBatchOp.java index 544803f2b..9e6e56c16 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToAnyBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToAnyBatchOp.java @@ -13,6 +13,7 @@ import org.apache.flink.util.Collector; import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.Internal; import com.alibaba.alink.common.annotation.NameCn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortDesc; @@ -40,6 +41,7 @@ @InputPorts(values = {@PortSpec(value = PortType.DATA, opType = OpType.BATCH)}) @OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) @NameCn("") +@Internal class TripleToAnyBatchOp> extends BatchOperator implements FromTripleParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToColumnsBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToColumnsBatchOp.java index 288dc50ba..7d2ee275d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToColumnsBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToColumnsBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "tripleColumnCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("三元组转列数据") +@NameEn("Triple To Columns") public class TripleToColumnsBatchOp extends TripleToAnyBatchOp implements TripleToColumnsParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToCsvBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToCsvBatchOp.java index 5430fca29..a62269be3 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToCsvBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToCsvBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "tripleColumnCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("三元组转CSV") +@NameEn("Triple To CSV") public class TripleToCsvBatchOp extends TripleToAnyBatchOp implements TripleToCsvParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToJsonBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToJsonBatchOp.java index 456f143d6..a561d3710 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToJsonBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToJsonBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.dataproc.format.FormatType; import com.alibaba.alink.params.dataproc.format.TripleToJsonParams; @@ -10,6 +11,7 @@ * Transform data type from Triple to Json. */ @NameCn("三元组转JSON") +@NameEn("Triple To Json") public class TripleToJsonBatchOp extends TripleToAnyBatchOp implements TripleToJsonParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToKvBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToKvBatchOp.java index 8237a2d23..008e82c07 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToKvBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToKvBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.dataproc.format.FormatType; import com.alibaba.alink.params.dataproc.format.TripleToKvParams; @@ -10,6 +11,7 @@ * Transform data type from Triple to Kv. */ @NameCn("三元组转KV") +@NameEn("Triple To Kv") public class TripleToKvBatchOp extends TripleToAnyBatchOp implements TripleToKvParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToVectorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToVectorBatchOp.java index 64822dcd9..d98f65487 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToVectorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/TripleToVectorBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -14,6 +15,7 @@ @ParamSelectColumnSpec(name = "tripleColumnCol", allowedTypeCollections = TypeCollections.INT_LONG_TYPES) @ParamSelectColumnSpec(name = "tripleValueCol", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("三元组转向量") +@NameEn("Triple To Vector") public class TripleToVectorBatchOp extends TripleToAnyBatchOp implements TripleToVectorParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToColumnsBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToColumnsBatchOp.java index 7cf1f1081..87594f976 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToColumnsBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToColumnsBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量转列数据") +@NameEn("Vector To Columns") public class VectorToColumnsBatchOp extends BaseFormatTransBatchOp implements VectorToColumnsParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToCsvBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToCsvBatchOp.java index 9c243dcbf..124351cdf 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToCsvBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToCsvBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量转CSV") +@NameEn("Vector To Csv") public class VectorToCsvBatchOp extends BaseFormatTransBatchOp implements VectorToCsvParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToJsonBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToJsonBatchOp.java index 48254122b..154ebe8f3 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToJsonBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToJsonBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量转JSON") +@NameEn("Vector to Json") public class VectorToJsonBatchOp extends BaseFormatTransBatchOp implements VectorToJsonParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToKvBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToKvBatchOp.java index a75308f67..14073f3df 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToKvBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToKvBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量转KV") +@NameEn("Vector to Kv") public class VectorToKvBatchOp extends BaseFormatTransBatchOp implements VectorToKvParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToTripleBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToTripleBatchOp.java index d173942a9..3db2e46ef 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToTripleBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/format/VectorToTripleBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.dataproc.format.FormatType; @@ -14,6 +15,7 @@ @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量转三元组") +@NameEn("Vector to Triple") public class VectorToTripleBatchOp extends AnyToTripleBatchOp implements VectorToTripleParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/tensor/TensorReshapeBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/tensor/TensorReshapeBatchOp.java index 90496dcd4..38a70c3af 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/tensor/TensorReshapeBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/tensor/TensorReshapeBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.dataproc.tensor.TensorReshapeMapper; import com.alibaba.alink.params.dataproc.tensor.TensorReshapeParams; @NameCn("张量重组") +@NameEn("Tensor Reshape") public final class TensorReshapeBatchOp extends MapBatchOp implements TensorReshapeParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorAssemblerBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorAssemblerBatchOp.java index d7355f614..554d5acb7 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorAssemblerBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorAssemblerBatchOp.java @@ -2,8 +2,13 @@ import org.apache.flink.ml.api.misc.param.Params; +import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.dataproc.vector.VectorAssemblerMapper; @@ -17,8 +22,11 @@ * * this operator can transform batch data. */ -@ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) +@InputPorts(values = @PortSpec(value = PortType.DATA)) +@OutputPorts(values = @PortSpec(value = PortType.DATA)) +@ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.NUMERIC_AND_VECTOR_TYPES) @NameCn("向量聚合") +@NameEn("Vector Assembler") public final class VectorAssemblerBatchOp extends MapBatchOp implements VectorAssemblerParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorBiFunctionBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorBiFunctionBatchOp.java index 2009f3739..ecb7cdf95 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorBiFunctionBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorBiFunctionBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.dataproc.vector.VectorBiFunctionMapper; import com.alibaba.alink.params.dataproc.vector.VectorBiFunctionParams; @@ -12,6 +13,7 @@ * Vector can be sparse vector or dense vector. */ @NameCn("二元向量函数") +@NameEn("Vector BiFunction") public final class VectorBiFunctionBatchOp extends MapBatchOp implements VectorBiFunctionParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorElementwiseProductBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorElementwiseProductBatchOp.java index bebb7521b..9b574fb01 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorElementwiseProductBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorElementwiseProductBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -16,6 +17,7 @@ */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量元素依次相乘") +@NameEn("Vector Elementwise Product") public final class VectorElementwiseProductBatchOp extends MapBatchOp implements VectorElementwiseProductParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorFunctionBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorFunctionBatchOp.java index a1d7a16a6..4cd38c519 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorFunctionBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorFunctionBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -22,6 +23,7 @@ @OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = {TypeCollections.VECTOR_TYPES}) @NameCn("向量函数") +@NameEn("Vector Function") public final class VectorFunctionBatchOp extends MapBatchOp implements VectorFunctionParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorImputerModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorImputerModelInfoBatchOp.java index 99a61f284..f2cf09597 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorImputerModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorImputerModelInfoBatchOp.java @@ -3,10 +3,8 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; -import com.alibaba.alink.operator.batch.dataproc.ImputerModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.dataproc.ImputerModelInfo; -import com.alibaba.alink.params.dataproc.ImputerTrainParams; import com.alibaba.alink.params.dataproc.vector.VectorImputerTrainParams; import java.util.List; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorImputerPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorImputerPredictBatchOp.java index 14997632e..69e1bb61f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorImputerPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorImputerPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.dataproc.vector.VectorImputerModelMapper; import com.alibaba.alink.params.dataproc.vector.VectorImputerPredictParams; @@ -17,6 +18,7 @@ * If value, will replace missing value with the value. */ @NameCn("向量缺失值填充预测") +@NameEn("Vector Imputer Prediction") public class VectorImputerPredictBatchOp extends ModelMapBatchOp implements VectorImputerPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorImputerTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorImputerTrainBatchOp.java index 99c931e95..087f9035d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorImputerTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorImputerTrainBatchOp.java @@ -10,20 +10,22 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.utils.RowCollector; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.dataproc.ImputerModelInfo; import com.alibaba.alink.operator.common.dataproc.vector.VectorImputerModelDataConverter; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; import com.alibaba.alink.params.dataproc.vector.VectorImputerTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Imputer completes missing values in a dataSet, but only same type of columns can be selected at the same time. @@ -39,6 +41,8 @@ @OutputPorts(values = @PortSpec(value = PortType.MODEL)) @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量缺失值填充训练") +@NameEn("Vector Imputer Train") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.dataproc.vector.VectorImputer") public class VectorImputerTrainBatchOp extends BatchOperator implements VectorImputerTrainParams , WithModelInfoBatchOp { @@ -95,7 +99,8 @@ private boolean isNeedStatModel() { } else if (Strategy.VALUE.equals(strategy)) { return false; } else { - throw new AkIllegalOperatorParameterException("Only support \"MAX\", \"MEAN\", \"MIN\" and \"VALUE\" strategy."); + throw new AkIllegalOperatorParameterException( + "Only support \"MAX\", \"MEAN\", \"MIN\" and \"VALUE\" strategy."); } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorInteractionBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorInteractionBatchOp.java index 8b498290c..c0750cf1c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorInteractionBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorInteractionBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -15,6 +16,7 @@ */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量元素两两相乘") +@NameEn("Vector Interaction") public final class VectorInteractionBatchOp extends MapBatchOp implements VectorInteractionParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorMaxAbsScalerModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorMaxAbsScalerModelInfoBatchOp.java index 745f51c0f..0202899dd 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorMaxAbsScalerModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorMaxAbsScalerModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.dataproc.vector.VectorMaxAbsScalarModelInfo; import java.util.List; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorMaxAbsScalerPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorMaxAbsScalerPredictBatchOp.java index d15de4288..c1d67ad80 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorMaxAbsScalerPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorMaxAbsScalerPredictBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -17,6 +18,7 @@ */ @InputPorts(values = {@PortSpec(value = PortType.MODEL, suggestions = VectorMaxAbsScalerTrainBatchOp.class), @PortSpec(PortType.DATA)}) @NameCn("向量绝对值最大化预测") +@NameEn("Vector MaxAbs Scaler Prediction") public final class VectorMaxAbsScalerPredictBatchOp extends ModelMapBatchOp implements VectorMaxAbsScalerPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorMaxAbsScalerTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorMaxAbsScalerTrainBatchOp.java index c0900b705..e99c4fee1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorMaxAbsScalerTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorMaxAbsScalerTrainBatchOp.java @@ -8,18 +8,20 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.dataproc.vector.VectorMaxAbsScalarModelInfo; import com.alibaba.alink.operator.common.dataproc.vector.VectorMaxAbsScalerModelDataConverter; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; import com.alibaba.alink.params.dataproc.vector.VectorMaxAbsScalerTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * MaxAbsScaler transforms a dataSet of Vector rows, rescaling each feature to range @@ -30,6 +32,8 @@ @OutputPorts(values = @PortSpec(value = PortType.MODEL)) @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量绝对值最大化训练") +@NameEn("Vector MaxAbs Scaler Train") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.dataproc.vector.VectorMaxAbsScaler") public final class VectorMaxAbsScalerTrainBatchOp extends BatchOperator implements VectorMaxAbsScalerTrainParams , WithModelInfoBatchOp implements VectorMinMaxScalerPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorMinMaxScalerTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorMinMaxScalerTrainBatchOp.java index 962e478cd..bf4cf9b5e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorMinMaxScalerTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorMinMaxScalerTrainBatchOp.java @@ -9,18 +9,20 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.dataproc.vector.VectorMinMaxScalerModelDataConverter; import com.alibaba.alink.operator.common.dataproc.vector.VectorMinMaxScalerModelInfo; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; import com.alibaba.alink.params.dataproc.vector.VectorMinMaxScalerTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * MinMaxScaler transforms a dataSet of rows, rescaling each feature @@ -31,6 +33,8 @@ @OutputPorts(values = @PortSpec(value = PortType.MODEL)) @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量归一化训练") +@NameEn("Vector MinAbs Scaler Train") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.dataproc.vector.VectorMinMaxScaler") public final class VectorMinMaxScalerTrainBatchOp extends BatchOperator implements VectorMinMaxScalerTrainParams , WithModelInfoBatchOp implements VectorNormalizeParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorPolynomialExpandBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorPolynomialExpandBatchOp.java index dfc5a3e22..af1b5b9b7 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorPolynomialExpandBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorPolynomialExpandBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -16,6 +17,7 @@ */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量多项式展开") +@NameEn("Vector Polynomial Expand") public final class VectorPolynomialExpandBatchOp extends MapBatchOp implements VectorPolynomialExpandParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorSizeHintBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorSizeHintBatchOp.java index b355f4ce4..1b2248865 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorSizeHintBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorSizeHintBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.dataproc.vector.VectorSizeHintMapper; import com.alibaba.alink.params.dataproc.vector.VectorSizeHintParams; @@ -13,6 +14,7 @@ * If optimistic, will accept the vector if it is not null. */ @NameCn("向量长度检验") +@NameEn("Vector Size Hint") public final class VectorSizeHintBatchOp extends MapBatchOp implements VectorSizeHintParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorSliceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorSliceBatchOp.java index 53fdb6ac1..7412a9564 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorSliceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorSliceBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.dataproc.vector.VectorSliceMapper; import com.alibaba.alink.params.dataproc.vector.VectorSliceParams; @@ -12,6 +13,7 @@ * original features. It is useful for extracting features from a vector column. */ @NameCn("向量切片") +@NameEn("Vector Slice") public final class VectorSliceBatchOp extends MapBatchOp implements VectorSliceParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorStandardScalerModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorStandardScalerModelInfoBatchOp.java index 7507ce724..ac20b5700 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorStandardScalerModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorStandardScalerModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.dataproc.vector.VectorStandardScalerModelInfo; import java.util.List; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorStandardScalerPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorStandardScalerPredictBatchOp.java index eca096671..27ff3740c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorStandardScalerPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorStandardScalerPredictBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortDesc; import com.alibaba.alink.common.annotation.PortSpec; @@ -20,6 +21,7 @@ @OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) @ReservedColsWithSecondInputSpec @NameCn("向量标准化预测") +@NameEn("Vector Standard Scaler Prediction") public final class VectorStandardScalerPredictBatchOp extends ModelMapBatchOp implements VectorStandardPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorStandardScalerTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorStandardScalerTrainBatchOp.java index 5c2d4d8a2..a78af5e8e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorStandardScalerTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/dataproc/vector/VectorStandardScalerTrainBatchOp.java @@ -9,19 +9,21 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortSpec.OpType; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.dataproc.vector.VectorStandardScalerModelDataConverter; import com.alibaba.alink.operator.common.dataproc.vector.VectorStandardScalerModelInfo; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; import com.alibaba.alink.params.dataproc.vector.VectorStandardTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * StandardScaler transforms a dataSet, normalizing each feature to have unit standard deviation and/or zero mean. @@ -31,6 +33,8 @@ @OutputPorts(values = @PortSpec(value = PortType.MODEL)) @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量标准化训练") +@NameEn("Vector Standard Scaler Train") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.dataproc.vector.VectorStandardScaler") public final class VectorStandardScalerTrainBatchOp extends BatchOperator implements VectorStandardTrainParams , WithModelInfoBatchOp implements EvalBinaryClassParams , EvaluationMetricsCollector { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalClusterBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalClusterBatchOp.java index aa9e1becb..e82d295dc 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalClusterBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalClusterBatchOp.java @@ -24,6 +24,7 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; @@ -32,7 +33,7 @@ import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.linalg.Vector; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.distance.FastDistance; import com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary; @@ -42,7 +43,7 @@ import com.alibaba.alink.operator.common.evaluation.EvaluationMetricsCollector; import com.alibaba.alink.operator.common.evaluation.EvaluationUtil; import com.alibaba.alink.operator.common.evaluation.LongMatrix; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; import com.alibaba.alink.operator.common.statistics.basicstatistic.SparseVectorSummary; import com.alibaba.alink.params.evaluation.EvalClusterParams; @@ -64,6 +65,7 @@ @ParamSelectColumnSpec(name = "predictionCol") @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("聚类评估") +@NameEn("Eval Cluster") public final class EvalClusterBatchOp extends BatchOperator implements EvalClusterParams, EvaluationMetricsCollector { public static final String SILHOUETTE_COEFFICIENT = "silhouetteCoefficient"; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalMultiClassBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalMultiClassBatchOp.java index 5882ced18..8bb28b66d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalMultiClassBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalMultiClassBatchOp.java @@ -13,6 +13,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; @@ -52,6 +53,7 @@ @ParamSelectColumnSpec(name = "predictionCol") @ParamSelectColumnSpec(name = "predictionDetailCol", allowedTypeCollections = TypeCollections.STRING_TYPE) @NameCn("多分类评估") +@NameEn("Eval Multi Class") public class EvalMultiClassBatchOp extends BatchOperator implements EvalMultiClassParams , EvaluationMetricsCollector { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalMultiLabelBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalMultiLabelBatchOp.java index 723cfcfd0..40be6f576 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalMultiLabelBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalMultiLabelBatchOp.java @@ -14,12 +14,13 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.exceptions.AkPreconditions; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary; @@ -40,6 +41,7 @@ @InputPorts(values = @PortSpec(PortType.DATA)) @OutputPorts(values = @PortSpec(PortType.EVAL_METRICS)) @NameCn("多标签分类评估") +@NameEn("Eval Multi Label") public class EvalMultiLabelBatchOp extends BatchOperator implements EvalMultiLabelParams , EvaluationMetricsCollector { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalRankingBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalRankingBatchOp.java index c77d47c2d..8439d083e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalRankingBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalRankingBatchOp.java @@ -12,10 +12,11 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary; @@ -32,6 +33,7 @@ @InputPorts(values = @PortSpec(PortType.DATA)) @OutputPorts(values = @PortSpec(PortType.EVAL_METRICS)) @NameCn("排序评估") +@NameEn("Eval Ranking") public class EvalRankingBatchOp extends BatchOperator implements EvalRankingParams , EvaluationMetricsCollector { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalRegressionBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalRegressionBatchOp.java index 722eaa81b..d9f5a9206 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalRegressionBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalRegressionBatchOp.java @@ -13,14 +13,15 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.exceptions.AkPreconditions; -import com.alibaba.alink.common.utils.DataSetConversionUtil; -import com.alibaba.alink.common.utils.DataSetUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary; @@ -44,6 +45,7 @@ @ParamSelectColumnSpec(name = "labelCol") @ParamSelectColumnSpec(name = "predictionCol") @NameCn("回归评估") +@NameEn("Eval Regression") public final class EvalRegressionBatchOp extends BatchOperator implements EvalRegressionParams , EvaluationMetricsCollector { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalTimeSeriesBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalTimeSeriesBatchOp.java index dad53bfe1..9274bc300 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalTimeSeriesBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/evaluation/EvalTimeSeriesBatchOp.java @@ -13,14 +13,15 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.exceptions.AkPreconditions; -import com.alibaba.alink.common.utils.DataSetConversionUtil; -import com.alibaba.alink.common.utils.DataSetUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.evaluation.BaseMetricsSummary; @@ -44,6 +45,7 @@ @ParamSelectColumnSpec(name = "labelCol") @ParamSelectColumnSpec(name = "predictionCol") @NameCn("时间序列评估") +@NameEn("Eval Time Series") public final class EvalTimeSeriesBatchOp extends BatchOperator implements EvalTimeSeriesParams , EvaluationMetricsCollector { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/AutoCrossAlgoTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/AutoCrossAlgoTrainBatchOp.java new file mode 100644 index 000000000..7b2213012 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/AutoCrossAlgoTrainBatchOp.java @@ -0,0 +1,16 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; + +@NameCn("AutoCross训练") +@NameEn("AutoCross Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.AutoCrossAlgo") +public class AutoCrossAlgoTrainBatchOp extends AutoCrossTrainBatchOp { + public AutoCrossAlgoTrainBatchOp(Params params) { + super(params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/AutoCrossPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/AutoCrossPredictBatchOp.java new file mode 100644 index 000000000..cd16365e1 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/AutoCrossPredictBatchOp.java @@ -0,0 +1,25 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.feature.AutoCross.AutoCrossModelMapper; +import com.alibaba.alink.params.feature.AutoCrossPredictParams; + +@NameCn("AutoCross预测") +@NameEn("AutoCross Prediction") +public class AutoCrossPredictBatchOp extends ModelMapBatchOp + implements AutoCrossPredictParams { + + private static final long serialVersionUID = 3987270029076248190L; + + public AutoCrossPredictBatchOp() { + this(new Params()); + } + + public AutoCrossPredictBatchOp(Params params) { + super(AutoCrossModelMapper::new, params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/AutoCrossTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/AutoCrossTrainBatchOp.java new file mode 100644 index 000000000..68421c38e --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/AutoCrossTrainBatchOp.java @@ -0,0 +1,299 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.operators.IterativeDataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Table; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.AlinkGlobalConfiguration; +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.operator.common.feature.AutoCross.BuildSideOutput; +import com.alibaba.alink.operator.common.feature.AutoCross.DataProfile; +import com.alibaba.alink.operator.common.feature.AutoCross.FeatureEvaluator; +import com.alibaba.alink.operator.common.feature.AutoCross.FeatureSet; +import com.alibaba.alink.operator.common.linear.LinearModelData; +import com.alibaba.alink.operator.common.linear.LinearModelType; +import com.alibaba.alink.operator.common.slidingwindow.SessionSharedData; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; + +import java.util.ArrayList; +import java.util.List; + +//todo here we do not support discretization of numerical cols and only keep them unchanged. +//todo if want to discretizate numerical cols, make another op. +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = { + @PortSpec(PortType.MODEL), + @PortSpec(value = PortType.DATA, desc = PortDesc.CROSSED_FEATURES) +}) +@NameCn("AutoCross训练") +@NameEn("AutoCross Train") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.AutoCross") +public class AutoCrossTrainBatchOp extends BaseCrossTrainBatchOp { + + private static final long serialVersionUID = 2847616118502942858L; + + public AutoCrossTrainBatchOp() { + this(new Params()); + } + + public AutoCrossTrainBatchOp(Params params) { + super(params); + } + public static final String AC_TRAIN_DATA = "AC_TRAIN_DATA"; + public static final int SESSION_ID = SessionSharedData.getNewSessionId(); + + DataSet buildAcModelData(DataSet> trainData, + DataSet featureSizeDataSet, + DataColumnsSaver dataColumnsSaver) { + String[] numericalCols = dataColumnsSaver.numericalCols; + double fraction = getFraction(); + int kCross = getKCross(); + boolean toFixCoef = true; + + final LinearModelType linearModelType = LinearModelType.LR; // todo: support regression and multi-class cls. + + DataSet > initialCoefsAndScore = trainData + .mapPartition( + new RichMapPartitionFunction , Tuple2>() { + @Override + public void mapPartition(Iterable > values, + Collector > out) throws Exception { + int taskId = getRuntimeContext().getIndexOfThisSubtask(); + if (taskId != 0) { + return; + } + List > samples = + (List >) SessionSharedData.get(AC_TRAIN_DATA, SESSION_ID, taskId); + + Tuple2 >, List >> + splited = + FeatureEvaluator.split(samples, fraction, 0); + DataProfile profile = new DataProfile(linearModelType, true); + LinearModelData model = FeatureEvaluator.train(splited.f0, profile); + DenseVector dv = model.coefVector; + double score = FeatureEvaluator.evaluate(model, splited.f1); + out.collect(Tuple2.of(dv.getData(), score)); + } + }).withBroadcastSet(trainData, "barrier"); + + DataSet featureSet = featureSizeDataSet + .map(new RichMapFunction () { + Tuple2 initialCoef; + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + this.initialCoef = (Tuple2) getRuntimeContext().getBroadcastVariable("initialCoefsAndScore").get(0); + } + + private static final long serialVersionUID = 2539121870837769846L; + + @Override + public FeatureSet map(int[] value) throws Exception { + //the value is all the categorical cols. + FeatureSet featureSet = new FeatureSet(value); + featureSet.updateFixedCoefs(initialCoef.f0); + return featureSet; + } + }).withBroadcastSet(initialCoefsAndScore, "initialCoefsAndScore"); + + + /* + * This iteration implements the beam search algorithm as described in + * "AutoCross: Automatic Feature Crossing for Tabular Data in Real-World Applications". + */ + final int maxSearchStep = getMaxSearchStep(); + IterativeDataSet loop = featureSet.iterate(maxSearchStep); + + DataSet > crossFeatureAndScore = trainData + .mapPartition(new CrossFeatureOperation(numericalCols.length, linearModelType, fraction, toFixCoef, kCross)) + .withBroadcastSet(loop, "featureSet") + .withBroadcastSet(featureSizeDataSet, "featureSize") + .withBroadcastSet(trainData, "barrier") + .name("train_and_evaluate"); + + //return candidate indices, score and fixed coefs. + DataSet > theBestOne = crossFeatureAndScore + .reduce(new ReduceFunction >() { + private static final long serialVersionUID = 1099754368531239834L; + + @Override + public Tuple3 reduce(Tuple3 value1, + Tuple3 value2) + throws Exception { + return value1.f1 > value2.f1 ? value1 : value2; + } + }).name("reduce the best one"); + + DataSet updatedFeatureSet = loop + .map(new RichMapFunction () { + + private static final long serialVersionUID = 1017420195682258788L; + + @Override + public FeatureSet map(FeatureSet fs) { + List > bc = getRuntimeContext() + .getBroadcastVariable("the_one"); + + if (bc.size() == 0) {//todo check can its size be 0? + return fs; + } + fs.addOneCrossFeature(bc.get(0).f0, bc.get(0).f1); + fs.updateFixedCoefs(bc.get(0).f2); + return fs; + } + }) + .withBroadcastSet(theBestOne, "the_one") + .name("update feature set"); + + featureSet = loop.closeWith(updatedFeatureSet, theBestOne); + + DataSet acModel = featureSet + .flatMap(new BuildModel(oneHotVectorCol, numericalCols, hasDiscrete)) + .withBroadcastSet(featureSizeDataSet, "featureSize"); + return acModel; + } + + void buildSideOutput(OneHotTrainBatchOp oneHotModel, DataSet acModel, + List numericalCols, long mlEnvId) { + DataSet sideDataSet = oneHotModel.getDataSet() + .mapPartition(new BuildSideOutput(numericalCols.size())) + .withBroadcastSet(acModel, "autocrossModel") + .setParallelism(1); + + Table sideModel = DataSetConversionUtil.toTable(mlEnvId, sideDataSet, + new String[] {"index", "feature", "value"}, + new TypeInformation[] {Types.INT, Types.STRING, Types.STRING}); + this.setSideOutputTables(new Table[] {sideModel}); + } + + + private static class CrossFeatureOperation extends + RichMapPartitionFunction , + Tuple3 > { + private static final long serialVersionUID = -4682615150965402842L; + transient FeatureSet featureSet; + transient int numTasks; + transient List candidates; + private int[] featureSize; + int numericalSize; + private final LinearModelType linearModelType; + private final double fraction; + private final boolean toFixCoef; + private final int kCross; + + CrossFeatureOperation(int numericalSize, LinearModelType linearModelType, + double fraction, boolean toFixCoef, int kCross) { + this.numericalSize = numericalSize; + this.linearModelType = linearModelType; + this.fraction = fraction; + this.toFixCoef = toFixCoef; + this.kCross = kCross; + } + + @Override + public void open(Configuration parameters) throws Exception { + numTasks = getRuntimeContext().getNumberOfParallelSubtasks(); + featureSet = (FeatureSet) getRuntimeContext().getBroadcastVariable("featureSet").get(0); + featureSize = (int[]) getIterationRuntimeContext().getBroadcastVariable("featureSize").get(0); + //generate candidate feature combination. + candidates = featureSet.generateCandidateCrossFeatures(); + if (AlinkGlobalConfiguration.isPrintProcessInfo()) { + System.out.println(String.format("\n** step %d, # picked features %d, # total candidates %d", + getIterationRuntimeContext().getSuperstepNumber(), featureSet.crossFeatureSet.size(), + candidates.size())); + } + } + + @Override + public void mapPartition(Iterable > values, + Collector > out) throws + Exception { + int taskId = getRuntimeContext().getIndexOfThisSubtask(); + if (candidates.size() <= taskId) { + return; + } + + List > data = (List >) + SessionSharedData.get(AC_TRAIN_DATA, SESSION_ID, taskId); + if (AlinkGlobalConfiguration.isPrintProcessInfo()) { + System.out.println("taskId: " + taskId + ", data size: " + data.size()); + } + double[] fixedCoefs = featureSet.getFixedCoefs(); + FeatureEvaluator evaluator = new FeatureEvaluator( + linearModelType, + data, + featureSize, + fixedCoefs, + fraction, + toFixCoef, + kCross); + //distribution candidates to each partition. + for (int i = taskId; i < candidates.size(); i += numTasks) { + int[] candidate = candidates.get(i); + List testingFeatures = new ArrayList <>(featureSet.crossFeatureSet); + testingFeatures.add(candidate); + if (AlinkGlobalConfiguration.isPrintProcessInfo()) { + System.out.println("evaluating " + JsonConverter.toJson(testingFeatures)); + } + Tuple2 scoreCoef = evaluator.score(testingFeatures, numericalSize); + double score = scoreCoef.f0; + out.collect(Tuple3.of(candidate, score, scoreCoef.f1)); + } + } + } + + private static class BuildModel extends RichFlatMapFunction { + private static final long serialVersionUID = -3939236593612638919L; + private int[] indexSize; + private String vectorCol; + private boolean hasDiscrete; + private String[] numericalCols; + + BuildModel(String vectorCol, String[] numericalCols, boolean hasDiscrete) { + this.vectorCol = vectorCol; + this.hasDiscrete = hasDiscrete; + this.numericalCols = numericalCols; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + indexSize = (int[]) getRuntimeContext().getBroadcastVariable("featureSize").get(0); + } + + @Override + public void flatMap(FeatureSet fs, Collector out) throws Exception { + fs.numericalCols = numericalCols; + fs.indexSize = indexSize; + fs.vecColName = vectorCol; + fs.hasDiscrete = hasDiscrete; + out.collect(Row.of(0L, fs.toString(), null)); + for (int i = 0; i < fs.crossFeatureSet.size(); i++) { + out.collect(Row.of((long) (i + 1), JsonConverter.toJson(fs.crossFeatureSet.get(i)), fs.scores.get(i))); + } + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/BaseCrossTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/BaseCrossTrainBatchOp.java new file mode 100644 index 000000000..5e524feb1 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/BaseCrossTrainBatchOp.java @@ -0,0 +1,359 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.MLEnvironmentFactory; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.ReservedColsWithFirstInputSpec; +import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; +import com.alibaba.alink.common.linalg.SparseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.linalg.VectorUtil; +import com.alibaba.alink.common.mapper.PipelineModelMapper; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.dataproc.FirstReducer; +import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData; +import com.alibaba.alink.operator.common.feature.OneHotModelDataConverter; +import com.alibaba.alink.operator.common.feature.OneHotModelMapper; +import com.alibaba.alink.operator.common.slidingwindow.SessionSharedData; +import com.alibaba.alink.params.classification.RandomForestTrainParams; +import com.alibaba.alink.params.feature.AutoCrossTrainParams; +import com.alibaba.alink.params.feature.HasDropLast; +import com.alibaba.alink.params.shared.colname.HasOutputColsDefaultAsNull; +import com.alibaba.alink.params.shared.colname.HasReservedColsDefaultAsNull; +import com.alibaba.alink.pipeline.PipelineModel; +import com.alibaba.alink.pipeline.TransformerBase; +import com.alibaba.alink.pipeline.feature.AutoCrossAlgoModel; +import com.alibaba.alink.pipeline.feature.OneHotEncoderModel; +import org.apache.commons.lang3.ArrayUtils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static com.alibaba.alink.operator.batch.feature.AutoCrossTrainBatchOp.AC_TRAIN_DATA; +import static com.alibaba.alink.operator.batch.feature.AutoCrossTrainBatchOp.SESSION_ID; + +@ReservedColsWithFirstInputSpec +@ParamSelectColumnSpec(name = "selectedCols") +@ParamSelectColumnSpec(name = "labelCol") +@NameCn("") +public abstract class BaseCrossTrainBatchOp> + extends BatchOperator + implements AutoCrossTrainParams { + + static final String oneHotVectorCol = "oneHotVectorCol"; + boolean hasDiscrete = true; + + BaseCrossTrainBatchOp(Params params) { + super(params); + } + + //todo construct function. + + @Override + public T linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + + String[] reversedCols = getParams().get(HasReservedColsDefaultAsNull.RESERVED_COLS); + if (reversedCols == null) { + reversedCols = in.getColNames(); + } + + long mlEnvId = getMLEnvironmentId(); + + String[] featureCols = getSelectedCols(); + final String labelCol = getLabelCol(); + + String[] selectedCols = ArrayUtils.add(featureCols, labelCol); + in = in.select(selectedCols); + TableSchema inputSchema = in.getSchema(); + final ExecutionEnvironment env = MLEnvironmentFactory.get(mlEnvId).getExecutionEnvironment(); + + String[] categoricalCols = TableUtil.getCategoricalCols(in.getSchema(), featureCols, + getParams().contains(RandomForestTrainParams.CATEGORICAL_COLS) ? + getParams().get(RandomForestTrainParams.CATEGORICAL_COLS) : null + ); + if (null == categoricalCols || categoricalCols.length == 0) { + throw new AkIllegalArgumentException("Please input param CategoricalCols!"); + } + String[] numericalCols = ArrayUtils.removeElements(featureCols, categoricalCols); + + Params oneHotParams = new Params().set(AutoCrossTrainParams.SELECTED_COLS, categoricalCols); + if (getParams().contains(AutoCrossTrainParams.DISCRETE_THRESHOLDS_ARRAY)) { + oneHotParams.set(AutoCrossTrainParams.DISCRETE_THRESHOLDS_ARRAY, getDiscreteThresholdsArray()); + } else if (getParams().contains(AutoCrossTrainParams.DISCRETE_THRESHOLDS)) { + oneHotParams.set(AutoCrossTrainParams.DISCRETE_THRESHOLDS, getDiscreteThresholds()); + } + oneHotParams.set(HasDropLast.DROP_LAST, false) + .set(HasOutputColsDefaultAsNull.OUTPUT_COLS, new String[] {oneHotVectorCol}); + + OneHotTrainBatchOp oneHotModel = new OneHotTrainBatchOp(oneHotParams) + .setMLEnvironmentId(mlEnvId) + .linkFrom(in); + + OneHotEncoderModel oneHotEncoderModel = new OneHotEncoderModel(oneHotParams) + .setMLEnvironmentId(mlEnvId); + oneHotEncoderModel.setModelData(oneHotModel); + + //todo first do not train numerical model. + //if (numericalCols.size() != 0) { + // Params numericalParams = new Params() + // .set(AutoCrossTrainParams.SELECTED_COLS, numericalCols.toArray(new String[0])); + // if (getParams().contains(AutoCrossTrainParams.NUM_BUCKETS_ARRAY)) { + // numericalParams.set(AutoCrossTrainParams.NUM_BUCKETS_ARRAY, getNumBucketsArray()); + // } else if (getParams().contains(AutoCrossTrainParams.NUM_BUCKETS)) { + // numericalParams.set(AutoCrossTrainParams.NUM_BUCKETS, getNumBuckets()); + // } + // + // BatchOperator quantile; + // if (getBinningMethod().equals(BinningMethod.QUANTILE)) { + // quantile = new QuantileDiscretizerTrainBatchOp(numericalParams) + // .linkFrom(in); + // } else { + // quantile = new EqualWidthDiscretizerTrainBatchOp(numericalParams) + // .linkFrom(in); + // } + // QuantileDiscretizerModel numericalModel = BinningTrainBatchOp + // .setQuantileDiscretizerModelData( + // quantile, + // numericalCols.toArray(new String[0]), + // getParams().get(ML_ENVIRONMENT_ID)); + // listModel.add(numericalModel); + //} + + TransformerBase [] finalModel = new TransformerBase[2]; + finalModel[0] = oneHotEncoderModel; + + in = new OneHotPredictBatchOp(oneHotParams) + .setMLEnvironmentId(getMLEnvironmentId()) + .linkFrom(oneHotModel, in); + + hasDiscrete = OneHotModelMapper.isEnableElse(oneHotParams); + + DataSet featureSizeDataSet = env.fromElements(new int[0]) + .map(new BuildFeatureSize(hasDiscrete)) + .withBroadcastSet(oneHotModel.getDataSet(), "oneHotModel"); + + DataSet positiveLabel = in + .select(labelCol) + .getDataSet().reduceGroup(new FirstReducer(1)) + .map(new MapFunction () { + private static final long serialVersionUID = 110081999458221448L; + + @Override + public Object map(Row value) throws Exception { + return value.getField(0); + } + }); + + int svIndex = TableUtil.findColIndex(in.getColNames(), oneHotVectorCol); + int labelIndex = TableUtil.findColIndex(in.getColNames(), labelCol); + + int[] numericalIndices = TableUtil.findColIndicesWithAssert(in.getSchema(), numericalCols); + DataColumnsSaver dataColumnsSaver = new DataColumnsSaver(categoricalCols, numericalCols, numericalIndices); + + //here numerical cols is concatted first. + DataSet >> trainDataOrigin = + in.getDataSet().rebalance() + .mapPartition(new GetTrainData(numericalIndices, svIndex, labelIndex)) + .withBroadcastSet(positiveLabel, "positiveLabel") + .partitionCustom(new Partitioner () { + private static final long serialVersionUID = 5552966434608252752L; + + @Override + public int partition(Integer key, int numPartitions) { + return key; + } + }, 0); + + DataSet > trainData = trainDataOrigin + .mapPartition( + new RichMapPartitionFunction >, + Tuple3 >() { + @Override + public void mapPartition(Iterable >> values, + Collector > out) + throws Exception { + List > samples = new ArrayList <>(); + for (Tuple2 > value : values) { + samples.add(value.f1); + } + int taskId = getRuntimeContext().getIndexOfThisSubtask(); + SessionSharedData.put(AC_TRAIN_DATA, SESSION_ID, taskId, samples); + } + }); + + DataSet acModel = buildAcModelData(trainData, featureSizeDataSet, dataColumnsSaver); + + BatchOperator autoCrossBatchModel = BatchOperator + .fromTable( + DataSetConversionUtil.toTable(mlEnvId, + acModel, new String[] {"feature_id", "cross_feature", "score"}, + new TypeInformation[] {Types.LONG, Types.STRING, Types.DOUBLE})); + + Params autoCrossParams = getParams(); + autoCrossParams.set(HasReservedColsDefaultAsNull.RESERVED_COLS, reversedCols); + AutoCrossAlgoModel acPipelineModel = new AutoCrossAlgoModel(autoCrossParams) + .setModelData(autoCrossBatchModel) + .setMLEnvironmentId(mlEnvId); + finalModel[1] = acPipelineModel; + + BatchOperator modelSaved = new PipelineModel(finalModel).save(); + + DataSet modelRows = modelSaved + .getDataSet() + .map(new PipelineModelMapper.ExtendPipelineModelRow(selectedCols.length + 1)); + + setOutput(modelRows, getAutoCrossModelSchema(inputSchema, modelSaved.getSchema(), selectedCols)); + + buildSideOutput(oneHotModel, acModel, Arrays.asList(numericalCols), mlEnvId); + + return (T) this; + } + + abstract DataSet buildAcModelData(DataSet > trainData, + DataSet featureSizeDataSet, + DataColumnsSaver dataColumnsSaver); + + abstract void buildSideOutput(OneHotTrainBatchOp oneHotModel, DataSet acModel, + List numericalCols, long mlEnvId); + + static class DataColumnsSaver { + String[] categoricalCols; + String[] numericalCols; + int[] numericalIndices; + + DataColumnsSaver(String[] categoricalCols, + String[] numericalCols, + int[] numericalIndices) { + this.categoricalCols = categoricalCols; + this.numericalCols = numericalCols; + this.numericalIndices = numericalIndices; + } + } + + public static class BuildFeatureSize extends RichMapFunction { + private static final long serialVersionUID = 873642749154257046L; + private final int additionalSize; + private int[] featureSize; + + BuildFeatureSize(boolean hasDiscrete) { + additionalSize = hasDiscrete ? 2 : 1; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + List modelRow = getRuntimeContext().getBroadcastVariable("oneHotModel"); + MultiStringIndexerModelData data = new OneHotModelDataConverter().load(modelRow).modelData; + int featureNumber = data.tokenNumber.size(); + featureSize = new int[featureNumber]; + for (int i = 0; i < featureNumber; i++) { + featureSize[i] = (int) (data.tokenNumber.get(i) + additionalSize); + } + } + + @Override + public int[] map(int[] value) throws Exception { + return featureSize; + } + } + + //concat numerical cols and the onehot sv. + //todo note: cannot have null value in the numerical data. + public static class GetTrainData + extends RichMapPartitionFunction >> { + private static final long serialVersionUID = -4406174781328407356L; + private int[] numericalIndices; + private int svIndex; + private int labelIndex; + private Object positiveLabel; + + GetTrainData(int[] numericalIndices, int svIndex, int labelIndex) { + this.svIndex = svIndex; + this.labelIndex = labelIndex; + this.numericalIndices = numericalIndices; + } + + @Override + public void mapPartition(Iterable values, + Collector >> out) throws Exception { + int taskNum = getRuntimeContext().getNumberOfParallelSubtasks(); + int vecSize = -1; + int svSize = -1; + int[] vecIndices = null; + double[] vecValues = null; + SparseVector sv; + for (Row rowData : values) { + if (vecSize == -1) { + sv = VectorUtil.getSparseVector(rowData.getField(svIndex)); + vecSize = numericalIndices.length + sv.getIndices().length; + svSize = numericalIndices.length + sv.size(); + vecIndices = new int[vecSize]; + vecValues = new double[vecSize]; + } + for (int i = 0; i < numericalIndices.length; i++) { + vecIndices[i] = i; + vecValues[i] = ((Number) rowData.getField(numericalIndices[i])).doubleValue(); + } + sv = VectorUtil.getSparseVector(rowData.getField(svIndex)); + int[] svIndices = new int[sv.getIndices().length]; + for (int i = 0; i < svIndices.length; i++) { + svIndices[i] = sv.getIndices()[i] + numericalIndices.length; + } + System.arraycopy(svIndices, 0, vecIndices, numericalIndices.length, sv.getIndices().length); + System.arraycopy(sv.getValues(), 0, vecValues, numericalIndices.length, sv.getValues().length); + + for (int i = 0; i < taskNum; i++) { + out.collect(Tuple2.of(i, Tuple3.of( + 1.0, + positiveLabel.equals(rowData.getField(labelIndex)) ? 1. : 0., + new SparseVector(svSize, vecIndices, vecValues)))); + } + + } + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + positiveLabel = getRuntimeContext().getBroadcastVariable("positiveLabel").get(0); + } + } + + public static TableSchema getAutoCrossModelSchema( + TableSchema dataSchema, TableSchema modelSchema, String[] selectedCols) { + + int pipeFieldCount = modelSchema.getFieldNames().length; + String[] modelCols = new String[pipeFieldCount + 1 + selectedCols.length]; + TypeInformation [] modelType = new TypeInformation[pipeFieldCount + 1 + selectedCols.length]; + System.arraycopy(modelSchema.getFieldNames(), 0, modelCols, 0, pipeFieldCount); + System.arraycopy(modelSchema.getFieldTypes(), 0, modelType, 0, pipeFieldCount); + modelCols[pipeFieldCount] = PipelineModelMapper.SPLITER_COL_NAME; + modelType[pipeFieldCount] = PipelineModelMapper.SPLITER_COL_TYPE; + System.arraycopy(selectedCols, 0, modelCols, pipeFieldCount + 1, selectedCols.length); + for (int i = 0; i < selectedCols.length; i++) { + int index = TableUtil.findColIndex(dataSchema, selectedCols[i]); + modelType[i + pipeFieldCount + 1] = dataSchema.getFieldTypes()[index]; + } + return new TableSchema(modelCols, modelType); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/BinarizerBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/BinarizerBatchOp.java index ed352b081..d93d8db53 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/BinarizerBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/BinarizerBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -14,6 +15,7 @@ */ @ParamSelectColumnSpec(name="selectedCol", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("二值化") +@NameEn("Binarize") public final class BinarizerBatchOp extends MapBatchOp implements BinarizerParams { private static final long serialVersionUID = -8285479274916036924L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/BinningPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/BinningPredictBatchOp.java new file mode 100644 index 000000000..b541b7c06 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/BinningPredictBatchOp.java @@ -0,0 +1,26 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.feature.BinningModelMapper; +import com.alibaba.alink.params.finance.BinningPredictParams; + +@NameCn("分箱预测") +@NameEn("binning predictor") +public final class BinningPredictBatchOp extends ModelMapBatchOp + implements BinningPredictParams { + + private static final long serialVersionUID = 8864849007288633705L; + + public BinningPredictBatchOp() { + this(null); + } + + public BinningPredictBatchOp(Params params) { + super(BinningModelMapper::new, params); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/BinningTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/BinningTrainBatchOp.java new file mode 100644 index 000000000..b31c4aa11 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/BinningTrainBatchOp.java @@ -0,0 +1,734 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.JoinFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple5; +import org.apache.flink.ml.api.misc.param.ParamInfo; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import com.alibaba.alink.common.MLEnvironmentFactory; +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.SelectedColsWithFirstInputSpec; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; +import com.alibaba.alink.common.viz.AlinkViz; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.common.viz.VizDataWriterInterface; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.DataSetUtil; +import com.alibaba.alink.operator.common.dataproc.StringIndexerUtil; +import com.alibaba.alink.operator.common.feature.BinningModelDataConverter; +import com.alibaba.alink.operator.common.feature.BinningModelInfo; +import com.alibaba.alink.operator.common.feature.BinningModelMapper; +import com.alibaba.alink.operator.common.feature.OneHotModelDataConverter; +import com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelDataConverter; +import com.alibaba.alink.operator.common.feature.binning.BinDivideType; +import com.alibaba.alink.operator.common.feature.binning.BinTypes; +import com.alibaba.alink.operator.common.feature.binning.BinningModelInfoBatchOp; +import com.alibaba.alink.operator.common.feature.binning.Bins; +import com.alibaba.alink.operator.common.feature.binning.FeatureBinsCalculator; +import com.alibaba.alink.operator.common.feature.binning.FeatureBinsUtil; +import com.alibaba.alink.operator.common.similarity.SerializableComparator; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; +import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummary; +import com.alibaba.alink.operator.common.statistics.statistics.IntervalCalculator; +import com.alibaba.alink.params.dataproc.HasHandleInvalid; +import com.alibaba.alink.params.feature.HasDropLast; +import com.alibaba.alink.params.feature.HasEncodeWithoutWoe; +import com.alibaba.alink.params.feature.QuantileDiscretizerTrainParams; +import com.alibaba.alink.params.finance.BinningTrainParams; +import com.alibaba.alink.params.shared.HasMLEnvironmentId; +import com.alibaba.alink.params.shared.colname.HasSelectedCols; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; +import com.alibaba.alink.pipeline.PipelineModel; +import com.alibaba.alink.pipeline.TransformerBase; +import com.alibaba.alink.pipeline.feature.OneHotEncoderModel; +import com.alibaba.alink.pipeline.feature.QuantileDiscretizerModel; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.TreeSet; +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = @PortSpec(value = PortType.MODEL, desc = PortDesc.OUTPUT_RESULT)) +@SelectedColsWithFirstInputSpec +@NameCn("分箱训练") +@NameEn("binning trainer") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.Binning") +public final class BinningTrainBatchOp extends BatchOperator + implements BinningTrainParams , AlinkViz , + WithModelInfoBatchOp { + + private static final long serialVersionUID = 2424584385121349839L; + private static String FEATURE_DELIMITER = ","; + private static String KEY_VALUE_DELIMITER = ":"; + private static int BUCKET_NUMBER = 10000; + private static long DATA_ID_MAP = 1L; + + public BinningTrainBatchOp() {} + + public BinningTrainBatchOp(Params params) { + super(params); + } + + @Override + public BinningTrainBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + + List transformers = new ArrayList <>(); + DataSet featureBorderDataSet = null; + + if (this.getFromUserDefined()) { + Preconditions.checkNotNull(this.getUserDefinedBin(), "User defined bin is empty!"); + //numeric, discrete + Tuple5 , HashSet , List , HashSet , + String[]> featureBorders = parseUserDefined(this.getParams()); + //check whether the bindividetype and leftOpen is the same + Tuple2 binDivideTypeLeftOpen = parseUserDefinedNumeric(featureBorders.f0); + //update params + this.setSelectedCols(featureBorders.f4) + .setBinningMethod(BinningMethod.valueOf(binDivideTypeLeftOpen.f0.name())) + .setLeftOpen(binDivideTypeLeftOpen.f1); + + //addOneHotModel + if (featureBorders.f2.size() > 0) { + Tuple2 , OneHotEncoderModel> tuple = discreteFromUserDefined( + featureBorders.f2, + getMLEnvironmentId(), featureBorders.f3.toArray(new String[0])); + featureBorderDataSet = tuple.f0; + transformers.add(tuple.f1); + } + //addQuantileModel + if (featureBorders.f0.size() > 0) { + Tuple2 , QuantileDiscretizerModel> tuple = numericFromUserDefined( + featureBorders.f0, + getMLEnvironmentId(), featureBorders.f1.toArray(new String[0])); + transformers.add(tuple.f1); + featureBorderDataSet = null == featureBorderDataSet ? tuple.f0 : featureBorderDataSet.union(tuple.f0); + } + } else { + //numeric, discrete + Tuple2 , List > t = BinningModelMapper.distinguishNumericDiscrete( + this.getSelectedCols(), + in.getSchema()); + //add onehot model + if (t.f1.size() > 0) { + Tuple2 , OneHotEncoderModel> tuple = discreteFromTrain(in, + t.f1.toArray(new String[0]), getParams()); + featureBorderDataSet = tuple.f0; + transformers.add(tuple.f1); + } + //add quantile model + if (t.f0.size() > 0) { + Tuple2 , QuantileDiscretizerModel> tuple = numericFromTrain(in, + t.f0.toArray(new String[0]), getParams()); + transformers.add(tuple.f1); + featureBorderDataSet = null == featureBorderDataSet ? tuple.f0 : featureBorderDataSet.union(tuple.f0); + } + } + + Preconditions.checkNotNull(featureBorderDataSet, "No binning is generated, please check input!"); + + PipelineModel pipelineModel = new PipelineModel(transformers.toArray(new TransformerBase[0])); + BatchOperator index = pipelineModel.transform(in); + featureBorderDataSet = featureBinsStatistics(featureBorderDataSet, index, getParams()); + + DataSet featureBorderModel = featureBorderDataSet + .mapPartition(new SerializeFeatureBinsModel()) + .name("SerializeModel"); + + VizDataWriterInterface writer = this.getVizDataWriter(); + if (writer != null) { + writeVizData(featureBorderDataSet, writer, getSelectedCols()); + writePreciseVizData(in, featureBorderDataSet, getParams(), writer); + } + + this.setOutput(featureBorderModel, new BinningModelDataConverter().getModelSchema()); + return this; + } + + private static class FeatureBinsKey implements KeySelector { + private static final long serialVersionUID = 4363650897636276540L; + + @Override + public String getKey(FeatureBinsCalculator calculator) { + return calculator.getFeatureName(); + } + } + + static DataSet featureBinsStatistics(DataSet featureBorderDataSet, + BatchOperator index, Params param) { + if (null != param.get(BinningTrainParams.LABEL_COL)) { + Preconditions.checkArgument(param.contains(BinningTrainParams.POS_LABEL_VAL_STR), + "PositiveValue is not set!"); + return setFeatureBinsTotalAndWoe(featureBorderDataSet, index, param); + } else { + return setFeatureBinsTotal(featureBorderDataSet, index, param.get(BinningTrainParams.SELECTED_COLS)); + } + } + + static class SerializeFeatureBinsModel extends RichMapPartitionFunction { + private static final long serialVersionUID = 2703991631667194660L; + + @Override + public void mapPartition(Iterable iterable, Collector collector) { + new BinningModelDataConverter().save(iterable, collector); + } + } + + public static Tuple2 , QuantileDiscretizerModel> numericFromTrain(BatchOperator + in, + String[] + numericCols, + Params params) { + BinningMethod binningMethod = params.get(BinningTrainParams.BINNIG_METHOD); + switch (binningMethod) { + case QUANTILE: { + return quantileTrain(in, numericCols, params); + } + case BUCKET: { + return bucketTrain(in, numericCols, params); + } + default: { + throw new IllegalArgumentException("Not support binningMethod: " + binningMethod); + } + } + } + + public static Tuple2 , OneHotEncoderModel> discreteFromTrain(BatchOperator in, + String[] + discreteCols, + Params params) { + Integer[] discreteThreshold = getValueArray(params, BinningTrainParams.DISCRETE_THRESHOLDS, + BinningTrainParams.DISCRETE_THRESHOLDS_ARRAY, BinningTrainParams.DISCRETE_THRESHOLDS_MAP, discreteCols); + + OneHotTrainBatchOp oneHot = new OneHotTrainBatchOp() + .setSelectedCols(discreteCols) + .setDiscreteThresholdsArray(discreteThreshold) + .linkFrom(in); + + DataSet featureBinsCalculator = OneHotTrainBatchOp + .transformModelToFeatureBins(oneHot.getDataSet()); + OneHotEncoderModel model = setOneHotModelData(oneHot, discreteCols, params.get(ML_ENVIRONMENT_ID)); + return Tuple2.of(featureBinsCalculator, model); + } + + private static Tuple2 , QuantileDiscretizerModel> bucketTrain(BatchOperator in, + String[] numericCols, + Params params) { + EqualWidthDiscretizerTrainBatchOp quantile = new EqualWidthDiscretizerTrainBatchOp( + generateNumericTrainParams(params, numericCols)) + .linkFrom(in); + + DataSet featureBinsCalculator = QuantileDiscretizerTrainBatchOp + .transformModelToFeatureBins(quantile.getDataSet(), BinDivideType.BUCKET); + + QuantileDiscretizerModel model = setQuantileDiscretizerModelData( + quantile, + numericCols, + params.get(ML_ENVIRONMENT_ID)); + return Tuple2.of(featureBinsCalculator, model); + } + + private static Tuple2 , QuantileDiscretizerModel> quantileTrain(BatchOperator in, + String[] + numericCols, + Params params) { + QuantileDiscretizerTrainBatchOp quantile = new QuantileDiscretizerTrainBatchOp( + generateNumericTrainParams(params, numericCols)) + .linkFrom(in); + + DataSet featureBinsCalculator = QuantileDiscretizerTrainBatchOp + .transformModelToFeatureBins(quantile.getDataSet(), BinDivideType.QUANTILE); + + QuantileDiscretizerModel model = setQuantileDiscretizerModelData( + quantile, + numericCols, + params.get(ML_ENVIRONMENT_ID)); + return Tuple2.of(featureBinsCalculator, model); + } + + private static Tuple2 , OneHotEncoderModel> discreteFromUserDefined( + List featureBinsCalculators, + long environmentId, + String[] discreteCols) { + DataSet featureBorderDataSet = MLEnvironmentFactory + .get(environmentId) + .getExecutionEnvironment() + .fromCollection(featureBinsCalculators) + .name("DiscreteFromUserDefined"); + + OneHotEncoderModel model = setOneHotModelData( + BatchOperator.fromTable( + DataSetConversionUtil + .toTable(environmentId, OneHotTrainBatchOp.transformFeatureBinsToModel(featureBorderDataSet), + new OneHotModelDataConverter().getModelSchema()) + ) + , discreteCols, environmentId); + + return Tuple2.of(featureBorderDataSet, model); + } + + private static Tuple2 , QuantileDiscretizerModel> numericFromUserDefined( + List featureBinsCalculators, + long environmentId, + String[] numericCols) { + DataSet featureBorderDataSet = MLEnvironmentFactory + .get(environmentId) + .getExecutionEnvironment() + .fromCollection(featureBinsCalculators) + .name("NumericFromUserDefined"); + + QuantileDiscretizerModel model = setQuantileDiscretizerModelData( + BatchOperator.fromTable( + DataSetConversionUtil.toTable(environmentId, + QuantileDiscretizerTrainBatchOp.transformFeatureBinsToModel(featureBorderDataSet), + new QuantileDiscretizerModelDataConverter().getModelSchema()) + ) + , numericCols, + environmentId); + + return Tuple2.of(featureBorderDataSet, model); + } + + private static Tuple5 , HashSet , List , HashSet + , String[]> parseUserDefined(Params params) { + HashSet userDefinedNumeric = new HashSet <>(); + HashSet userDefinedDiscrete = new HashSet <>(); + String[] selectedCols = params.get(BinningTrainBatchOp.SELECTED_COLS); + Tuple2 , List > numericDiscrete = BinningModelMapper + .distinguishNumericDiscrete( + FeatureBinsUtil.deSerialize(params.get(BinningTrainBatchOp.USER_DEFINED_BIN)), + selectedCols, + userDefinedNumeric, + userDefinedDiscrete); + + String[] newSelectedCols = new String[userDefinedDiscrete.size() + userDefinedNumeric.size()]; + int c = 0; + for (String s : selectedCols) { + if (userDefinedDiscrete.contains(s) || userDefinedNumeric.contains(s)) { + newSelectedCols[c++] = s; + } + } + return Tuple5.of(numericDiscrete.f0, userDefinedNumeric, numericDiscrete.f1, userDefinedDiscrete, + newSelectedCols); + } + + private static Tuple2 parseUserDefinedNumeric(List list) { + BinDivideType binDivideType = null; + Boolean leftOpen = null; + for (FeatureBinsCalculator calculator : list) { + Preconditions.checkArgument(calculator.isNumeric(), "parseUserDefinedNumeric only supports numeric bins!"); + if (null == binDivideType) { + binDivideType = calculator.getBinDivideType(); + leftOpen = calculator.getLeftOpen(); + } else { + Preconditions.checkArgument(binDivideType.equals(calculator.getBinDivideType()), + "Features have different BinDivideType!"); + Preconditions.checkArgument(leftOpen.equals(calculator.getLeftOpen()), + "Features have different leftOpen params!"); + } + } + return Tuple2.of(binDivideType, leftOpen); + } + + public static DataSet setFeatureBinsTotalAndWoe( + DataSet featureBorderDataSet, + BatchOperator index, + Params params) { + Preconditions.checkArgument(TableUtil.findColIndex(params.get(BinningTrainBatchOp.SELECTED_COLS), + params.get(BinningTrainBatchOp.LABEL_COL)) < 0, + "labelCol is included in selectedCols"); + WoeTrainBatchOp op = new WoeTrainBatchOp(params).linkFrom(index); + return WoeTrainBatchOp.setFeatureBinsWoe(featureBorderDataSet, op.getDataSet()); + } + + static DataSet setFeatureBinsTotal(DataSet featureBorderDataSet, + BatchOperator index, + String[] selectedCols) { + DataSet > tokenCounts = StringIndexerUtil.countTokens( + index.select(selectedCols).getDataSet(), true); + DataSet > borderWithName = featureBorderDataSet.map( + new MapFunction >() { + private static final long serialVersionUID = 5564348245552253677L; + + @Override + public Tuple2 map(FeatureBinsCalculator value) { + return Tuple2.of(value.getFeatureName(), value); + } + }); + + DataSet >> featureCounts = tokenCounts + .groupBy(0) + .reduceGroup(new GroupReduceFunction , Tuple2 >> + () { + private static final long serialVersionUID = -3648772506771320958L; + + @Override + public void reduce(Iterable > values, + Collector >> out) { + String featureName = null; + Map map = new HashMap <>(); + for (Tuple3 t : values) { + featureName = selectedCols[t.f0]; + map.put(Long.valueOf(t.f1), t.f2); + } + if (null != featureName) { + out.collect(Tuple2.of(featureName, map)); + } + } + }).name("GetBinTotalMap"); + + return borderWithName + .join(featureCounts) + .where(0) + .equalTo(0) + .with( + new JoinFunction , Tuple2 >, + FeatureBinsCalculator>() { + private static final long serialVersionUID = -7326232704777819964L; + + @Override + public FeatureBinsCalculator join(Tuple2 first, + Tuple2 > second) { + FeatureBinsCalculator border = first.f1; + border.setTotal(second.f1); + return border; + } + }).name("SetBinTotal"); + } + + private static Params generateNumericTrainParams(Params params, String[] numericCols) { + Integer[] discreteThreshold = getValueArray(params, BinningTrainParams.NUM_BUCKETS, + BinningTrainParams.NUM_BUCKETS_ARRAY, BinningTrainParams.NUM_BUCKETS_MAP, numericCols); + return new Params() + .set(QuantileDiscretizerTrainParams.SELECTED_COLS, numericCols) + .set(QuantileDiscretizerTrainParams.NUM_BUCKETS_ARRAY, discreteThreshold) + .set(QuantileDiscretizerTrainParams.LEFT_OPEN, params.get(BinningTrainParams.LEFT_OPEN)); + } + + private static OneHotEncoderModel setOneHotModelData(BatchOperator modelData, + String[] selectedCols, + long environmentId) { + OneHotEncoderModel oneHotEncode = new OneHotEncoderModel( + encodeIndexForWoeTrainParams(selectedCols, environmentId)); + oneHotEncode.setModelData(modelData); + return oneHotEncode; + } + + static QuantileDiscretizerModel setQuantileDiscretizerModelData( + BatchOperator modelData, + String[] selectedCols, + long environmentId) { + QuantileDiscretizerModel quantileDiscretizerModel = new QuantileDiscretizerModel( + encodeIndexForWoeTrainParams(selectedCols, environmentId)); + quantileDiscretizerModel.setModelData(modelData); + return quantileDiscretizerModel; + } + + public static Params encodeIndexForWoeTrainParams(String[] selectedCols, long environmentId) { + return new Params() + .set(HasMLEnvironmentId.ML_ENVIRONMENT_ID, environmentId) + .set(HasSelectedCols.SELECTED_COLS, selectedCols) + .set(HasHandleInvalid.HANDLE_INVALID, HasHandleInvalid.HandleInvalid.KEEP) + .set(HasDropLast.DROP_LAST, false) + .set(HasEncodeWithoutWoe.ENCODE, HasEncodeWithoutWoe.Encode.INDEX); + } + + //write quick binning viz data + private static void writePreciseVizData(BatchOperator in, + DataSet originFeatureBins, + Params params, + VizDataWriterInterface writer) { + //numeric, discrete + Tuple2 , List > t = BinningModelMapper.distinguishNumericDiscrete( + params.get(BinningTrainParams.SELECTED_COLS), + in.getSchema()); + //add quantile model + if (t.f0.size() > 0) { + params.set(BinningTrainParams.SELECTED_COLS, t.f0.toArray(new String[0])); + DataSet summary = StatisticsHelper.summary(in, t.f0.toArray(new String[0])); + DataSet preciseFeatureBins = originFeatureBins.flatMap( + new RichFlatMapFunction () { + private static final long serialVersionUID = -1054858627945256161L; + + @Override + public void flatMap(FeatureBinsCalculator value, Collector out) { + BinTypes.ColType colType = value.getColType(); + if (!colType.isNumeric) { + return; + } + Number[] numbers = value.getSplitsArray(); + boolean isFloat = colType.equals(BinTypes.ColType.FLOAT); + for (Number number2 : numbers) { + isFloat |= (number2 instanceof Double || number2 instanceof Float); + } + TreeSet set = generateCloseBucket(numbers, isFloat); + String fatureName = value.getFeatureName(); + + TableSummary summary = (TableSummary) getRuntimeContext().getBroadcastVariable("summary").get( + 0); + set.addAll( + generateGivenNumberBucket(summary.minDouble(fatureName), summary.maxDouble(fatureName), isFloat)); + + out.collect(FeatureBinsCalculator.createNumericCalculator(value.getBinDivideType(), + value.getFeatureName(), + FeatureBinsUtil.getFlinkType(value.getFeatureType()), + set.toArray(new Number[0]), + value.getLeftOpen())); + } + }).withBroadcastSet(summary, "summary"); + + QuantileDiscretizerModel model = setQuantileDiscretizerModelData( + BatchOperator.fromTable( + DataSetConversionUtil.toTable(in.getMLEnvironmentId(), + QuantileDiscretizerTrainBatchOp.transformFeatureBinsToModel(preciseFeatureBins), + new QuantileDiscretizerModelDataConverter().getModelSchema()) + ), + params.get(BinningTrainParams.SELECTED_COLS), + in.getMLEnvironmentId()); + BatchOperator preciseCut = model.transform(in); + preciseFeatureBins = featureBinsStatistics(preciseFeatureBins, preciseCut, params); + writePreciseVizData(preciseFeatureBins, originFeatureBins, params.get(BinningTrainParams.SELECTED_COLS), + writer); + } + } + + static List generateGivenNumberBucket(double minDecimal, double maxDecimal, boolean isFloat) { + List list = new ArrayList <>(); + if (!isFloat) { + long min = (long) minDecimal; + long max = (long) maxDecimal; + IntervalCalculator intervalCalculator = IntervalCalculator.create(new long[] {min - 1, max + 1}, + BUCKET_NUMBER); + long start = intervalCalculator.getLeftBound().longValue(); + long step = intervalCalculator.getStep().longValue(); + for (int i = 0; i < intervalCalculator.getCount().length; i++) { + list.add(start); + start += step; + } + } else { + IntervalCalculator intervalCalculator = IntervalCalculator.create( + new double[] {minDecimal - 0.1, maxDecimal + 0.1}, BUCKET_NUMBER); + double start = intervalCalculator.getLeftBound().doubleValue(); + double step = intervalCalculator.getStep().doubleValue(); + for (int i = 0; i < intervalCalculator.n; i++) { + list.add(start); + start += step; + } + } + return list; + } + + static TreeSet generateCloseBucket(Number[] numbers, boolean isFloat) { + TreeSet set = new TreeSet <>(); + if (!isFloat) { + for (Number number1 : numbers) { + set.add(number1.longValue()); + boolean negative = (number1.longValue() < 0); + long number = Math.abs(number1.longValue()); + long deno = 1; + while (deno <= number * 10) { + long tmp = number / deno * deno; + set.add(negative ? -tmp : tmp); + set.add(negative ? -(tmp + deno) : tmp + deno); + deno *= 10; + } + } + } else { + for (Number number1 : numbers) { + set.add(number1.doubleValue()); + boolean negative = (number1.doubleValue() < 0); + double number = Math.abs(number1.doubleValue()); + if (number < 1) { + int deno = 1; + double target = number * deno; + while (Double.compare(target, Math.floor(target)) != 0) { + double tmp = Math.floor(target); + set.add(negative ? -tmp / deno : tmp / deno); + set.add(negative ? -(tmp + 1) / deno : (tmp + 1) / deno); + deno *= 10; + target = number * deno; + } + } else { + long numberN = Math.abs(number1.longValue()); + long deno = 1; + while (deno <= number * 10) { + long tmp = numberN / deno * deno; + set.add(negative ? -(double) tmp : (double) tmp); + set.add(negative ? -(double) (tmp + deno) : (double) (tmp + deno)); + deno *= 10; + } + + + } + } + } + return set; + } + + private static void writePreciseVizData(DataSet featureBorderDataSet, + DataSet originFeatureBins, + String[] selecteCols, + VizDataWriterInterface writer) { + Map keyId = new HashMap <>(); + //write map: column name : column id + long start = DATA_ID_MAP + 1L; + for (String col : selecteCols) { + keyId.put(col, start++); + } + //System.out.println(JsonConverter.toJson(keyId)); + writer.writeBatchData(DATA_ID_MAP, JsonConverter.toJson(keyId), System.currentTimeMillis()); + + DataSet dummy = featureBorderDataSet.join(originFeatureBins) + .where(new FeatureBinsKey()) + .equalTo(new FeatureBinsKey()) + .with(new JoinFunction () { + private static final long serialVersionUID = 4855360130213173442L; + + @Override + public Row join(FeatureBinsCalculator first, FeatureBinsCalculator second) throws Exception { + Preconditions.checkArgument(first.isNumeric(), "Precise only support numeric bins!"); + first.calcStatistics(); + List originBins = Arrays.asList(second.getSplitsArray()); + Number[] cutsArray = first.getSplitsArray(); + Bins bins = first.getBin(); + List map = new ArrayList <>(); + List cutsMap = new ArrayList <>(); + Integer positive = null; + int total = 0; + int pre = -1; + for (int i = 0; i < bins.normBins.size(); i++) { + Bins.BaseBin baseBin = bins.normBins.get(i); + //keep the original cutsArray, only keep two cuts whose total are the same + if (baseBin.getTotal() > 0 || pre != 0 || originBins.contains(cutsArray[i - 1])) { + pre = baseBin.getTotal().intValue(); + total += pre; + if (first.getPositiveTotal() != null) { + positive = ((null == positive) ? baseBin.getPositive().intValue() + : positive + baseBin.getPositive().intValue()); + } + map.add(new IntervalStatistics(total, positive)); + if (i > 0) { + cutsMap.add(cutsArray[i - 1]); + } + } + } + int index = TableUtil.findColIndexWithAssert(selecteCols, first.getFeatureName()); + Params params = new Params().set("Interval", cutsMap); + params.set("Statistics", map); + //System.out.println((DATA_ID_MAP + 1 + index) + ":" + params.toJson()); + writer.writeBatchData((DATA_ID_MAP + 1 + index), params.toJson(), System.currentTimeMillis()); + return new Row(1); + } + }); + DataSetUtil.linkDummySink(dummy); + } + + private static void writeVizData(DataSet featureBorderDataSet, + VizDataWriterInterface writer, + String[] selectedCols) { + DataSet dummy = featureBorderDataSet.mapPartition( + new MapPartitionFunction () { + private static final long serialVersionUID = 1298967177624132843L; + + @Override + public void mapPartition(Iterable values, Collector out) { + List list = new ArrayList <>(); + values.forEach(list::add); + list.sort(new SerializableComparator () { + private static final long serialVersionUID = -6390285370495541755L; + + @Override + public int compare(FeatureBinsCalculator o1, FeatureBinsCalculator o2) { + return TableUtil.findColIndex(selectedCols, o1.getFeatureName()) < TableUtil.findColIndex( + selectedCols, o2.getFeatureName()) ? -1 : 1; + } + }); + //System.out.println(0L + ":" + FeatureBinsUtil.serialize(list.toArray(new + // FeatureBinsCalculator[0]))); + writer.writeBatchData(0L, FeatureBinsUtil.serialize(list.toArray(new FeatureBinsCalculator[0])), + System.currentTimeMillis()); + } + }).setParallelism(1).name("WriteFeatureBinsViz"); + DataSetUtil.linkDummySink(dummy); + } + + static Integer[] getValueArray(Params params, ParamInfo single, + ParamInfo array, + ParamInfo map, + String[] selectedCols) { + Preconditions.checkArgument(!(params.contains(single) && (params.contains(array))), + "It can not set " + single.getName() + " " + array.getName() + " at the same time!"); + Preconditions.checkArgument(!(params.contains(array) && (params.contains(map))), + "It can not set " + map.getName() + " " + array.getName() + " at the same time!"); + + Integer[] values; + if (params.contains(array)) { + values = params.get(array); + Preconditions.checkArgument(values.length == selectedCols.length, + "The length of %s must be equal to the length of train cols!", array.getName()); + } else { + values = new Integer[selectedCols.length]; + Arrays.fill(values, params.get(single)); + if (params.contains(map)) { + Map keyValue = parseInputMap(params.get(map)); + for (Map.Entry entry : keyValue.entrySet()) { + int index = TableUtil.findColIndexWithAssertAndHint(selectedCols, entry.getKey()); + values[index] = entry.getValue(); + } + } + } + return values; + } + + static Map parseInputMap(String str) { + String[] cols = str.split(FEATURE_DELIMITER); + Map map = new HashMap <>(); + for (String s : cols) { + String[] nameValue = s.split(KEY_VALUE_DELIMITER); + Preconditions.checkArgument(nameValue.length == 2, "Input Map parse fail!"); + map.put(nameValue[0].trim(), Integer.valueOf(nameValue[1])); + } + return map; + } + + static class IntervalStatistics implements Serializable { + private static final long serialVersionUID = 7680582727625414983L; + Integer total; + Integer positive; + + public IntervalStatistics(Integer total, Integer positive) { + this.total = total; + this.positive = positive; + } + } + + @Override + public BinningModelInfoBatchOp getModelInfoBatchOp() { + return new BinningModelInfoBatchOp().linkFrom(this); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/BucketizerBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/BucketizerBatchOp.java index af9b9f8dc..be6374aba 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/BucketizerBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/BucketizerBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.feature.BucketizerMapper; import com.alibaba.alink.params.feature.BucketizerParams; @@ -19,6 +20,7 @@ * segments with delimiter ",". */ @NameCn("分桶") +@NameEn("Bucketize") public final class BucketizerBatchOp extends MapBatchOp implements BucketizerParams { private static final long serialVersionUID = -2658623503634689607L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/C45EncoderTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/C45EncoderTrainBatchOp.java new file mode 100644 index 000000000..70d6a18e0 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/C45EncoderTrainBatchOp.java @@ -0,0 +1,17 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.classification.C45TrainBatchOp; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; + +@NameCn("C45决策树分类编码器训练") +@NameEn(" C45 Encoder Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.C45Encoder") +public class C45EncoderTrainBatchOp extends C45TrainBatchOp { + public C45EncoderTrainBatchOp(Params params) { + super(params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/CartEncoderTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/CartEncoderTrainBatchOp.java new file mode 100644 index 000000000..3e826df48 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/CartEncoderTrainBatchOp.java @@ -0,0 +1,17 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.classification.CartTrainBatchOp; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; + +@NameCn("CART决策树分类编码器训练") +@NameEn(" Cart Encoder Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.CartEncoder") +public class CartEncoderTrainBatchOp extends CartTrainBatchOp { + public CartEncoderTrainBatchOp(Params params) { + super(params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/CartRegEncoderTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/CartRegEncoderTrainBatchOp.java new file mode 100644 index 000000000..b3c84c465 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/CartRegEncoderTrainBatchOp.java @@ -0,0 +1,17 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.regression.CartRegTrainBatchOp; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; + +@NameCn("CART决策树回归编码器训练") +@NameEn(" Cart RegEncoder Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.CartRegEncoder") +public class CartRegEncoderTrainBatchOp extends CartRegTrainBatchOp { + public CartRegEncoderTrainBatchOp(Params params) { + super(params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/ChiSqSelectorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/ChiSqSelectorBatchOp.java index 861a9cca2..c9e0a8613 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/ChiSqSelectorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/ChiSqSelectorBatchOp.java @@ -6,13 +6,14 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.feature.ChiSqSelectorModelDataConverter; import com.alibaba.alink.operator.common.feature.ChisqSelectorModelInfo; @@ -27,6 +28,7 @@ @OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) @ParamSelectColumnSpec(name = "selectedCols") @NameCn("卡方选择器") +@NameEn("Chisq Selector") public final class ChiSqSelectorBatchOp extends BatchOperator implements ChiSqSelectorParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/ChisqSelectorModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/ChisqSelectorModelInfoBatchOp.java index 6c4b83eb6..d5b4a37ad 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/ChisqSelectorModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/ChisqSelectorModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.feature.ChisqSelectorModelInfo; import java.util.List; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/CrossCandidateSelectorPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/CrossCandidateSelectorPredictBatchOp.java new file mode 100644 index 000000000..078620340 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/CrossCandidateSelectorPredictBatchOp.java @@ -0,0 +1,26 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.feature.AutoCross.CrossCandidateSelectorModelMapper; +import com.alibaba.alink.params.feature.featuregenerator.CrossCandidateSelectorPredictParams; + +@NameCn("cross候选特征选择预测") +@NameEn("Cross Candidate Selector Prediction") +public class CrossCandidateSelectorPredictBatchOp + extends ModelMapBatchOp + implements CrossCandidateSelectorPredictParams { + + private static final long serialVersionUID = 3987270029076248190L; + + public CrossCandidateSelectorPredictBatchOp() { + this(new Params()); + } + + public CrossCandidateSelectorPredictBatchOp(Params params) { + super(CrossCandidateSelectorModelMapper::new, params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/CrossCandidateSelectorTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/CrossCandidateSelectorTrainBatchOp.java new file mode 100644 index 000000000..6edbb648e --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/CrossCandidateSelectorTrainBatchOp.java @@ -0,0 +1,196 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.AlinkGlobalConfiguration; +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.common.feature.AutoCross.FeatureEvaluator; +import com.alibaba.alink.operator.common.feature.AutoCross.FeatureSet; +import com.alibaba.alink.operator.common.linear.LinearModelType; +import com.alibaba.alink.operator.common.slidingwindow.SessionSharedData; +import com.alibaba.alink.params.feature.CrossCandidateSelectorTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; + +import static com.alibaba.alink.operator.batch.feature.AutoCrossTrainBatchOp.AC_TRAIN_DATA; +import static com.alibaba.alink.operator.batch.feature.AutoCrossTrainBatchOp.SESSION_ID; + +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = {@PortSpec(PortType.MODEL)}) +@NameCn("cross候选特征选择训练") +@NameEn("Cross Candidate Selector Train") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.CrossCandidateSelector") +public class CrossCandidateSelectorTrainBatchOp + extends BaseCrossTrainBatchOp + implements CrossCandidateSelectorTrainParams { + + public CrossCandidateSelectorTrainBatchOp() { + this(new Params()); + } + + public CrossCandidateSelectorTrainBatchOp(Params params) { + super(params); + } + + DataSet buildAcModelData(DataSet > trainData, + DataSet featureSizeDataSet, + DataColumnsSaver dataColumnsSaver) { + String[] numericalCols = dataColumnsSaver.numericalCols; + String[] categoricalCols = dataColumnsSaver.categoricalCols; + int[] numericalIndices = dataColumnsSaver.numericalIndices; + + DataSet > candidateAndAuc = trainData + .mapPartition(new CalcAucOfCandidate(categoricalCols, numericalIndices.length, getFeatureCandidates())) + .withBroadcastSet(featureSizeDataSet, "featureSize") + .withBroadcastSet(trainData, "barrier") + .partitionByHash(3); + + DataSet acModel = candidateAndAuc + .mapPartition(new FilterAuc(getCrossFeatureNumber(), oneHotVectorCol, numericalCols)) + .withBroadcastSet(featureSizeDataSet, "featureSize"); + return acModel; + } + + @Override + void buildSideOutput(OneHotTrainBatchOp oneHotModel, DataSet acModel, List numericalCols, + long mlEnvId) { + + } + + private static class CalcAucOfCandidate + extends RichMapPartitionFunction , + Tuple4 > { + private List candidateIndices; + private int[] featureSize; + private int numTasks; + private int numericalSize; + + CalcAucOfCandidate(String[] inputCols, int numericalSize, String[] featureCandidates) { + this.numericalSize = numericalSize; + candidateIndices = new ArrayList <>(featureCandidates.length); + for (String stringCandidate : featureCandidates) { + String[] candidate = stringCandidate.split(","); + for (int i = 0; i < candidate.length; i++) { + candidate[i] = candidate[i].trim(); + } + candidateIndices.add(TableUtil.findColIndices(inputCols, candidate)); + } + } + + @Override + public void open(Configuration parameters) throws Exception { + numTasks = getRuntimeContext().getNumberOfParallelSubtasks(); + featureSize = (int[]) getRuntimeContext().getBroadcastVariable("featureSize").get(0); + + } + + @Override + public void mapPartition(Iterable > values, + Collector > out) + throws Exception { + + int taskId = getRuntimeContext().getIndexOfThisSubtask(); + if (candidateIndices.size() <= taskId) { + return; + } + List > data = (List >) + SessionSharedData.get(AC_TRAIN_DATA, SESSION_ID, taskId); + + FeatureEvaluator evaluator = new FeatureEvaluator( + LinearModelType.LR, + data, + featureSize, + null, + 0.8, + false, + 1); + + for (int i = taskId; i < candidateIndices.size(); i += numTasks) { + int[] candidate = candidateIndices.get(i); + List testingFeatures = new ArrayList <>(); + testingFeatures.add(candidate); + if (AlinkGlobalConfiguration.isPrintProcessInfo()) { + System.out.println("evaluating " + JsonConverter.toJson(testingFeatures)); + } + Tuple2 scoreCoef = evaluator.score(testingFeatures, numericalSize); + double score = scoreCoef.f0; + out.collect(Tuple4.of(candidate, score, scoreCoef.f1, 0)); + } + + } + } + + private static class FilterAuc + extends RichMapPartitionFunction , Row> { + + private int crossFeatureNumber; + private int[] indexSize; + private String vectorCol; + private String[] numericalCols; + + FilterAuc(int crossFeatureNumber, String vectorCol, String[] numericalCols) { + this.crossFeatureNumber = crossFeatureNumber; + this.vectorCol = vectorCol; + this.numericalCols = numericalCols; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + indexSize = (int[]) getRuntimeContext().getBroadcastVariable("featureSize").get(0); + } + + @Override + public void mapPartition(Iterable > values, + Collector out) throws Exception { + List > canAndAuc = new ArrayList <>(); + for (Tuple4 value : values) { + canAndAuc.add(value); + } + if (canAndAuc.size() == 0) { + return; + } + + FeatureSet featureSet = new FeatureSet(indexSize); + featureSet.numericalCols = numericalCols; + featureSet.indexSize = indexSize; + featureSet.vecColName = vectorCol; + featureSet.hasDiscrete = true; + canAndAuc.sort(new Comparator >() { + @Override + public int compare(Tuple4 o1, + Tuple4 o2) { + return -o1.f1.compareTo(o2.f1); + } + }); + + for (int i = 0; i < this.crossFeatureNumber; i++) { + featureSet.addOneCrossFeature(canAndAuc.get(i).f0, canAndAuc.get(i).f1); + } + out.collect(Row.of(0L, featureSet.toString(), null)); + for (int i = 0; i < featureSet.crossFeatureSet.size(); i++) { + out.collect(Row.of((long) (i + 1), JsonConverter.toJson(featureSet.crossFeatureSet.get(i)), + featureSet.scores.get(i))); + } + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/CrossFeaturePredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/CrossFeaturePredictBatchOp.java index 2fdd1b09c..5b9f78cee 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/CrossFeaturePredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/CrossFeaturePredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.feature.CrossFeatureModelMapper; import com.alibaba.alink.params.feature.CrossFeaturePredictParams; @@ -11,6 +12,7 @@ * Cross selected columns to build new vector type data. */ @NameCn("Cross特征预测") +@NameEn("Cross Feature Prediction") public class CrossFeaturePredictBatchOp extends ModelMapBatchOp implements CrossFeaturePredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/CrossFeatureTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/CrossFeatureTrainBatchOp.java index 106180ae9..ad2b7895d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/CrossFeatureTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/CrossFeatureTrainBatchOp.java @@ -13,11 +13,12 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData; @@ -27,6 +28,7 @@ import com.alibaba.alink.params.dataproc.HasSelectedColTypes; import com.alibaba.alink.params.feature.CrossFeatureTrainParams; import com.alibaba.alink.params.shared.colname.HasSelectedCols; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import com.google.common.collect.Lists; import java.util.List; @@ -38,6 +40,8 @@ @OutputPorts(values = {@PortSpec(value = PortType.MODEL)}) @ParamSelectColumnSpec(name = "selectedCols") @NameCn("Cross特征训练") +@NameEn("Cross Feature Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.CrossFeature") public class CrossFeatureTrainBatchOp extends BatchOperator implements CrossFeatureTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/DCTBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/DCTBatchOp.java index 7f00cf50c..ab08e7348 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/DCTBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/DCTBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -20,7 +21,8 @@ @InputPorts(values = {@PortSpec(PortType.DATA)}) @OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) @ParamSelectColumnSpec(name = "selectedCol") -@NameCn("离散余弦变换") +@NameCn("Discrete Cosine Transform") +@NameEn("Cross Feature Training") public class DCTBatchOp extends MapBatchOp implements DCTParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/DecisionTreeEncoderTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/DecisionTreeEncoderTrainBatchOp.java new file mode 100644 index 000000000..d8d663d3e --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/DecisionTreeEncoderTrainBatchOp.java @@ -0,0 +1,17 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.classification.DecisionTreeTrainBatchOp; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; + +@NameCn("决策树编码器训练") +@NameEn("Decision Tree Encoder Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.DecisionTreeEncoder") +public class DecisionTreeEncoderTrainBatchOp extends DecisionTreeTrainBatchOp { + public DecisionTreeEncoderTrainBatchOp(Params params) { + super(params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/DecisionTreeRegEncoderTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/DecisionTreeRegEncoderTrainBatchOp.java new file mode 100644 index 000000000..b601e1a9e --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/DecisionTreeRegEncoderTrainBatchOp.java @@ -0,0 +1,17 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.regression.DecisionTreeRegTrainBatchOp; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; + +@NameCn("决策树回归编码器训练") +@NameEn("Decision Tree RegEncoder Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.DecisionTreeRegEncoder") +public class DecisionTreeRegEncoderTrainBatchOp extends DecisionTreeRegTrainBatchOp { + public DecisionTreeRegEncoderTrainBatchOp(Params params) { + super(params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/EqualWidthDiscretizerModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/EqualWidthDiscretizerModelInfoBatchOp.java index 95ca1a4e2..9ae7e9626 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/EqualWidthDiscretizerModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/EqualWidthDiscretizerModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelInfo; import java.util.List; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/EqualWidthDiscretizerPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/EqualWidthDiscretizerPredictBatchOp.java index 5994ca406..722b6f092 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/EqualWidthDiscretizerPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/EqualWidthDiscretizerPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.SelectedColsWithSecondInputSpec; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelMapper; @@ -15,6 +16,7 @@ */ @SelectedColsWithSecondInputSpec @NameCn("等宽离散化预测") +@NameEn("Equal Width Discretize Prediction") public final class EqualWidthDiscretizerPredictBatchOp extends ModelMapBatchOp implements QuantileDiscretizerPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/EqualWidthDiscretizerTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/EqualWidthDiscretizerTrainBatchOp.java index 692eca88b..cfa96d16f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/EqualWidthDiscretizerTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/EqualWidthDiscretizerTrainBatchOp.java @@ -8,19 +8,21 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelDataConverter; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummary; import com.alibaba.alink.params.feature.QuantileDiscretizerTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import java.util.Comparator; import java.util.HashMap; @@ -37,6 +39,8 @@ @OutputPorts(values = {@PortSpec(PortType.MODEL)}) @ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("等宽离散化训练") +@NameEn("Equal Width Discretize Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.EqualWidthDiscretizer") public final class EqualWidthDiscretizerTrainBatchOp extends BatchOperator implements QuantileDiscretizerTrainParams , WithModelInfoBatchOp colNameBucketNumber, @Override public void flatMap(TableSummary tableSummary, Collector collector) { for (String colName : tableSummary.getColNames()) { - double min = tableSummary.min(colName); - double max = tableSummary.max(colName); + double min = tableSummary.minDouble(colName); + double max = tableSummary.maxDouble(colName); collector.collect(Row.of(TableUtil.findColIndexWithAssertAndHint(colNames, colName), getSplitPointsFromMinMax(min, max, colNameBucketNumber.get(colName)))); diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/ExclusiveFeatureBundleModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/ExclusiveFeatureBundleModelInfoBatchOp.java new file mode 100644 index 000000000..bb26c9685 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/ExclusiveFeatureBundleModelInfoBatchOp.java @@ -0,0 +1,29 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; + +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.common.feature.ExclusiveFeatureBundleModelInfo; + +import java.util.List; + +/** + * ExclusiveFeatureBundleModelInfoLocalOp can be linked to the output of ExclusiveFeatureBundleTrainLocalOp to summary the model. + */ +public class ExclusiveFeatureBundleModelInfoBatchOp + extends ExtractModelInfoBatchOp { + + public ExclusiveFeatureBundleModelInfoBatchOp() { + this(null); + } + + public ExclusiveFeatureBundleModelInfoBatchOp(Params params) { + super(params); + } + + @Override + public ExclusiveFeatureBundleModelInfo createModelInfo(List rows) { + return new ExclusiveFeatureBundleModelInfo(rows); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/ExclusiveFeatureBundlePredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/ExclusiveFeatureBundlePredictBatchOp.java new file mode 100644 index 000000000..9b1f15570 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/ExclusiveFeatureBundlePredictBatchOp.java @@ -0,0 +1,23 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.feature.ExclusiveFeatureBundleModelMapper; +import com.alibaba.alink.params.feature.ExclusiveFeatureBundlePredictParams; + +@NameCn("互斥特征捆绑模型预测") +@NameEn("Exclusive Feature Bundle Prediction") +public class ExclusiveFeatureBundlePredictBatchOp extends ModelMapBatchOp + implements ExclusiveFeatureBundlePredictParams { + + public ExclusiveFeatureBundlePredictBatchOp() { + this(null); + } + + public ExclusiveFeatureBundlePredictBatchOp(Params params) { + super(ExclusiveFeatureBundleModelMapper::new, params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/FeatureHasherBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/FeatureHasherBatchOp.java index ea73fb610..fc3a39ec3 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/FeatureHasherBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/FeatureHasherBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -23,6 +24,7 @@ @ParamSelectColumnSpec(name = "selectedCols") @ParamSelectColumnSpec(name = "categoricalCols") @NameCn("特征哈希") +@NameEn("Feature Hasher") public final class FeatureHasherBatchOp extends MapBatchOp implements FeatureHasherParams { private static final long serialVersionUID = 6037792513321750824L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/GbdtEncoderPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/GbdtEncoderPredictBatchOp.java new file mode 100644 index 000000000..59f7da1c7 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/GbdtEncoderPredictBatchOp.java @@ -0,0 +1,28 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.tree.predictors.TreeModelEncoderModelMapper; +import com.alibaba.alink.params.feature.TreeModelEncoderParams; + +/** + * Gbdt encoder to encode gbdt model to sparse matrix which is used + * as feature for classifier. + */ +@NameCn("GBDT分类编码预测") +@NameEn("Gbdt Encoder Prediction") +public class GbdtEncoderPredictBatchOp extends ModelMapBatchOp + implements TreeModelEncoderParams { + private static final long serialVersionUID = -7596799178072234171L; + + public GbdtEncoderPredictBatchOp() { + this(new Params()); + } + + public GbdtEncoderPredictBatchOp(Params params) { + super(TreeModelEncoderModelMapper::new, params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/GbdtEncoderTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/GbdtEncoderTrainBatchOp.java new file mode 100644 index 000000000..70bfc217a --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/GbdtEncoderTrainBatchOp.java @@ -0,0 +1,21 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.classification.GbdtTrainBatchOp; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; + +@NameCn("GBDT分类编码训练") +@NameEn("Gbdt Encoder Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.GbdtEncoder") +public class GbdtEncoderTrainBatchOp extends GbdtTrainBatchOp { + public GbdtEncoderTrainBatchOp() { + super(new Params()); + } + + public GbdtEncoderTrainBatchOp(Params params) { + super(params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/GbdtRegEncoderTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/GbdtRegEncoderTrainBatchOp.java new file mode 100644 index 000000000..f877bfa40 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/GbdtRegEncoderTrainBatchOp.java @@ -0,0 +1,17 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.regression.GbdtRegTrainBatchOp; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; + +@NameCn("GBDT回归编码器训练") +@NameEn("Gbdt RegEncoder Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.GbdtRegEncoder") +public class GbdtRegEncoderTrainBatchOp extends GbdtRegTrainBatchOp { + public GbdtRegEncoderTrainBatchOp(Params params) { + super(params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/GenerateFeatureOfLatestBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/GenerateFeatureOfLatestBatchOp.java index 5891e8429..ea1698c3b 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/GenerateFeatureOfLatestBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/GenerateFeatureOfLatestBatchOp.java @@ -10,14 +10,15 @@ import com.alibaba.alink.common.MTable; import com.alibaba.alink.common.MTableUtil; import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.params.feature.GenerateFeatureOfLatestParams; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.fe.GenerateFeatureUtil; import com.alibaba.alink.common.fe.define.BaseStatFeatures; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.dataproc.FlattenMTableBatchOp; import com.alibaba.alink.operator.batch.source.TableSourceBatchOp; +import com.alibaba.alink.params.feature.GenerateFeatureOfLatestParams; import java.sql.Timestamp; import java.util.List; @@ -27,6 +28,7 @@ * Latest Feature Window. */ @NameCn("Latest特征生成") +@NameEn("Generate Feature of Latest") public class GenerateFeatureOfLatestBatchOp extends BatchOperator implements GenerateFeatureOfLatestParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/GenerateFeatureOfLatestNDaysBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/GenerateFeatureOfLatestNDaysBatchOp.java index 1a87b06b9..52642a4fc 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/GenerateFeatureOfLatestNDaysBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/GenerateFeatureOfLatestNDaysBatchOp.java @@ -13,6 +13,7 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.fe.GenerateFeatureUtil; import com.alibaba.alink.common.fe.define.day.BaseDaysStatFeatures; @@ -23,14 +24,14 @@ import com.alibaba.alink.common.fe.define.statistics.BaseCategoricalStatistics; import com.alibaba.alink.common.fe.define.statistics.BaseNumericStatistics; import com.alibaba.alink.common.fe.udaf.CatesCntUdaf; -import com.alibaba.alink.common.fe.udaf.TotalCountUdaf; import com.alibaba.alink.common.fe.udaf.DistinctCountUdaf; import com.alibaba.alink.common.fe.udaf.KvCntUdaf; import com.alibaba.alink.common.fe.udaf.MaxUdaf; import com.alibaba.alink.common.fe.udaf.MeanUdaf; import com.alibaba.alink.common.fe.udaf.MinUdaf; import com.alibaba.alink.common.fe.udaf.SumUdaf; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.common.fe.udaf.TotalCountUdaf; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.MemSourceBatchOp; @@ -52,6 +53,7 @@ * Latest Feature Window. */ @NameCn("Latest特征生成") +@NameEn("Generate Feature of Latest N Days") public class GenerateFeatureOfLatestNDaysBatchOp extends BatchOperator implements GenerateFeatureOfLatestDayParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/GenerateFeatureOfWindowBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/GenerateFeatureOfWindowBatchOp.java index f6e224538..4f45d0243 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/GenerateFeatureOfWindowBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/GenerateFeatureOfWindowBatchOp.java @@ -11,13 +11,14 @@ import com.alibaba.alink.common.MTable; import com.alibaba.alink.common.MTableUtil; import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.params.feature.GenerateFeatureOfWindowParams; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.fe.GenerateFeatureUtil; import com.alibaba.alink.common.fe.define.BaseStatFeatures; import com.alibaba.alink.common.fe.define.InterfaceWindowStatFeatures; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.params.feature.GenerateFeatureOfWindowParams; import java.sql.Timestamp; import java.util.List; @@ -26,6 +27,7 @@ * Generate Feature Window. */ @NameCn("窗口特征生成") +@NameEn("Generate Feature of Window") public class GenerateFeatureOfWindowBatchOp extends BatchOperator implements GenerateFeatureOfWindowParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/HashCrossFeatureBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/HashCrossFeatureBatchOp.java index 099b59c45..d8231a83f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/HashCrossFeatureBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/HashCrossFeatureBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.feature.HashCrossFeatureMapper; @@ -10,6 +11,7 @@ @ParamSelectColumnSpec(name = "selectedCols") @NameCn("Hash Cross特征") +@NameEn("Hash Cross Feature") public class HashCrossFeatureBatchOp extends MapBatchOp implements HashCrossFeatureParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/Id3EncoderTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/Id3EncoderTrainBatchOp.java new file mode 100644 index 000000000..b920dd8c0 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/Id3EncoderTrainBatchOp.java @@ -0,0 +1,17 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.classification.Id3TrainBatchOp; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; + +@NameCn("ID3决策树分类编码器训练") +@NameEn(" Id3 Encoder Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.Id3Encoder") +public class Id3EncoderTrainBatchOp extends Id3TrainBatchOp { + public Id3EncoderTrainBatchOp(Params params) { + super(params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/MultiHotModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/MultiHotModelInfoBatchOp.java index 41def96d0..30b81e053 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/MultiHotModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/MultiHotModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.feature.MultiHotModelInfo; import java.util.List; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/MultiHotPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/MultiHotPredictBatchOp.java index 92a9af1af..47e9950d6 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/MultiHotPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/MultiHotPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -14,6 +15,7 @@ */ @ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("多热编码预测") +@NameEn("Multi Hot Prediction") public class MultiHotPredictBatchOp extends ModelMapBatchOp implements MultiHotPredictParams { private static final long serialVersionUID = -6029385456358959482L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/MultiHotTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/MultiHotTrainBatchOp.java index b08b694f9..d1fa52963 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/MultiHotTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/MultiHotTrainBatchOp.java @@ -10,19 +10,21 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortSpec.OpType; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.feature.MultiHotModelData; import com.alibaba.alink.operator.common.feature.MultiHotModelDataConverter; import com.alibaba.alink.operator.common.feature.MultiHotModelInfo; import com.alibaba.alink.params.feature.MultiHotTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import java.util.HashMap; import java.util.Map; @@ -34,6 +36,8 @@ @OutputPorts(values = @PortSpec(value = PortType.MODEL)) @ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("多热编码训练") +@NameEn("Multi Hot Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.MultiHotEncoder") public final class MultiHotTrainBatchOp extends BatchOperator implements MultiHotTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/OneHotModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/OneHotModelInfoBatchOp.java index 02cc73747..9714191ce 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/OneHotModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/OneHotModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.feature.OneHotModelInfo; import java.util.List; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/OneHotPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/OneHotPredictBatchOp.java index 8752aa092..d7b3eb44d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/OneHotPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/OneHotPredictBatchOp.java @@ -3,6 +3,8 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.feature.OneHotModelMapper; import com.alibaba.alink.params.feature.OneHotPredictParams; @@ -12,6 +14,8 @@ * sparse binary vectors. */ @NameCn("独热编码预测") +@NameEn("OneHot Encoder Predict") +@ParamSelectColumnSpec(name = "selectedCols", portIndices = {1}) public final class OneHotPredictBatchOp extends ModelMapBatchOp implements OneHotPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/OneHotTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/OneHotTrainBatchOp.java index d4e36480c..9db1add79 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/OneHotTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/OneHotTrainBatchOp.java @@ -20,18 +20,21 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamsIgnoredOnWebUI; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortSpec.OpType; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.SelectedColsWithFirstInputSpec; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.dataproc.StringIndexerUtil; import com.alibaba.alink.operator.common.feature.OneHotModelDataConverter; import com.alibaba.alink.operator.common.feature.OneHotModelInfo; +import com.alibaba.alink.operator.common.feature.OneHotModelMapper; import com.alibaba.alink.operator.common.feature.binning.BinDivideType; import com.alibaba.alink.operator.common.feature.binning.Bins; import com.alibaba.alink.operator.common.feature.binning.FeatureBinsCalculator; @@ -40,6 +43,7 @@ import com.alibaba.alink.params.feature.HasEnableElse; import com.alibaba.alink.params.feature.OneHotTrainParams; import com.alibaba.alink.params.shared.colname.HasSelectedCols; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import java.util.ArrayList; import java.util.Arrays; @@ -53,6 +57,10 @@ @OutputPorts(values = @PortSpec(value = PortType.MODEL)) @SelectedColsWithFirstInputSpec @NameCn("独热编码训练") +@NameEn("OneHot Encoder Train") +// Designer前端不支持返回int数组,且xflow中one-hot编码没有相应参数,因此先对Designer用户屏蔽该参数 +@ParamsIgnoredOnWebUI(names = {"discreteThresholdsArray"}) +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.OneHotEncoder") public final class OneHotTrainBatchOp extends BatchOperator implements OneHotTrainParams , WithModelInfoBatchOp { @@ -94,7 +102,7 @@ public OneHotTrainBatchOp linkFrom(BatchOperator ... inputs) { thresholdArray = Arrays.stream(getDiscreteThresholdsArray()).mapToInt(Integer::intValue).toArray(); } - boolean enableElse = isEnableElse(thresholdArray); + boolean enableElse = OneHotModelMapper.isEnableElse(thresholdArray); DataSet inputRows = in.select(selectedColNames).getDataSet(); DataSet > countTokens = StringIndexerUtil.countTokens(inputRows, true) @@ -152,20 +160,14 @@ public OneHotModelInfoBatchOp getModelInfoBatchOp() { return new OneHotModelInfoBatchOp(this.getParams()).linkFrom(this); } + /** * If the thresholdArray is set and greater than 0, enableElse is true. * * @param thresholdArray thresholds for each column. * @return enableElse. */ - private static boolean isEnableElse(int[] thresholdArray) { - for (int threshold : thresholdArray) { - if (threshold > 0) { - return true; - } - } - return false; - } + /** * Transform OneHotModel to Binning Featureborder. diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/OverWindowBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/OverWindowBatchOp.java index 01863d1e2..0e8f9b202 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/OverWindowBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/OverWindowBatchOp.java @@ -1,19 +1,21 @@ package com.alibaba.alink.operator.batch.feature; -import org.apache.flink.api.common.functions.GroupReduceFunction; import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichGroupReduceFunction; import org.apache.flink.api.common.operators.Order; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.operators.SortedGrouping; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; @@ -27,8 +29,9 @@ import com.alibaba.alink.common.sql.builtin.agg.LastTimeUdaf; import com.alibaba.alink.common.sql.builtin.agg.LastValueUdaf; import com.alibaba.alink.common.sql.builtin.agg.ListAggUdaf; +import com.alibaba.alink.common.sql.builtin.agg.MTableAgg; import com.alibaba.alink.common.sql.builtin.agg.SumLastUdaf; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.TableSourceBatchOp; @@ -44,6 +47,7 @@ * Batch over window feature builder. */ @NameCn("特征构造:OverWindow") +@NameEn("Over Window Feature Builder") @InputPorts(values = @PortSpec(PortType.DATA)) @OutputPorts(values = @PortSpec(PortType.DATA)) @ParamSelectColumnSpec(name = "partitionCols") @@ -80,7 +84,8 @@ public OverWindowBatchOp linkFrom(BatchOperator ... inputs) { int[] reversedIndices = TableUtil.findColIndices(inputColNames, reversedCols); String sqlClause = getClause(); - FeatureClause[] featureClauses = FeatureClauseUtil.extractFeatureClauses(sqlClause); + FeatureClause[] featureClauses = FeatureClauseUtil.extractFeatureClauses(sqlClause, in.getSchema(), + orderBy); DataSet res; if (partitionBys != null) { @@ -197,7 +202,7 @@ private Tuple2 parseOrder(String orderClause, String[] inputCol return Tuple2.of(orderIndices, orderTypes); } - private static class GroupOperation implements GroupReduceFunction { + private static class GroupOperation extends RichGroupReduceFunction { FeatureClause[] featureClauses; int sessionId; int[] partitionByIndices; @@ -209,12 +214,17 @@ private static class GroupOperation implements GroupReduceFunction { int[] partitionByIndices, int[] reversedIndices, String[] inputColNames) { this.featureClauses = featureClauses; this.orderIndices = orderIndices; - this.sessionId = SessionSharedData.getNewSessionId(); this.partitionByIndices = partitionByIndices; this.reversedIndices = reversedIndices; this.inputColNames = inputColNames; } + @Override + public void open(Configuration parameters) throws Exception { + //new open func in worker, so sessionId new here, not constructor. + this.sessionId = SessionSharedData.getNewSessionId(); + } + @Override public void reduce(Iterable values, Collector out) throws Exception { Row res = null; @@ -236,7 +246,7 @@ public void reduce(Iterable values, Collector out) throws Exception } } - BaseUdaf[] calcs = (BaseUdaf[]) SessionSharedData.get(keys.toString(), sessionId); + BaseUdaf [] calcs = (BaseUdaf []) SessionSharedData.get(keys.toString(), sessionId); if (calcs == null) { calcs = new BaseUdaf[featureClauses.length]; for (int i = 0; i < featureClauses.length; i++) { @@ -250,6 +260,7 @@ public void reduce(Iterable values, Collector out) throws Exception } else { index++; } + for (int i = 0; i < featureClauses.length; i++) { /* @@ -258,7 +269,7 @@ public void reduce(Iterable values, Collector out) throws Exception Beside, 'distinct' and 'all' are not supported. */ - BaseUdaf udaf = calcs[i]; + BaseUdaf udaf = calcs[i]; if (udaf instanceof LastValueUdaf || udaf instanceof LastDistinctValueUdaf || udaf instanceof LastTimeUdaf || udaf instanceof SumLastUdaf) { int kLength = featureClauses[i].inputParams.length; @@ -272,7 +283,7 @@ public void reduce(Iterable values, Collector out) throws Exception if (udaf instanceof LastDistinctValueUdaf) { for (int j = 0; j < kLength; j++) { aggInputData[inputIndex++] = value.getField( - TableUtil.findColIndex(inputColNames, (String) featureClauses[j].inputParams[0])); + TableUtil.findColIndex(inputColNames, (String) featureClauses[i].inputParams[0])); } } else { for (int j = 0; j < kLength; j++) { @@ -298,6 +309,16 @@ public void reduce(Iterable values, Collector out) throws Exception udaf.accumulateBatch(thisData); } else if (udaf instanceof CountUdaf) { udaf.accumulateBatch(0); + } else if (udaf instanceof MTableAgg) { + Object[] aggInputData = new Object[featureClauses[i].inputParams.length + 1]; + aggInputData[0] = value.getField( + TableUtil.findColIndex(inputColNames, featureClauses[i].inColName)); + for (int j = 0; j < featureClauses[i].inputParams.length; j++) { + aggInputData[1 + j] = value.getField( + TableUtil.findColIndex(inputColNames, (String)featureClauses[i].inputParams[j])); + } + + udaf.accumulateBatch(aggInputData); } else { Object[] aggInputData = new Object[featureClauses[i].inputParams.length + 1]; aggInputData[0] = value.getField( diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaModelInfoBatchOp.java index 7db8dbc0b..0f3c49c5d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.feature.pca.PcaModelData; import com.alibaba.alink.operator.common.feature.pca.PcaModelDataConverter; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaPredictBatchOp.java index e01ba99f9..542d92e80 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -14,6 +15,7 @@ */ @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("主成分分析预测") +@NameEn("Pca Prediction") public class PcaPredictBatchOp extends ModelMapBatchOp implements PcaPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaTrainBatchOp.java index ead87291e..6726f7b09 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaTrainBatchOp.java @@ -10,6 +10,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; @@ -18,7 +19,7 @@ import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.exceptions.AkIllegalStateException; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.linalg.DenseMatrix; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.EigenSolver; @@ -27,10 +28,11 @@ import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.feature.pca.PcaModelData; import com.alibaba.alink.operator.common.feature.pca.PcaModelDataConverter; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummarizer; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; import com.alibaba.alink.params.feature.PcaTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import java.util.ArrayList; import java.util.Arrays; @@ -46,6 +48,8 @@ @ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("主成分分析训练") +@NameEn("Pca Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.PCA") public final class PcaTrainBatchOp extends BatchOperator implements PcaTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/QuantileDiscretizerModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/QuantileDiscretizerModelInfoBatchOp.java index 0589c94f3..87f29fe93 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/QuantileDiscretizerModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/QuantileDiscretizerModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelInfo; import java.util.List; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/QuantileDiscretizerPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/QuantileDiscretizerPredictBatchOp.java index c8a8fb461..bb46689f6 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/QuantileDiscretizerPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/QuantileDiscretizerPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -14,6 +15,7 @@ */ @ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("分位数离散化预测") +@NameEn("Quantile Discretizer Prediction") public final class QuantileDiscretizerPredictBatchOp extends ModelMapBatchOp implements QuantileDiscretizerPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/QuantileDiscretizerTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/QuantileDiscretizerTrainBatchOp.java index 41fe06e7f..da0647d1b 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/QuantileDiscretizerTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/QuantileDiscretizerTrainBatchOp.java @@ -19,6 +19,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamCond; import com.alibaba.alink.common.annotation.ParamCond.CondType; @@ -30,7 +31,7 @@ import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.dataproc.SortUtils; @@ -45,6 +46,7 @@ import com.alibaba.alink.operator.common.tree.Preprocessing; import com.alibaba.alink.params.feature.QuantileDiscretizerTrainParams; import com.alibaba.alink.params.statistics.HasRoundMode; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -82,6 +84,8 @@ ) ) @NameCn("分位数离散化训练") +@NameEn("Quantile Discretizer Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.QuantileDiscretizer") public final class QuantileDiscretizerTrainBatchOp extends BatchOperator implements QuantileDiscretizerTrainParams , WithModelInfoBatchOp + implements TargetEncoderPredictParams { + + public TargetEncoderPredictBatchOp() { + this(new Params()); + } + + public TargetEncoderPredictBatchOp(Params params) { + super(TargetEncoderModelMapper::new, params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/TargetEncoderTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/TargetEncoderTrainBatchOp.java new file mode 100644 index 000000000..080be3c75 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/TargetEncoderTrainBatchOp.java @@ -0,0 +1,301 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.api.common.functions.GroupCombineFunction; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichGroupReduceFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.operators.DeltaIteration; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Table; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.MLEnvironmentFactory; +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.feature.TargetEncoderConverter; +import com.alibaba.alink.operator.common.slidingwindow.SessionSharedData; +import com.alibaba.alink.params.feature.TargetEncoderTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; + +import java.util.ArrayList; +import java.util.HashMap; + +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = {@PortSpec(PortType.MODEL)}) +@ParamSelectColumnSpec(name = "selectedCols") +@ParamSelectColumnSpec(name = "labelCol") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.feature.TargetEncoder") +@NameCn("TargetEncoder") +@NameEn("TargetEncoder") +public class TargetEncoderTrainBatchOp extends BatchOperator + implements TargetEncoderTrainParams { + + public TargetEncoderTrainBatchOp(Params params) { + super(params); + } + + public TargetEncoderTrainBatchOp() { + this(new Params()); + } + + @Override + public TargetEncoderTrainBatchOp linkFrom(BatchOperator ... inputs) { + + BatchOperator in = checkAndGetFirst(inputs); + String label = getLabelCol(); + + String[] selectedCols = getSelectedCols(); + if (selectedCols == null) { + selectedCols = TableUtil + .getCategoricalCols(in.getSchema(), in.getColNames(), null); + ArrayList listCols = new ArrayList <>(); + for (String s : selectedCols) { + if (!s.equals(label)) { + listCols.add(s); + } + } + selectedCols = listCols.toArray(new String[0]); + } + int[] selectedColIndices = TableUtil.findColIndices(in.getSchema(), selectedCols); + int labelIndex = TableUtil.findColIndex(in.getSchema(), label); + String positiveLabel = getPositiveLabelValueString(); + int originalSize = in.getColNames().length; + int selectedSize = selectedColIndices.length; + int groupIndex = originalSize + selectedSize; + DataSet > inputData = in.getDataSet() + .mapPartition(new MapInputData(selectedSize, originalSize)) + .map(new MapFunction >() { + @Override + public Tuple2 map(Row value) throws Exception { + return Tuple2.of(0, value); + } + }); + + DataSet > res = MLEnvironmentFactory + .get(getMLEnvironmentId()) + .getExecutionEnvironment() + .fromElements(Tuple2.of(0, new Row(0))); + + DeltaIteration , Tuple2 > loop = + res.iterateDelta(inputData, selectedSize, 0); + DataSet > iterData = loop.getWorkset() + .mapPartition(new BuildGroupByCol(groupIndex, selectedColIndices)); + DataSet > means = iterData + .groupBy(new RowKeySelector(groupIndex)) + .combineGroup(new CalcMean(groupIndex, labelIndex, positiveLabel)); + DataSet rowMeans = means + .reduceGroup(new ReduceColumnInfo(selectedCols)) + .map(new MapFunction >, Row>() { + @Override + public Row map(Tuple2 > value) throws Exception { + Row res = new Row(2); + res.setField(0, value.f0); + res.setField(1, value.f1); + return res; + } + }).returns(TypeInformation.of(Row.class)); + + res = res + .mapPartition(new BuildIterRes()) + .withBroadcastSet(rowMeans, "rowMeans"); + + res = loop.closeWith(res, iterData); + + DataSet modelData = res + .mapPartition(new BuildModelData(selectedSize)) + .reduceGroup(new ReduceModelData()); + + Table resTable = DataSetConversionUtil.toTable(getMLEnvironmentId(), modelData, + new TargetEncoderConverter(selectedCols).getModelSchema()); + this.setOutputTable(resTable); + return this; + } + + public static class RowKeySelector implements KeySelector , Comparable> { + + int index; + + public RowKeySelector(int index) { + this.index = index; + } + + @Override + public Comparable getKey(Tuple2 value) { + return (Comparable) (value.f1.getField(index)); + } + } + + public static class MapInputData implements MapPartitionFunction { + int selectedColSize; + int originColSize; + + MapInputData(int selectedColSize, int originColSize) { + this.selectedColSize = selectedColSize; + this.originColSize = originColSize; + } + + @Override + public void mapPartition(Iterable values, Collector out) throws Exception { + int resColSize = selectedColSize + originColSize + 1;//this is for groupBy + Row res = new Row(resColSize); + for (Row value : values) { + for (int i = 0; i < originColSize; i++) { + res.setField(i, value.getField(i)); + } + out.collect(res); + } + } + } + + public static class CalcMean implements GroupCombineFunction , Tuple2 > { + int groupIndex; + int labelIndex; + String positiveLabel; + + CalcMean(int groupIndex, int labelIndex, String positiveLabel) { + this.groupIndex = groupIndex; + this.labelIndex = labelIndex; + this.positiveLabel = positiveLabel; + } + + @Override + public void combine(Iterable > values, Collector > out) + throws Exception { + int count = 0; + double sum = 0; + Object groupValue = null; + for (Tuple2 value : values) { + groupValue = value.f1.getField(groupIndex); + ++count; + if (positiveLabel == null) { + sum += (double) value.f1.getField(labelIndex); + } else if (value.f1.getField(labelIndex).toString().equals(positiveLabel)) { + ++sum; + } + } + double mean = sum / count; + out.collect(Tuple2.of(groupValue, mean)); + } + } + + public static class BuildGroupByCol + extends RichMapPartitionFunction , Tuple2 > { + int superStepNumber; + int lastIndex; + int[] selectedColIndices; + + BuildGroupByCol(int lastIndex, int[] selectedColIndices) { + this.lastIndex = lastIndex; + this.selectedColIndices = selectedColIndices; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + superStepNumber = getIterationRuntimeContext().getSuperstepNumber() - 1; + } + + @Override + public void mapPartition(Iterable > values, + Collector > out) throws Exception { + int selectedIndex = selectedColIndices[superStepNumber]; + for (Tuple2 value : values) { + value.f1.setField(lastIndex, value.f1.getField(selectedIndex)); + out.collect(value); + } + } + } + + private static class ReduceColumnInfo + extends RichGroupReduceFunction , Tuple2 >> { + String[] selectedCols; + + ReduceColumnInfo(String[] selectedCols) { + this.selectedCols = selectedCols; + } + + @Override + public void reduce(Iterable > values, + Collector >> out) throws Exception { + HashMap res = new HashMap <>(); + for (Tuple2 value : values) { + res.put(value.f0.toString(), value.f1); + } + int iterStepNum = getIterationRuntimeContext().getSuperstepNumber() - 1; + out.collect(Tuple2.of(selectedCols[iterStepNum], res)); + } + } + + private static class BuildIterRes extends RichMapPartitionFunction , Tuple2 > { + Row items; + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + items = (Row) getRuntimeContext().getBroadcastVariable("rowMeans").get(0); + } + + @Override + public void mapPartition(Iterable > values, Collector > out) throws Exception { + int superStepNum = getIterationRuntimeContext().getSuperstepNumber() - 1; + int taskId = getRuntimeContext().getIndexOfThisSubtask(); + int parallelism = getRuntimeContext().getMaxNumberOfParallelSubtasks(); + if (superStepNum % parallelism == taskId) { + SessionSharedData.put("" + superStepNum, taskId, items); + } + } + } + + private static class BuildModelData + extends RichMapPartitionFunction , Tuple2> { + int iterNum; + + BuildModelData(int iterNum) { + this.iterNum = iterNum; + } + + @Override + public void mapPartition(Iterable > values, + Collector > out) throws Exception { + int taskId = getRuntimeContext().getIndexOfThisSubtask(); + int index = taskId; + int parallelism = getRuntimeContext().getMaxNumberOfParallelSubtasks(); + while (index < iterNum) { + out.collect(Tuple2.of(0, (Row) SessionSharedData.get(""+index, taskId))); + index += parallelism; + } + } + } + + private static class ReduceModelData + implements GroupReduceFunction , Row> { + + @Override + public void reduce(Iterable > values, + Collector out) throws Exception { + TargetEncoderConverter converter = new TargetEncoderConverter(); + for (Tuple2 value : values) { + Row rowData = value.f1; + rowData.setField(1, JsonConverter.toJson(rowData.getField(1))); + converter.save(rowData, out); + } + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/TreeModelEncoderBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/TreeModelEncoderBatchOp.java index 39d977894..5a17ad463 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/TreeModelEncoderBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/TreeModelEncoderBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.tree.predictors.TreeModelEncoderModelMapper; import com.alibaba.alink.params.feature.TreeModelEncoderParams; @@ -12,6 +13,7 @@ * as feature for classifier or regressor. */ @NameCn("决策树模型编码") +@NameEn("Tree Model Encoder") public class TreeModelEncoderBatchOp extends ModelMapBatchOp implements TreeModelEncoderParams { private static final long serialVersionUID = -7596799114572234171L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/VectorChiSqSelectorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/VectorChiSqSelectorBatchOp.java index 361743f7f..274544b74 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/VectorChiSqSelectorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/VectorChiSqSelectorBatchOp.java @@ -6,14 +6,15 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.feature.ChisqSelectorModelInfo; import com.alibaba.alink.operator.common.feature.ChisqSelectorUtil; @@ -28,6 +29,7 @@ @OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量卡方选择器") +@NameEn("Vector ChiSq Selector") public final class VectorChiSqSelectorBatchOp extends BatchOperator implements VectorChiSqSelectorParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/WoePredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/WoePredictBatchOp.java new file mode 100644 index 000000000..0c260a878 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/WoePredictBatchOp.java @@ -0,0 +1,29 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.Internal; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.feature.WoeModelMapper; +import com.alibaba.alink.params.finance.WoePredictParams; + +/** + * Hash a vector in the Jaccard distance space to a new vector of given dimensions. + */ +@Internal +@ParamSelectColumnSpec(name = "selectedCols", + allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +public final class WoePredictBatchOp extends ModelMapBatchOp + implements WoePredictParams { + private static final long serialVersionUID = -9048145108684979562L; + + public WoePredictBatchOp() { + this(new Params()); + } + + public WoePredictBatchOp(Params params) { + super(WoeModelMapper::new, params); + } +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/WoeTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/WoeTrainBatchOp.java new file mode 100644 index 000000000..af311d155 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/WoeTrainBatchOp.java @@ -0,0 +1,250 @@ +package com.alibaba.alink.operator.batch.feature; + +import org.apache.flink.api.common.functions.FilterFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.JoinFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichGroupReduceFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.Internal; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.evaluation.EvaluationUtil; +import com.alibaba.alink.operator.common.feature.WoeModelDataConverter; +import com.alibaba.alink.operator.common.feature.binning.FeatureBinsCalculator; +import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter; +import com.alibaba.alink.params.dataproc.HasSelectedColTypes; +import com.alibaba.alink.params.finance.WoeTrainParams; +import com.alibaba.alink.params.shared.colname.HasSelectedCols; +import org.apache.commons.lang3.ArrayUtils; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +@InputPorts(values = @PortSpec(PortType.DATA)) +@OutputPorts(values = { + @PortSpec(PortType.MODEL) +}) + +@ParamSelectColumnSpec(name = "selectedCols", + allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@ParamSelectColumnSpec(name = "labelCol") + +@Internal +public final class WoeTrainBatchOp extends BatchOperator + implements WoeTrainParams { + + private static final long serialVersionUID = 5413307707249156884L; + public static String NULL_STR = "WOE_NULL_STRING"; + + public WoeTrainBatchOp() {} + + public WoeTrainBatchOp(Params params) { + super(params); + } + + @Override + public WoeTrainBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + String[] selectedCols = this.getSelectedCols(); + final String[] selectedColSqlType = new String[selectedCols.length]; + for (int i = 0; i < selectedCols.length; i++) { + selectedColSqlType[i] = FlinkTypeConverter.getTypeString( + TableUtil.findColType(in.getSchema(), selectedCols[i])); + } + int selectedLen = selectedCols.length; + String labelCol = this.getLabelCol(); + TypeInformation type = TableUtil.findColTypeWithAssertAndHint(in.getSchema(), labelCol); + DataSet > labelCount = in.select(labelCol).getDataSet().map( + new MapFunction >() { + private static final long serialVersionUID = 7244041023667742068L; + + @Override + public Tuple2 map(Row value) throws Exception { + Preconditions.checkNotNull(value.getField(0), "LabelCol contains null value!"); + return Tuple2.of( + new EvaluationUtil.ComparableLabel(value.getField(0).toString(), type).label.toString(), 1L); + } + }) + .groupBy(0) + .sum(1); + EvaluationUtil.ComparableLabel positiveValue = new EvaluationUtil.ComparableLabel( + this.getPositiveLabelValueString(), type); + + DataSet > total = in.select(ArrayUtils.add(selectedCols, labelCol)) + .getDataSet() + .flatMap(new FlatMapFunction >() { + private static final long serialVersionUID = 3246134198223500699L; + + @Override + public void flatMap(Row value, Collector > out) { + Long equalPositive = new EvaluationUtil.ComparableLabel(value.getField(selectedLen), type).equals( + positiveValue) ? 1L : 0L; + for (int i = 0; i < selectedLen; i++) { + Object obj = value.getField(i); + out.collect(Tuple3.of(i, null == obj ? NULL_STR : obj.toString(), equalPositive)); + } + } + }) + .groupBy(0, 1) + .reduceGroup( + new GroupReduceFunction , Tuple4 >() { + private static final long serialVersionUID = 8132981693511963253L; + + @Override + public void reduce(Iterable > values, + Collector > out) throws Exception { + Long binPositiveTotal = 0L; + Long binTotal = 0L; + int colIdx = -1; + String binIndex = null; + for (Tuple3 t : values) { + binTotal++; + colIdx = t.f0; + binIndex = t.f1; + binPositiveTotal += t.f2; + } + if (colIdx >= 0) { + out.collect(Tuple4.of(colIdx, binIndex, binTotal, binPositiveTotal)); + } + } + }); + DataSet values = total + .mapPartition(new RichMapPartitionFunction , Row>() { + private static final long serialVersionUID = 9015674191729072450L; + private long positiveTotal; + private long negativeTotal; + + @Override + public void open(Configuration configuration) { + List > labelCount = this.getRuntimeContext().getBroadcastVariable( + "labelCount"); + Preconditions.checkArgument(labelCount.size() == 2, "Only support binary classification!"); + if (positiveValue.equals(new EvaluationUtil.ComparableLabel(labelCount.get(0).f0, type))) { + positiveTotal = labelCount.get(0).f1; + negativeTotal = labelCount.get(1).f1; + } else if (positiveValue.equals(new EvaluationUtil.ComparableLabel(labelCount.get(1).f0, type))) { + positiveTotal = labelCount.get(1).f1; + negativeTotal = labelCount.get(0).f1; + } else { + throw new IllegalArgumentException("Not contain positiveValue " + positiveValue); + } + } + + @Override + public void mapPartition(Iterable > values, Collector out) { + Params meta = null; + if (getRuntimeContext().getIndexOfThisSubtask() == 0) { + meta = new Params() + .set(HasSelectedCols.SELECTED_COLS, selectedCols) + .set(HasSelectedColTypes.SELECTED_COL_TYPES, selectedColSqlType) + .set(WoeModelDataConverter.POSITIVE_TOTAL, positiveTotal) + .set(WoeModelDataConverter.NEGATIVE_TOTAL, negativeTotal); + } + new WoeModelDataConverter().save(Tuple2.of(meta, values), out); + } + }) + .withBroadcastSet(labelCount, "labelCount") + .name("build_model"); + + this.setOutput(values, new WoeModelDataConverter().getModelSchema()); + return this; + } + + public static DataSet setFeatureBinsWoe( + DataSet featureBorderDataSet, + DataSet woeModel) { + DataSet > borderWithName = featureBorderDataSet.map( + new MapFunction >() { + private static final long serialVersionUID = 3810414585464772028L; + + @Override + public Tuple2 map(FeatureBinsCalculator value) { + return Tuple2.of(value.getFeatureName(), value); + } + }); + + DataSet selectedCols = woeModel.filter(new FilterFunction () { + private static final long serialVersionUID = -2272981616877035934L; + + @Override + public boolean filter(Row value) { + return (long) value.getField(0) < 0; + } + }); + + DataSet , Map >> featureCounts = woeModel + .groupBy(0) + .reduceGroup(new RichGroupReduceFunction , Map >>() { + private static final long serialVersionUID = 2877684088626081532L; + + @Override + public void reduce(Iterable values, + Collector , Map >> out) { + String[] selectedCols = Params + .fromJson((String) ((Row) getRuntimeContext() + .getBroadcastVariable("selectedCols") + .get(0)) + .getField(1)) + .get(WoeTrainParams.SELECTED_COLS); + + Map total = new HashMap <>(); + Map positiveTotal = new HashMap <>(); + long colIndex = -1; + for (Row row : values) { + colIndex = (Long) row.getField(0); + if (colIndex >= 0L) { + Long key = Long.valueOf((String) row.getField(1)); + total.put(key, (long) row.getField(2)); + positiveTotal.put(key, (long) row.getField(3)); + } else { + return; + } + } + out.collect(Tuple3.of(selectedCols[(int) colIndex], total, positiveTotal)); + } + }).withBroadcastSet(selectedCols, "selectedCols") + .name("GetBinTotalFromWoeModel"); + + return borderWithName + .join(featureCounts) + .where(0) + .equalTo(0) + .with( + new JoinFunction , Tuple3 , Map >, + FeatureBinsCalculator>() { + private static final long serialVersionUID = -4468441310215491228L; + + @Override + public FeatureBinsCalculator join(Tuple2 first, + Tuple3 , Map > second) { + FeatureBinsCalculator border = first.f1; + border.setTotal(second.f1); + border.setPositiveTotal(second.f2); + return border; + } + }).name("SetBinTotal"); + } + + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/finance/BinarySelectorPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/finance/BinarySelectorPredictBatchOp.java new file mode 100644 index 000000000..5ae1152bd --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/finance/BinarySelectorPredictBatchOp.java @@ -0,0 +1,28 @@ +package com.alibaba.alink.operator.batch.finance; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.SelectorModelMapper; +import com.alibaba.alink.params.finance.SelectorPredictParams; + +@ParamSelectColumnSpec(name = "selectedCol") +@NameCn("Stepwise二分类筛选预测") +@NameEn("Stepwise Binary Selector Predictor") +public class BinarySelectorPredictBatchOp extends ModelMapBatchOp + implements SelectorPredictParams { + + private static final long serialVersionUID = -801930887556428652L; + + public BinarySelectorPredictBatchOp() { + this(null); + } + + public BinarySelectorPredictBatchOp(Params params) { + super(SelectorModelMapper::new, params); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/finance/BinarySelectorTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/finance/BinarySelectorTrainBatchOp.java new file mode 100644 index 000000000..6d8a788bf --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/finance/BinarySelectorTrainBatchOp.java @@ -0,0 +1,58 @@ +package com.alibaba.alink.operator.batch.finance; + +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Table; + +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.BaseStepWiseSelectorBatchOp; +import com.alibaba.alink.params.finance.BaseStepwiseSelectorParams; +import com.alibaba.alink.params.finance.BinarySelectorTrainParams; + +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = { + @PortSpec(value = PortType.MODEL, desc = PortDesc.OUTPUT_RESULT), + @PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT) +}) +@ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@ParamSelectColumnSpec(name = "forceSelectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@NameCn("Stepwise二分类筛选训练") +@NameEn("Stepwise Binary Selector Train") +public class BinarySelectorTrainBatchOp extends BatchOperator + implements BinarySelectorTrainParams { + + private static final long serialVersionUID = -7632837469256501259L; + + public BinarySelectorTrainBatchOp() { + super(null); + } + + public BinarySelectorTrainBatchOp(Params params) { + super(params); + } + + @Override + public BinarySelectorTrainBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + + BatchOperator op = in.link(new BaseStepWiseSelectorBatchOp( + this.getParams()) + .setStepWiseType(getMethod().name()) + .setLinearModelType(BaseStepwiseSelectorParams.LinearModelType.LR) + .setOptimMethod(getOptimMethod().name()) + ); + + setOutputTable(op.getOutputTable()); + setSideOutputTables(new Table[] {op.getSideOutput(1).getOutputTable()}); + + return this; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/finance/BinningTrainForScorecardBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/finance/BinningTrainForScorecardBatchOp.java new file mode 100644 index 000000000..89608f301 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/finance/BinningTrainForScorecardBatchOp.java @@ -0,0 +1,83 @@ +package com.alibaba.alink.operator.batch.finance; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.MLEnvironmentFactory; +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.common.viz.AlinkViz; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.feature.BinningTrainBatchOp; +import com.alibaba.alink.params.feature.HasConstraint; +import com.alibaba.alink.params.finance.BinningTrainForScorecardParams; +import com.alibaba.alink.params.shared.HasMLEnvironmentId; + +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = { + @PortSpec(value = PortType.MODEL, desc = PortDesc.OUTPUT_RESULT), + @PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT) +}) +@ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@NameCn("评分卡分箱训练") +@NameEn("Score Train") +public final class BinningTrainForScorecardBatchOp extends BatchOperator + implements BinningTrainForScorecardParams , + AlinkViz { + + private static final long serialVersionUID = -1215494859549421782L; + public static TableSchema CONSTRAINT_TABLESCHEMA = new TableSchema(new String[] {"constrain"}, + new TypeInformation[] {Types.STRING}); + + public BinningTrainForScorecardBatchOp() { + } + + public BinningTrainForScorecardBatchOp(Params params) { + super(params); + } + + @Override + public BinningTrainForScorecardBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + + BinningTrainBatchOp op = new BinningTrainBatchOp(this.getParams()).linkFrom(in); + + this.setOutput(op.getDataSet(), op.getSchema()); + + String constraint = getParams().get(HasConstraint.CONSTRAINT); + constraint = null == constraint ? "" : constraint; + + DataSet dataSet = MLEnvironmentFactory + .get(this.get(HasMLEnvironmentId.ML_ENVIRONMENT_ID)) + .getExecutionEnvironment() + .fromElements(constraint) + .map(new MapFunction () { + private static final long serialVersionUID = 9023058389801218418L; + + @Override + public Row map(String value) throws Exception { + return Row.of(value); + } + }); + + this.setSideOutputTables(new Table[] { + DataSetConversionUtil.toTable(this.getMLEnvironmentId(), + dataSet, CONSTRAINT_TABLESCHEMA)}); + + return this; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedBinarySelectorPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedBinarySelectorPredictBatchOp.java new file mode 100644 index 000000000..f862ae548 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedBinarySelectorPredictBatchOp.java @@ -0,0 +1,28 @@ +package com.alibaba.alink.operator.batch.finance; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.SelectorModelMapper; +import com.alibaba.alink.params.finance.SelectorPredictParams; + +@ParamSelectColumnSpec(name = "selectedCol") +@NameCn("带约束的Stepwise二分类筛选预测") +@NameEn("Constrained Binary Selector Predictor") +public class ConstrainedBinarySelectorPredictBatchOp extends ModelMapBatchOp + implements SelectorPredictParams { + + private static final long serialVersionUID = -6112139129156096897L; + + public ConstrainedBinarySelectorPredictBatchOp() { + this(null); + } + + public ConstrainedBinarySelectorPredictBatchOp(Params params) { + super(SelectorModelMapper::new, params); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedBinarySelectorTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedBinarySelectorTrainBatchOp.java new file mode 100644 index 000000000..250e5159d --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedBinarySelectorTrainBatchOp.java @@ -0,0 +1,60 @@ +package com.alibaba.alink.operator.batch.finance; + +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Table; + +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.BaseStepWiseSelectorBatchOp; +import com.alibaba.alink.params.finance.BaseStepwiseSelectorParams; +import com.alibaba.alink.params.finance.ConstrainedBinarySelectorTrainParams; + +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = { + @PortSpec(value = PortType.MODEL, desc = PortDesc.OUTPUT_RESULT), + @PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT) +}) +@ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@ParamSelectColumnSpec(name = "forceSelectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@NameCn("带约束的Stepwise二分类筛选训练") +@NameEn("Constrained Binary Selector Trainer") +public class ConstrainedBinarySelectorTrainBatchOp extends BatchOperator + implements ConstrainedBinarySelectorTrainParams { + + private static final long serialVersionUID = 8418940566466615567L; + + public ConstrainedBinarySelectorTrainBatchOp() { + super(null); + } + + public ConstrainedBinarySelectorTrainBatchOp(Params params) { + super(params); + } + + @Override + public ConstrainedBinarySelectorTrainBatchOp linkFrom(BatchOperator ... inputs) { + if (inputs.length != 2) { + throw new RuntimeException("input must be two."); + } + + BatchOperator op = new BaseStepWiseSelectorBatchOp( + this.getParams()) + .setStepWiseType(getMethod().name()) + .setLinearModelType(BaseStepwiseSelectorParams.LinearModelType.LR) + .setOptimMethod(getOptimMethod().name()) + .linkFrom(inputs); + + setOutputTable(op.getOutputTable()); + setSideOutputTables(new Table[] {op.getSideOutput(1).getOutputTable()}); + + return this; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedLinearRegTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedLinearRegTrainBatchOp.java new file mode 100644 index 000000000..e44f83f08 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedLinearRegTrainBatchOp.java @@ -0,0 +1,30 @@ +package com.alibaba.alink.operator.batch.finance; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp; +import com.alibaba.alink.operator.common.linear.LinearModelType; +import com.alibaba.alink.params.finance.ConstrainedLinearRegTrainParams; + +/** + * Train a regression model. + */ +@NameCn("带约束的线性回归训练") +@NameEn("Constrained Linear Selector Trainer") +public final class ConstrainedLinearRegTrainBatchOp + extends BaseConstrainedLinearModelTrainBatchOp + implements ConstrainedLinearRegTrainParams { + + private static final long serialVersionUID = 7603485759107349632L; + + public ConstrainedLinearRegTrainBatchOp() { + this(new Params()); + } + + public ConstrainedLinearRegTrainBatchOp(Params params) { + super(params, LinearModelType.LinearReg, "Linear Regression"); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedLogisticRegressionTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedLogisticRegressionTrainBatchOp.java new file mode 100644 index 000000000..b02d023a0 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedLogisticRegressionTrainBatchOp.java @@ -0,0 +1,29 @@ +package com.alibaba.alink.operator.batch.finance; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp; +import com.alibaba.alink.operator.common.linear.LinearModelType; +import com.alibaba.alink.params.finance.ConstrainedLogisticRegressionTrainParams; + +/** + * Logistic regression train batch operator. we use log loss func by setting LinearModelType = LR and model + * name = "Logistic Regression". + */ +@NameCn("带约束的逻辑回归训练") +@NameEn("Constrained Logistic Regression Trainer") +public final class ConstrainedLogisticRegressionTrainBatchOp + extends BaseConstrainedLinearModelTrainBatchOp + implements ConstrainedLogisticRegressionTrainParams { + private static final long serialVersionUID = 3324942229315576654L; + + public ConstrainedLogisticRegressionTrainBatchOp() { + this(new Params()); + } + + public ConstrainedLogisticRegressionTrainBatchOp(Params params) { + super(params, LinearModelType.LR, "Logistic Regression"); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedRegSelectorPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedRegSelectorPredictBatchOp.java new file mode 100644 index 000000000..65a53e853 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedRegSelectorPredictBatchOp.java @@ -0,0 +1,29 @@ +package com.alibaba.alink.operator.batch.finance; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.SelectorModelMapper; +import com.alibaba.alink.params.finance.SelectorPredictParams; + +@ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@NameCn("带约束的Stepwise回归筛选预测") +@NameEn("Constrained Linear Selector Predictor") +public class ConstrainedRegSelectorPredictBatchOp extends ModelMapBatchOp + implements SelectorPredictParams { + + private static final long serialVersionUID = 152592469378068233L; + + public ConstrainedRegSelectorPredictBatchOp() { + this(null); + } + + public ConstrainedRegSelectorPredictBatchOp(Params params) { + super(SelectorModelMapper::new, params); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedRegSelectorTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedRegSelectorTrainBatchOp.java new file mode 100644 index 000000000..de57a2fe4 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/finance/ConstrainedRegSelectorTrainBatchOp.java @@ -0,0 +1,62 @@ +package com.alibaba.alink.operator.batch.finance; + +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Table; + +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.BaseStepWiseSelectorBatchOp; +import com.alibaba.alink.params.finance.BaseStepwiseSelectorParams; +import com.alibaba.alink.params.finance.ConstrainedRegSelectorTrainParams; + +import java.security.InvalidParameterException; + +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = { + @PortSpec(value = PortType.MODEL, desc = PortDesc.OUTPUT_RESULT), + @PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT) +}) +@ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@ParamSelectColumnSpec(name = "forceSelectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@NameCn("带约束的Stepwise回归筛选训练") +@NameEn("Constrained Linear Selector Trainer") +public class ConstrainedRegSelectorTrainBatchOp extends BatchOperator + implements ConstrainedRegSelectorTrainParams { + + private static final long serialVersionUID = 368730481323862021L; + + public ConstrainedRegSelectorTrainBatchOp() { + super(null); + } + + public ConstrainedRegSelectorTrainBatchOp(Params params) { + super(params); + } + + @Override + public ConstrainedRegSelectorTrainBatchOp linkFrom(BatchOperator ... inputs) { + if (inputs.length != 2) { + throw new InvalidParameterException("input size must be two."); + } + + BatchOperator op = new BaseStepWiseSelectorBatchOp( + this.getParams()) + .setStepWiseType(getMethod().name()) + .setLinearModelType(BaseStepwiseSelectorParams.LinearModelType.LinearReg) + .setOptimMethod(getOptimMethod().name()) + .linkFrom(inputs); + + setOutputTable(op.getOutputTable()); + setSideOutputTables(new Table[] {op.getSideOutput(1).getOutputTable()}); + + return this; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/finance/RegressionSelectorPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/finance/RegressionSelectorPredictBatchOp.java new file mode 100644 index 000000000..cefa9e101 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/finance/RegressionSelectorPredictBatchOp.java @@ -0,0 +1,29 @@ +package com.alibaba.alink.operator.batch.finance; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.SelectorModelMapper; +import com.alibaba.alink.params.finance.SelectorPredictParams; + +@ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@NameCn("Stepwise回归筛选预测") +@NameEn("Regression Selector Predictor") +public class RegressionSelectorPredictBatchOp extends ModelMapBatchOp + implements SelectorPredictParams { + + private static final long serialVersionUID = 6316651625329228306L; + + public RegressionSelectorPredictBatchOp() { + this(null); + } + + public RegressionSelectorPredictBatchOp(Params params) { + super(SelectorModelMapper::new, params); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/finance/RegressionSelectorTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/finance/RegressionSelectorTrainBatchOp.java new file mode 100644 index 000000000..e4437e8f8 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/finance/RegressionSelectorTrainBatchOp.java @@ -0,0 +1,58 @@ +package com.alibaba.alink.operator.batch.finance; + +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Table; + +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.BaseStepWiseSelectorBatchOp; +import com.alibaba.alink.params.finance.BaseStepwiseSelectorParams; +import com.alibaba.alink.params.finance.RegressionSelectorParams; + +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = { + @PortSpec(value = PortType.MODEL, desc = PortDesc.OUTPUT_RESULT), + @PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT) +}) +@ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@ParamSelectColumnSpec(name = "forceSelectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@NameCn("Stepwise回归筛选预测") +@NameEn("Regression Selector Trainer") +public class RegressionSelectorTrainBatchOp extends BatchOperator + implements RegressionSelectorParams { + + private static final long serialVersionUID = 4105397009104570405L; + + public RegressionSelectorTrainBatchOp() { + super(null); + } + + public RegressionSelectorTrainBatchOp(Params params) { + super(params); + } + + @Override + public RegressionSelectorTrainBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + + BatchOperator op = in.link(new BaseStepWiseSelectorBatchOp( + this.getParams()) + .setStepWiseType(getMethod().name()) + .setLinearModelType(BaseStepwiseSelectorParams.LinearModelType.LinearReg) + .setOptimMethod(getOptimMethod().name()) + ); + + setOutputTable(op.getOutputTable()); + setSideOutputTables(new Table[] {op.getSideOutput(1).getOutputTable()}); + + return this; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/finance/ScorecardPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/finance/ScorecardPredictBatchOp.java new file mode 100644 index 000000000..5b1a89340 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/finance/ScorecardPredictBatchOp.java @@ -0,0 +1,25 @@ +package com.alibaba.alink.operator.batch.finance; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.finance.ScorecardModelMapper; +import com.alibaba.alink.params.finance.ScorecardPredictParams; + +@NameCn("评分卡预测") +@NameEn("Score Predict") +public class ScorecardPredictBatchOp extends ModelMapBatchOp + implements ScorecardPredictParams { + + private static final long serialVersionUID = -3498559932886873694L; + + public ScorecardPredictBatchOp() { + this(new Params()); + } + + public ScorecardPredictBatchOp(Params params) { + super(ScorecardModelMapper::new, params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/finance/ScorecardTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/finance/ScorecardTrainBatchOp.java new file mode 100644 index 000000000..5b0e84fa7 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/finance/ScorecardTrainBatchOp.java @@ -0,0 +1,986 @@ +package com.alibaba.alink.operator.batch.finance; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.misc.param.ParamInfo; +import org.apache.flink.ml.api.misc.param.ParamInfoFactory; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.types.DataType; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import com.alibaba.alink.common.annotation.FeatureColsVectorColMutexRule; +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; +import com.alibaba.alink.common.mapper.PipelineModelMapper; +import com.alibaba.alink.common.viz.AlinkViz; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.common.viz.VizDataWriterInterface; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.feature.BinningPredictBatchOp; +import com.alibaba.alink.operator.batch.feature.BinningTrainBatchOp; +import com.alibaba.alink.operator.batch.source.DataSetWrapperBatchOp; +import com.alibaba.alink.operator.batch.utils.DataSetUtil; +import com.alibaba.alink.operator.common.feature.binning.Bins; +import com.alibaba.alink.operator.common.feature.binning.FeatureBinsCalculator; +import com.alibaba.alink.operator.common.feature.binning.FeatureBinsUtil; +import com.alibaba.alink.operator.common.finance.ScorecardModelInfo; +import com.alibaba.alink.operator.common.finance.ScorecardModelInfoBatchOp; +import com.alibaba.alink.operator.common.finance.VizData; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.BaseStepWiseSelectorBatchOp; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.ClassificationSelectorResult; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.SelectorResult; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.StepWiseType; +import com.alibaba.alink.operator.common.linear.LinearModelData; +import com.alibaba.alink.operator.common.linear.LinearModelDataConverter; +import com.alibaba.alink.operator.common.linear.ModelSummaryHelper; +import com.alibaba.alink.operator.common.optim.ConstraintBetweenBins; +import com.alibaba.alink.operator.common.optim.FeatureConstraint; +import com.alibaba.alink.params.dataproc.HasHandleInvalid; +import com.alibaba.alink.params.feature.HasEncode; +import com.alibaba.alink.params.finance.BaseStepwiseSelectorParams; +import com.alibaba.alink.params.finance.BinningPredictParams; +import com.alibaba.alink.params.finance.FitScaleParams; +import com.alibaba.alink.params.finance.HasConstrainedOptimizationMethod; +import com.alibaba.alink.params.finance.HasPdo; +import com.alibaba.alink.params.finance.HasScaledValue; +import com.alibaba.alink.params.finance.ScorecardTrainParams; +import com.alibaba.alink.params.shared.HasMLEnvironmentId; +import com.alibaba.alink.params.shared.colname.HasFeatureColsDefaultAsNull; +import com.alibaba.alink.params.shared.colname.HasLabelCol; +import com.alibaba.alink.params.shared.colname.HasOutputCol; +import com.alibaba.alink.params.shared.colname.HasOutputColsDefaultAsNull; +import com.alibaba.alink.params.shared.colname.HasSelectedCols; +import com.alibaba.alink.params.shared.colname.HasVectorColDefaultAsNull; +import com.alibaba.alink.params.shared.colname.HasWeightColDefaultAsNull; +import com.alibaba.alink.params.shared.iter.HasMaxIterDefaultAs100; +import com.alibaba.alink.params.shared.linear.HasEpsilonDefaultAs0000001; +import com.alibaba.alink.params.shared.linear.HasL1; +import com.alibaba.alink.params.shared.linear.HasL2; +import com.alibaba.alink.params.shared.linear.HasPositiveLabelValueStringDefaultAs1; +import com.alibaba.alink.params.shared.linear.HasStandardization; +import com.alibaba.alink.params.shared.linear.HasWithIntercept; +import com.alibaba.alink.pipeline.PipelineModel; +import com.alibaba.alink.pipeline.TransformerBase; +import com.alibaba.alink.pipeline.feature.BinningModel; +import com.alibaba.alink.pipeline.finance.ScoreModel; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.TreeSet; + +@InputPorts(values = { + @PortSpec(PortType.DATA), + @PortSpec(PortType.MODEL), + @PortSpec(value = PortType.DATA), +}) +@OutputPorts(values = { + @PortSpec(PortType.MODEL) +}) + +@ParamSelectColumnSpec(name = "selectCols", + allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@ParamSelectColumnSpec(name = "labelCol") +@ParamSelectColumnSpec(name = "weightCol", + allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@ParamSelectColumnSpec(name = "forceSelectedCols") +@FeatureColsVectorColMutexRule +@NameCn("评分卡训练") +@NameEn("Score Trainer") +public class ScorecardTrainBatchOp extends BatchOperator + implements ScorecardTrainParams , AlinkViz , + WithModelInfoBatchOp { + private static final long serialVersionUID = 9216204663345544968L; + static double SCALE_A = 1.0; + static double SCALE_B = 0.0; + + public static String BINNING_OUTPUT_COL = "BINNING_PREDICT"; + + private static String UNSCALED_MODEL = "unscaledModel"; + private static String SCALED_MODEL = "scaledModel"; + private static String BIN_COUNT = "binCount"; + private static String INTERCEPT = "intercept"; + private static String STEPWISE_MODEL = "stepwiseModel"; + private static String SCORECARD_MODEL = "scorecardModel"; + + public static ParamInfo WITH_ELSE = ParamInfoFactory + .createParamInfo("withElse", Map.class) + .setDescription("has else or not") + .setHasDefaultValue(null) + .build(); + + public static ParamInfo IN_SCORECARD = ParamInfoFactory + .createParamInfo("inScorecard", Boolean.class) + .setDescription("calculate linear model in scorecard else or not") + .setHasDefaultValue(false) + .build(); + + public ScorecardTrainBatchOp(Params params) { + super(params); + } + + public ScorecardTrainBatchOp() { + super(null); + } + + public static Tuple2 loadScaleInfo(Params params) { + if (!params.contains(FitScaleParams.SCALED_VALUE) && !params.contains(FitScaleParams.ODDS) && !params.contains( + FitScaleParams.PDO)) { + return Tuple2.of(SCALE_A, SCALE_B); + } + if (params.contains(FitScaleParams.SCALED_VALUE) && params.contains(FitScaleParams.ODDS) && params.contains( + FitScaleParams.PDO)) { + double odds = params.get(FitScaleParams.ODDS); + double logOdds = Math.log(odds); + double scaleA = (Math.log(odds * 2) - logOdds) / params.get(HasPdo.PDO); + double scaleB = logOdds - scaleA * params.get(HasScaledValue.SCALED_VALUE); + return Tuple2.of(scaleA, scaleB); + } + return Tuple2.of(SCALE_A, SCALE_B); + } + + @Override + public ScorecardTrainBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator data = inputs[0]; + BatchOperator binningModel = null; + if (inputs.length > 1) { + binningModel = inputs[1]; + } + BatchOperator constraint = null; + if (inputs.length == 3) { + constraint = inputs[2]; + } + TypeInformation labelType = TableUtil.findColTypeWithAssertAndHint(data.getSchema(), this.getLabelCol()); + List > finalModel = new ArrayList <>(); + + Boolean withSelector = getWithSelector(); + + //binning + DataSet featureBorderDataSet = null; + String[] selectedCols = this.getSelectedCols(); + String[] outputCols = this.getSelectedCols(); + String outputCol = null; + Encode encode = binningModel == null ? Encode.NULL : getEncode(); + switch (encode) { + case ASSEMBLED_VECTOR: + case WOE: { + Preconditions.checkNotNull(binningModel, "BinningModel is empty!"); + featureBorderDataSet = FeatureBinsUtil.parseFeatureBinsModel(binningModel.getDataSet()); + Params binningPredictParams = getBinningOutputParams(getParams(), withSelector); + data = new BinningPredictBatchOp(binningPredictParams).linkFrom(binningModel, data); + finalModel.add(setBinningModelData(binningModel, binningPredictParams)); + outputCols = getParams().get(HasOutputColsDefaultAsNull.OUTPUT_COLS); + outputCol = getParams().get(HasOutputCol.OUTPUT_COL); + break; + } + case NULL: { + TableUtil.assertNumericalCols(data.getSchema(), selectedCols); + break; + } + default: { + throw new RuntimeException("Not support " + encode.name()); + } + } + Map withElse = withElse(inputs[0].getSchema(), selectedCols); + constraint = unionConstraint(constraint, featureBorderDataSet, encode, selectedCols, withElse); + + //model + BatchOperator model; + DataSet modeSummary = null; + + if (!withSelector) { + Params linearTrainParams = new Params() + .set(HasMLEnvironmentId.ML_ENVIRONMENT_ID, getParams().get(HasMLEnvironmentId.ML_ENVIRONMENT_ID)) + .set(HasWithIntercept.WITH_INTERCEPT, true) + .set(HasFeatureColsDefaultAsNull.FEATURE_COLS, outputCols) + .set(HasVectorColDefaultAsNull.VECTOR_COL, outputCol) + .set(HasLabelCol.LABEL_COL, getLabelCol()) + .set(HasWeightColDefaultAsNull.WEIGHT_COL, getWeightCol()) + .set(HasPositiveLabelValueStringDefaultAs1.POS_LABEL_VAL_STR, getPositiveLabelValueString()) + .set(WITH_ELSE, withElse) + .set(IN_SCORECARD, true) + .set(HasEpsilonDefaultAs0000001.EPSILON, getEpsilon()) + .set(HasL2.L_2, getL2()) + .set(HasL1.L_1, getL1()) + .set(HasStandardization.STANDARDIZATION, true) + .set(HasMaxIterDefaultAs100.MAX_ITER, getMaxIter()) + .set(HasConstrainedOptimizationMethod.CONST_OPTIM_METHOD, getConstOptimMethod()); + + if (getConstOptimMethod() == null) { + linearTrainParams.set(HasConstrainedOptimizationMethod.CONST_OPTIM_METHOD, + null == constraint ? ConstOptimMethod.LBFGS : ConstOptimMethod.SQP); + } + + ConstOptimMethod optMethod = linearTrainParams.get(HasConstrainedOptimizationMethod.CONST_OPTIM_METHOD); + boolean useConstraint = (encode == Encode.ASSEMBLED_VECTOR) && + (optMethod == ConstOptimMethod.SQP || optMethod == ConstOptimMethod.Barrier); + if (ModelSummaryHelper.isLinearRegression(getLinearModelType().toString())) { + if (useConstraint) { + model = new ConstrainedLinearRegTrainBatchOp(linearTrainParams).linkFrom(data, constraint); + } else { + model = new ConstrainedLinearRegTrainBatchOp(linearTrainParams).linkFrom(data, null); + } + } else { + if (useConstraint) { + model = new ConstrainedLogisticRegressionTrainBatchOp(linearTrainParams).linkFrom(data, + constraint); + } else { + model = new ConstrainedLogisticRegressionTrainBatchOp(linearTrainParams).linkFrom(data, null); + } + } + } else { + Params stepwiseParams = new Params() + .set(HasMLEnvironmentId.ML_ENVIRONMENT_ID, getParams().get(HasMLEnvironmentId.ML_ENVIRONMENT_ID)) + .set(BaseStepwiseSelectorParams.SELECTED_COLS, outputCols) + .set(BaseStepwiseSelectorParams.LABEL_COL, getLabelCol()) + .set(BaseStepwiseSelectorParams.LINEAR_MODEL_TYPE, getLinearModelType()) + .set(BaseStepwiseSelectorParams.OPTIM_METHOD, getConstOptimMethod().name()) + .set(BaseStepwiseSelectorParams.ALPHA_ENTRY, getAlphaEntry()) + .set(BaseStepwiseSelectorParams.ALPHA_STAY, getAlphaStay()) + .set(HasPositiveLabelValueStringDefaultAs1.POS_LABEL_VAL_STR, getPositiveLabelValueString()) + .set(IN_SCORECARD, true) + .set(BaseStepwiseSelectorParams.WITH_VIZ, false); + + if (ScorecardTrainParams.Encode.ASSEMBLED_VECTOR == encode) { + stepwiseParams.set(BaseStepwiseSelectorParams.STEP_WISE_TYPE, StepWiseType.marginalContribution); + } else { + if (LinearModelType.LR == getLinearModelType()) { + stepwiseParams.set(BaseStepwiseSelectorParams.STEP_WISE_TYPE, StepWiseType.scoreTest); + } else { + stepwiseParams.set(BaseStepwiseSelectorParams.STEP_WISE_TYPE, StepWiseType.fTest); + } + } + if (getParams().contains(ScorecardTrainParams.FORCE_SELECTED_COLS)) { + String[] forceCols = getForceSelectedCols(); + if (forceCols != null && forceCols.length != 0) { + int[] forceColsIndices = TableUtil.findColIndicesWithAssertAndHint(getSelectedCols(), forceCols); + stepwiseParams.set(BaseStepwiseSelectorParams.FORCE_SELECTED_COLS, forceColsIndices); + } + } + + BaseStepWiseSelectorBatchOp stepwiseModel = new BaseStepWiseSelectorBatchOp(stepwiseParams).linkFrom(data, + constraint); + model = stepwiseModel.getSideOutput(0); + modeSummary = stepwiseModel.getStepWiseSummary(); + } + + Preconditions.checkArgument(model != null, "Unscaled model is not set!"); + + finalModel.add( + setScoreModelData( + model, + new Params().set( + HasMLEnvironmentId.ML_ENVIRONMENT_ID, + getParams().get(HasMLEnvironmentId.ML_ENVIRONMENT_ID) + ), + labelType + ) + ); + + DataSet linearModelDataDataSet = model + .getDataSet() + .mapPartition(new LoadLinearModel()) + .setParallelism(1); + + VizDataWriterInterface writer = this.getVizDataWriter(); + + //statistic viz + statViz(writer, modeSummary, encode, withSelector, + data, linearModelDataDataSet, outputCol, outputCols); + + //fit scale + DataSet scaledLinearModelDataSet = linearModelDataDataSet; + if (getScaleInfo()) { + Preconditions.checkArgument( + null != getOdds() && null != getPdo() && null != getScaledValue(), + "ScaledValue/Pdo/Odds must be set!" + ); + scaledLinearModelDataSet = linearModelDataDataSet.map(new ScaleLinearModel(getParams())); + finalModel.add(setScoreModelData( + BatchOperator.fromTable(DataSetConversionUtil + .toTable(getParams().get(HasMLEnvironmentId.ML_ENVIRONMENT_ID), + scaledLinearModelDataSet.flatMap(new SerializeLinearModel()), + new LinearModelDataConverter(labelType).getModelSchema()) + + ), + new Params() + .set(HasMLEnvironmentId.ML_ENVIRONMENT_ID, getParams().get(HasMLEnvironmentId.ML_ENVIRONMENT_ID)), + labelType)); + } + + //pipeline model transform + BatchOperator savedModel = new PipelineModel(finalModel.toArray(new TransformerBase[0])).save(); + + DataSet modelRows = savedModel + .getDataSet() + .map(new PipelineModelMapper.ExtendPipelineModelRow(selectedCols.length + 1)); + + TypeInformation[] selectedTypes = new TypeInformation[this.getSelectedCols().length]; + Arrays.fill(selectedTypes, labelType); + + TableSchema modelSchema = PipelineModelMapper.getExtendModelSchema( + savedModel.getSchema(), this.getSelectedCols(), selectedTypes); + + this.setOutput(modelRows, modelSchema); + + if (null != featureBorderDataSet) { + featureBorderDataSet = updateStatisticsInFeatureBins(getParams(), inputs[0], binningModel, + featureBorderDataSet); + + DataSet vizData; + if (withSelector) { + vizData = featureBorderDataSet + .flatMap(new FeatureBinsToScorecard(getParams(), withSelector)) + .withBroadcastSet(linearModelDataDataSet, UNSCALED_MODEL) + .withBroadcastSet(scaledLinearModelDataSet, SCALED_MODEL) + .withBroadcastSet(featureNameBinCount(featureBorderDataSet), BIN_COUNT) + .withBroadcastSet(modeSummary, STEPWISE_MODEL) + ; + } else { + vizData = featureBorderDataSet + .flatMap(new FeatureBinsToScorecard(getParams(), withSelector)) + .withBroadcastSet(linearModelDataDataSet, UNSCALED_MODEL) + .withBroadcastSet(scaledLinearModelDataSet, SCALED_MODEL) + .withBroadcastSet(featureNameBinCount(featureBorderDataSet), BIN_COUNT); + } + + if (writer != null) { + DataSet dummy = vizData + .mapPartition(new VizDataWriter(writer)) + .setParallelism(1) + .name("WriteVizData"); + DataSetUtil.linkDummySink(dummy); + } + } + + if (writer != null) { + DataSet dummy = modelRows + .mapPartition(new WriteModelPMML(writer, modelSchema)) + .setParallelism(1); + DataSetUtil.linkDummySink(dummy); + } + return this; + } + + private void statViz(VizDataWriterInterface writer, + DataSet stepwiseSummary, Encode encode, Boolean withSelector, + BatchOperator data, DataSet linearModelDataDataSet, String outputCol, + String[] outputCols) { + if ((encode == Encode.WOE || encode == Encode.NULL) || withSelector) { + DataSet selectorResult; + if (withSelector) { + selectorResult = stepwiseSummary; + } else { + selectorResult = ModelSummaryHelper.calModelSummary(data, getLinearModelType(), + linearModelDataDataSet, outputCol, outputCols, getLabelCol()); + } + DataSetUtil.linkDummySink(selectorResult.flatMap(new StepSummaryVizDataWriter(writer))); + } + } + + private static class WriteModelPMML implements MapPartitionFunction { + private VizDataWriterInterface writer; + + private final String[] modelFieldNames; + + /** + * Field types of the model. + */ + private final DataType[] modelFieldTypes; + + public WriteModelPMML(VizDataWriterInterface writer, TableSchema schema) { + this.writer = writer; + this.modelFieldNames = schema.getFieldNames(); + this.modelFieldTypes = schema.getFieldDataTypes(); + } + + @Override + public void mapPartition(Iterable iterable, Collector collector) { + List modelRows = new ArrayList <>(); + iterable.forEach(modelRows::add); + ScorecardModelInfo summary = new ScorecardModelInfo( + modelRows, TableSchema.builder().fields(modelFieldNames, modelFieldTypes).build() + ); + //System.out.println(summary.getPMML()); + writer.writeBatchData(3L, summary.getPMML(), System.currentTimeMillis()); + } + } + + private static DataSet updateStatisticsInFeatureBins(Params params, + BatchOperator data, + BatchOperator binningModel, + DataSet + featureBorderDataSet) { + BinningPredictBatchOp op = new BinningPredictBatchOp(BinningTrainBatchOp.encodeIndexForWoeTrainParams( + params.get(HasSelectedCols.SELECTED_COLS), params.get(HasMLEnvironmentId.ML_ENVIRONMENT_ID))); + + op.linkFrom(binningModel, data); + + return BinningTrainBatchOp.setFeatureBinsTotalAndWoe(featureBorderDataSet, op, params); + } + + private static class VizDataWriter extends RichMapPartitionFunction { + private static final long serialVersionUID = -6344386801175328688L; + private VizDataWriterInterface writer; + + public VizDataWriter(VizDataWriterInterface writer) { + this.writer = writer; + } + + @Override + public void mapPartition(Iterable iterable, Collector collector) + throws Exception { + + Map > map = new HashMap <>(); + for (VizData.ScorecardVizData data : iterable) { + TreeSet list = map.computeIfAbsent(data.featureName, + k -> new TreeSet <>(VizData.VizDataComparator)); + if (!data.featureName.equals(INTERCEPT) || list.size() < 1) { + list.add(data); + } + } + //System.out.println(JsonConverter.toJson(map)); + writer.writeBatchData(0L, JsonConverter.toJson(map), System.currentTimeMillis()); + } + } + + private static Map withElse(TableSchema tableSchema, String[] selectedCols) { + Map withElse = new HashMap <>(); + for (String s : selectedCols) { + withElse.put(s, !TableUtil.isSupportedNumericType(TableUtil.findColTypeWithAssertAndHint(tableSchema, s))); + } + return withElse; + } + + private BatchOperator unionConstraint(BatchOperator constraint, + DataSet featureBorderDataSet, + Encode encode, + String[] selectedCols, + Map withElse) { + if (encode.equals(Encode.ASSEMBLED_VECTOR)) { + Preconditions.checkNotNull(featureBorderDataSet, "Binning Model is empty!"); + BatchOperator binCountConstraint = binningModelToConstraint(featureBorderDataSet, selectedCols, + this.getMLEnvironmentId()); + if (constraint == null) { + constraint = binCountConstraint; + } else { + DataSet binConstraintDataSet = binCountConstraint.getDataSet(); + DataSet udConstraintDataSet = constraint.getDataSet(); + //withElse will not be null. + DataSet cons = binConstraintDataSet.map(new MapConstraints(withElse)) + .withBroadcastSet(udConstraintDataSet, "udConstraint"); + constraint = new DataSetWrapperBatchOp(cons, + constraint.getColNames(), + new TypeInformation[] {TypeInformation.of(FeatureConstraint.class)}) + .setMLEnvironmentId(this.getMLEnvironmentId()); + } + } else { + constraint = null; + } + return constraint; + } + + private static class StepSummaryVizDataWriter implements FlatMapFunction { + private static final long serialVersionUID = -2495790734379602318L; + private VizDataWriterInterface writer; + + public StepSummaryVizDataWriter(VizDataWriterInterface writer) { + this.writer = writer; + } + + @Override + public void flatMap(SelectorResult val, Collector collector) throws Exception { + String type = "linearReg"; + if (val instanceof ClassificationSelectorResult) { + type = "classification"; + } + val.selectedCols = trimCols(val.selectedCols, BINNING_OUTPUT_COL); + writer.writeBatchData(1L, type, System.currentTimeMillis()); + writer.writeBatchData(2L, val.toVizData(), System.currentTimeMillis()); + } + } + + public static String[] trimCols(String[] cols, String trimStr) { + if (cols == null) { + return null; + } + String[] trimCols = new String[cols.length]; + for (int i = 0; i < trimCols.length; i++) { + int index = cols[i].lastIndexOf(trimStr); + if (index < 0) { + trimCols[i] = cols[i]; + } else { + trimCols[i] = cols[i].substring(0, index); + } + } + return trimCols; + } + + public static class FeatureBinsToScorecard + extends RichFlatMapFunction { + private static final long serialVersionUID = 332595127422145833L; + private ScorecardTransformData transformData; + + public FeatureBinsToScorecard(Params params, boolean withSelector) { + transformData = new ScorecardTransformData(params.get(ScorecardTrainParams.SCALE_INFO), + params.get(ScorecardTrainParams.SELECTED_COLS), + Encode.WOE.equals(params.get(ScorecardTrainParams.ENCODE)), + params.get(ScorecardTrainParams.DEFAULT_WOE), + withSelector); + } + + @Override + public void open(Configuration configuration) { + transformData.unscaledModel = getModelData((LinearModelData) + this.getRuntimeContext().getBroadcastVariable(UNSCALED_MODEL).get(0), transformData); + if (transformData.isScaled) { + transformData.scaledModel = getModelData((LinearModelData) + this.getRuntimeContext().getBroadcastVariable(SCALED_MODEL).get(0), transformData); + } + if (!transformData.isWoe) { + List > binCounts = this.getRuntimeContext().getBroadcastVariable(BIN_COUNT); + Map nameBinCountMap = new HashMap <>(); + binCounts.forEach(binCount -> nameBinCountMap.put(binCount.f0, binCount.f1)); + initializeStartIndex(transformData, nameBinCountMap); + } + if (transformData.withSelector) { + SelectorResult selectorSummary = (SelectorResult) this.getRuntimeContext().getBroadcastVariable( + STEPWISE_MODEL).get(0); + transformData.stepwiseSelectedCols = trimCols( + BaseStepWiseSelectorBatchOp + .getSCurSelectedCols(transformData.selectedCols, selectorSummary.selectedIndices), + BINNING_OUTPUT_COL); + } + } + + @Override + public void flatMap(FeatureBinsCalculator featureBinsCalculator, Collector out) { + featureBinsCalculator.calcStatistics(); + String featureColName = featureBinsCalculator.getFeatureName(); + transformData.featureIndex = TableUtil.findColIndex(transformData.selectedCols, + featureBinsCalculator.getFeatureName()); + if (transformData.featureIndex < 0) { + return; + } + featureBinsCalculator.splitsArrayToInterval(); + if (null != featureBinsCalculator.bin.nullBin) { + VizData.ScorecardVizData vizData = transform(featureBinsCalculator, featureBinsCalculator.bin.nullBin, + FeatureBinsUtil.NULL_LABEL, transformData); + vizData.index = -1L; + out.collect(vizData); + } + if (null != featureBinsCalculator.bin.elseBin) { + VizData.ScorecardVizData vizData = transform(featureBinsCalculator, featureBinsCalculator.bin.elseBin, + FeatureBinsUtil.ELSE_LABEL, transformData); + vizData.index = -2L; + out.collect(vizData); + } + if (null != featureBinsCalculator.bin.normBins) { + for (Bins.BaseBin bin : featureBinsCalculator.bin.normBins) { + out.collect( + transform(featureBinsCalculator, bin, bin.getValueStr(featureBinsCalculator.getColType()), + transformData)); + } + } + //write efficient + VizData.ScorecardVizData firstLine = new VizData.ScorecardVizData(featureBinsCalculator.getFeatureName(), + null, + null); + firstLine.total = featureBinsCalculator.getTotal(); + firstLine.positive = featureBinsCalculator.getPositiveTotal(); + if (firstLine.total != null && firstLine.positive != null) { + firstLine.negative = firstLine.total - firstLine.positive; + firstLine.positiveRate = 100.0; + firstLine.negativeRate = 100.0; + } + if (transformData.isWoe) { + int linearCoefIndex = findLinearModelCoefIdx(featureColName, null, transformData); + firstLine.unscaledValue = FeatureBinsUtil.keepGivenDecimal( + getLinearCoef(transformData.unscaledModel, linearCoefIndex), + 3); + firstLine.scaledValue = transformData.isScaled ? FeatureBinsUtil.keepGivenDecimal( + getLinearCoef(transformData.scaledModel, linearCoefIndex), 3) : null; + out.collect(firstLine); + } else { + out.collect(new VizData.ScorecardVizData(featureBinsCalculator.getFeatureName(), null, null)); + } + //write intercept + VizData.ScorecardVizData intercept = new VizData.ScorecardVizData(INTERCEPT, null, null); + intercept.unscaledValue = FeatureBinsUtil.keepGivenDecimal(transformData.unscaledModel[0], 3); + intercept.scaledValue = transformData.isScaled ? FeatureBinsUtil.keepGivenDecimal( + transformData.scaledModel[0], 3) : null; + out.collect(intercept); + + } + + public static VizData.ScorecardVizData transform(FeatureBinsCalculator featureBinsCalculator, + Bins.BaseBin bin, + String label, + ScorecardTransformData transformData) { + //write efficients, must be calculate first + VizData.ScorecardVizData vizData = new VizData.ScorecardVizData(featureBinsCalculator.getFeatureName(), + bin.getIndex(), + label); + int binIndex = bin.getIndex().intValue(); + String featureColName = featureBinsCalculator.getFeatureName(); + + int linearModelCoefIdx = findLinearModelCoefIdx(featureColName, binIndex, transformData); + + if (linearModelCoefIdx >= 0) { + vizData.unscaledValue = FeatureBinsUtil.keepGivenDecimal( + getModelValue(linearModelCoefIdx, bin.getWoe(), transformData.unscaledModel, transformData.isWoe, + transformData.defaultWoe), 3); + vizData.scaledValue = transformData.isScaled ? FeatureBinsUtil.keepGivenDecimal( + getModelValue(linearModelCoefIdx, bin.getWoe(), transformData.scaledModel, transformData.isWoe, + transformData.defaultWoe), 0) : null; + } + + //change the statistics of the bin, must be set after the efficients are set + featureBinsCalculator.calcBinStatistics(bin); + vizData.total = bin.getTotal(); + vizData.positive = bin.getPositive(); + vizData.negative = bin.getNegative(); + vizData.positiveRate = bin.getPositiveRate(); + vizData.negativeRate = bin.getNegativeRate(); + vizData.woe = bin.getWoe(); + + vizData.positiveRate = (null == vizData.positiveRate ? null : FeatureBinsUtil.keepGivenDecimal( + vizData.positiveRate * 100, 2)); + vizData.negativeRate = (null == vizData.negativeRate ? null : FeatureBinsUtil.keepGivenDecimal( + vizData.negativeRate * 100, 2)); + + return vizData; + } + + public static void initializeStartIndex(ScorecardTransformData transformData, + Map nameBinCountMap) { + transformData.startIndex = new int[transformData.selectedCols.length]; + int i = 1; + for (; i < transformData.selectedCols.length; i++) { + transformData.startIndex[i] = nameBinCountMap.get( + transformData.selectedCols[i - 1]) + transformData.startIndex[i - 1]; + } + if (!transformData.withSelector) { + Preconditions.checkArgument( + transformData.startIndex[i - 1] + nameBinCountMap.get( + transformData.selectedCols[i - 1]) == transformData.unscaledModel.length - 1, + "Assembled vector size error!"); + } + } + + public static double[] getModelData(LinearModelData linearModelData, ScorecardTransformData transformData) { + double[] modelData = linearModelData.coefVector.getData(); + Preconditions.checkState(linearModelData.hasInterceptItem, + "LinearModel in Scorecard not have intercept!"); + if (!transformData.withSelector) { + Preconditions.checkState( + !transformData.isWoe || modelData.length == transformData.selectedCols.length + 1, + "SelectedCol length: " + transformData.selectedCols.length + "; Model efficients length: " + + modelData.length); + } + return modelData; + } + + public static int findLinearModelCoefIdx(String featureColName, + Integer binIndex, + ScorecardTransformData transformData) { + int result = 0; + if (transformData.isWoe) { + result = transformData.withSelector + ? TableUtil.findColIndex(transformData.stepwiseSelectedCols, featureColName) + : transformData.featureIndex; + } else { + if (transformData.withSelector) { + int idx = TableUtil.findColIndex(transformData.stepwiseSelectedCols, featureColName); + result = idx == -1 ? -1 : getIdx(transformData, featureColName) + binIndex; + } else { + result = transformData.startIndex[transformData.featureIndex] + binIndex; + } + } + return result; + } + + private static int getIdx(ScorecardTransformData transformData, String featureColName) { + int stepwiseIndex = TableUtil.findColIndex(transformData.stepwiseSelectedCols, featureColName); + int startIndex = 0; + for (int i = 0; i < stepwiseIndex; i++) { + int idx = TableUtil.findColIndex(transformData.selectedCols, transformData.stepwiseSelectedCols[i]); + if (idx < transformData.startIndex.length - 1) { + startIndex += (transformData.startIndex[idx + 1] - transformData.startIndex[idx]); + } else { + startIndex += (transformData.unscaledModel.length - 1 - transformData.startIndex[ + transformData.startIndex.length - 1]); + } + } + return startIndex; + } + + public static Double getModelValue(int linearModelCoefIdx, Double woe, double[] efficients, boolean isWoe, + double defaultWoe) { + if (isWoe) { + Double val = getLinearCoef(efficients, linearModelCoefIdx); + if (val == null) { + return null; + } + if (null == woe) { + return Double.isNaN(defaultWoe) ? null : val * defaultWoe; + } else { + return val * woe; + } + } else { + return getLinearCoef(efficients, linearModelCoefIdx); + } + } + + //if intercept, it will + 1 + private static Double getLinearCoef(double[] efficients, int linearModelCoefIdx) { + return linearModelCoefIdx < 0 ? null : efficients[linearModelCoefIdx + 1]; + } + + } + + public static class ScorecardTransformData implements Serializable { + public double[] unscaledModel; + public double[] scaledModel; + public boolean isScaled; + public boolean isWoe; + public String[] selectedCols; + public double defaultWoe; + public int featureIndex; + public int[] startIndex; + + //for stepwise + public boolean withSelector; + public LinearModelType linearModelType; + public String[] stepwiseSelectedCols; + + public ScorecardTransformData(boolean isScaled, String[] selectedCols, boolean isWoe, double defaultWoe, + boolean withSelector) { + this.isScaled = isScaled; + this.selectedCols = selectedCols; + this.isWoe = isWoe; + this.defaultWoe = defaultWoe; + this.withSelector = withSelector; + } + } + + private static DataSet > featureNameBinCount( + DataSet featureBorderDataSet) { + return featureBorderDataSet.map(new MapFunction >() { + private static final long serialVersionUID = 6330797694544909126L; + + @Override + public Tuple2 map(FeatureBinsCalculator value) throws Exception { + return Tuple2.of(value.getFeatureName(), FeatureBinsUtil.getBinEncodeVectorSize(value)); + } + }); + } + + private static BatchOperator binningModelToConstraint(DataSet featureBorderDataSet, + String[] selectedCols, + long environmentId) { + DataSet constraint = featureBorderDataSet.mapPartition( + new MapPartitionFunction () { + private static final long serialVersionUID = 6519391092676175709L; + + @Override + public void mapPartition(Iterable values, Collector out) { + Map featureNameBinCountMap = new HashMap <>(); + values.forEach(featureBorder -> featureNameBinCountMap + .put(featureBorder.getFeatureName(), FeatureBinsUtil.getBinEncodeVectorSize(featureBorder))); + ConstraintBetweenBins[] constraintBetweenBins = new ConstraintBetweenBins[selectedCols.length]; + for (int i = 0; i < selectedCols.length; i++) { + Integer binCount = featureNameBinCountMap.get(selectedCols[i]); + Preconditions.checkNotNull(binCount, "BinCount for %s is not set!", selectedCols[i]); + constraintBetweenBins[i] = new ConstraintBetweenBins(selectedCols[i], binCount); + } + FeatureConstraint featureConstraint = new FeatureConstraint(); + featureConstraint.addBinConstraint(constraintBetweenBins); + out.collect(Row.of(featureConstraint)); + } + }).setParallelism(1); + return new DataSetWrapperBatchOp(constraint, + BinningTrainForScorecardBatchOp.CONSTRAINT_TABLESCHEMA.getFieldNames(), + new TypeInformation [] {TypeInformation.of(FeatureConstraint.class)}).setMLEnvironmentId(environmentId); + } + + private static Params getBinningOutputParams(Params params, boolean withSelector) { + String[] selectedCols = params.get(HasSelectedCols.SELECTED_COLS); + Encode encode = params.get(ScorecardTrainParams.ENCODE); + Params binningPredictParams = new Params() + .set(HasMLEnvironmentId.ML_ENVIRONMENT_ID, params.get(HasMLEnvironmentId.ML_ENVIRONMENT_ID)) + .set(BinningPredictParams.SELECTED_COLS, selectedCols) + .set(BinningPredictParams.DEFAULT_WOE, params.get(BinningPredictParams.DEFAULT_WOE)) + .set(BinningPredictParams.DROP_LAST, false) + .set(BinningPredictParams.HANDLE_INVALID, HasHandleInvalid.HandleInvalid.KEEP); + + switch (encode) { + case WOE: + binningPredictParams.set(BinningPredictParams.ENCODE, HasEncode.Encode.WOE); + break; + case ASSEMBLED_VECTOR: + if (withSelector) { + binningPredictParams.set(BinningPredictParams.ENCODE, HasEncode.Encode.VECTOR); + } else { + binningPredictParams.set(BinningPredictParams.ENCODE, HasEncode.Encode.ASSEMBLED_VECTOR); + } + break; + default: { + throw new RuntimeException("Not support!"); + } + } + + switch (encode) { + case WOE: { + String[] binningOutCols = new String[selectedCols.length]; + for (int i = 0; i < selectedCols.length; i++) { + binningOutCols[i] = selectedCols[i] + BINNING_OUTPUT_COL; + } + params.set(BinningPredictParams.OUTPUT_COLS, binningOutCols).set(HasOutputCol.OUTPUT_COL, null); + binningPredictParams.set(BinningPredictParams.OUTPUT_COLS, binningOutCols); + break; + } + case ASSEMBLED_VECTOR: { + if (withSelector) { + String[] binningOutCols = new String[selectedCols.length]; + for (int i = 0; i < selectedCols.length; i++) { + binningOutCols[i] = selectedCols[i] + BINNING_OUTPUT_COL; + } + params.set(BinningPredictParams.OUTPUT_COLS, binningOutCols).set(HasOutputCol.OUTPUT_COL, null); + binningPredictParams.set(BinningPredictParams.OUTPUT_COLS, binningOutCols); + } else { + params.set(HasOutputCol.OUTPUT_COL, BINNING_OUTPUT_COL).set(BinningPredictParams.OUTPUT_COLS, + null); + binningPredictParams.set(BinningPredictParams.OUTPUT_COLS, new String[] {BINNING_OUTPUT_COL}); + } + break; + } + default: { + throw new RuntimeException("Not support!"); + } + } + return binningPredictParams; + } + + private static class MapConstraints extends RichMapFunction { + private static final long serialVersionUID = -8118713814960689315L; + FeatureConstraint udConstraint; + Map withElse; + + public MapConstraints(Map withElse) { + this.withElse = withElse; + } + + @Override + public void open(Configuration parameters) throws Exception { + String consString = + (String) ((Row) getRuntimeContext().getBroadcastVariable("udConstraint").get(0)) + .getField(0); + this.udConstraint = FeatureConstraint.fromJson(consString); + } + + @Override + public Row map(Row value) throws Exception { + FeatureConstraint binConstraint = (FeatureConstraint) value.getField(0); + udConstraint.addDim(binConstraint); + //no matter where are the constraint of else and null, they shall be passed with index of -1 or -2. + udConstraint.modify(withElse); + return Row.of(udConstraint); + } + } + + private static BinningModel setBinningModelData(BatchOperator modelData, Params params) { + BinningModel binningModel = new BinningModel(params); + binningModel.setModelData(modelData); + return binningModel; + } + + private static ScoreModel setScoreModelData( + BatchOperator modelData, Params params, TypeInformation labelType) { + + ScoreModel scoreM = new ScoreModel(params); + scoreM.setModelData(modelData); + return scoreM; + } + + static LinearModelData scaleLinearModelWeight(LinearModelData modelData, Tuple2 scaleInfo) { + double[] efficients = modelData.coefVector.getData(); + double scaleA = scaleInfo.f0; + double scaleB = scaleInfo.f1; + efficients[0] = scaleWeight(efficients[0], scaleA, scaleB, true); + for (int i = 1; i < efficients.length; i++) { + efficients[i] = scaleWeight(efficients[i], scaleA, scaleB, false); + } + return modelData; + } + + private static double scaleWeight(double weight, double scaleA, double scaleB, boolean isIntercept) { + if (isIntercept) { + return (weight - scaleB) / scaleA; + } else { + return weight / scaleA; + } + } + + static class LoadLinearModel implements MapPartitionFunction { + private static final long serialVersionUID = -2042421522639330680L; + + @Override + public void mapPartition(Iterable rows, Collector collector) { + List list = new ArrayList <>(); + rows.forEach(list::add); + collector.collect(new LinearModelDataConverter().load(list)); + } + } + + static class ScaleLinearModel implements MapFunction { + private static final long serialVersionUID = 1506707086583262871L; + private Params params; + + public ScaleLinearModel(Params params) { + this.params = params; + } + + @Override + public LinearModelData map(LinearModelData modelData) { + return scaleLinearModelWeight(modelData, loadScaleInfo(params)); + } + } + + static class SerializeLinearModel implements FlatMapFunction { + private static final long serialVersionUID = -780782639893457512L; + + @Override + public void flatMap(LinearModelData modelData, Collector collector) { + new LinearModelDataConverter(modelData.labelType).save(modelData, collector); + } + } + + @Override + public ScorecardModelInfoBatchOp getModelInfoBatchOp() { + return new ScorecardModelInfoBatchOp(this.getParams()).linkFrom(this); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/CommonNeighborsBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/CommonNeighborsBatchOp.java index af0cbdca3..0b90a0102 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/CommonNeighborsBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/CommonNeighborsBatchOp.java @@ -19,13 +19,14 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortSpec.OpType; import com.alibaba.alink.common.annotation.PortType; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.dataproc.HugeIndexerStringPredictBatchOp; import com.alibaba.alink.operator.batch.dataproc.HugeStringIndexerPredictBatchOp; @@ -43,6 +44,7 @@ @ParamSelectColumnSpec(name = "edgeSourceCol", portIndices = 0) @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = 0) @NameCn("共同邻居计算") +@NameEn("Common Neighbors") public class CommonNeighborsBatchOp extends BatchOperator implements CommonNeighborsTrainParams { private static final long serialVersionUID = -9221019571132151284L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/CommunityDetection.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/CommunityDetection.java deleted file mode 100644 index 25afb28dc..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/CommunityDetection.java +++ /dev/null @@ -1,393 +0,0 @@ -package com.alibaba.alink.operator.batch.graph; - -import org.apache.flink.api.common.functions.GroupReduceFunction; -import org.apache.flink.api.common.functions.JoinFunction; -import org.apache.flink.api.common.functions.MapFunction; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.api.java.tuple.Tuple3; -import org.apache.flink.api.java.tuple.Tuple4; -import org.apache.flink.api.java.tuple.Tuple5; -import org.apache.flink.graph.Edge; -import org.apache.flink.graph.Graph; -import org.apache.flink.graph.Vertex; -import org.apache.flink.graph.spargel.GatherFunction; -import org.apache.flink.graph.spargel.MessageIterator; -import org.apache.flink.graph.spargel.ScatterFunction; -import org.apache.flink.types.NullValue; -import org.apache.flink.util.Collector; - -import com.alibaba.alink.operator.batch.BatchOperator; - -import java.util.HashMap; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Random; -import java.util.TreeMap; - -/** - * Community Detection Algorithm. - *

- *

The Vertex values of the input Graph provide the initial label assignments. - *

- *

Initially, each vertex is assigned a tuple formed of its own initial value along with a score equal to 1.0. - * The vertices propagate their labels and max scores in iterations, each time adopting the label with the - * highest score from the list of received messages. The chosen label is afterwards re-scored using the fraction - * delta/the superstep number. Delta is passed as a parameter and has 0.5 as a default value. - */ -public class CommunityDetection { - - private int maxIterations; - private double delta; /*(0,1) default 0.5 */ - private int k; /* 1/k nodes not update info */ - - /** - * Creates a new Community Detection algorithm instance. - * The algorithm converges when vertices no longer update their value - * or when the maximum number of iterations is reached. - * - * @param maxIterations The maximum number of iterations to run. - * @param delta The hop attenuation parameter. Its default value is 0.5. - * @see - * Towards real-time community detection in large networks - */ - public CommunityDetection(int maxIterations, double delta, int k) { - this.maxIterations = maxIterations; - this.delta = delta; - this.k = k; - } - - public Graph runCluster(Graph graph) { - return operation(graph); - - } - - public Graph runCluster(DataSet > edges, Boolean directed) { - - Graph graph; - if (directed) { - graph = Graph.fromDataSet(edges, BatchOperator.getExecutionEnvironmentFromDataSets(edges)) - .mapVertices(new MapVertices()).getUndirected(); - } else { - graph = Graph.fromDataSet(edges, BatchOperator.getExecutionEnvironmentFromDataSets(edges)) - .mapVertices(new MapVertices()); - } - return operation(graph); - - } - - public Graph runClassify(Graph , Double> graph) { - //this operation is the last operation of the method operation. - Graph res = graph.runScatterGatherIteration(new LabelMessengerClassify(), - new VertexLabelUpdater(delta, k, false), maxIterations) - .mapVertices(new RemoveScoreFromVertexValuesMapper()); - return res; - } - - public static Graph operation(Graph graph, int maxIterations, - double delta, int k) { - //if clusterAlgorithm is true, it means to run cluster algorithm. - //vertex的三个变量分别为:点的id,点的label和点的分数。 - // tuple3: label, score and k. - DataSet >> initializedVertices = graph - .getVertices() - .map(new AddScoreToVertexValuesMapper(k)); - - //构造新的图 - Graph , Double> graphWithScoredVertices = - Graph.fromDataSet(initializedVertices, graph.getEdges(), graph.getContext()); - //主要的scatter-gather函数 - - Graph res = graphWithScoredVertices - .runScatterGatherIteration(new LabelMessengerCluster(), - new VertexLabelUpdater(delta, k, true), maxIterations) - .mapVertices(new RemoveScoreFromVertexValuesMapper()); - return res; - - } - - private Graph operation(Graph graph) { - return operation(graph, maxIterations, delta, k); - - } - - public static final class LabelMessengerCluster - extends ScatterFunction , Tuple2 , Double> { - - private static final long serialVersionUID = 8226284268536491551L; - - @Override - public void sendMessages(Vertex > vertex) { - //向每个邻点发送信息,主要是label和分数。 - //send msg(label and score) along with edges, to each neighbor. - - for (Edge edge : getEdges()) { - sendMessageTo(edge.getTarget(), new Tuple2 <>(vertex.getValue().f0, - vertex.getValue().f1 * edge.getValue()));//将label和分数乘label传递。 - } - } - } - - public static final class LabelMessengerClassify - extends ScatterFunction , Tuple2 , Double> { - - private static final long serialVersionUID = 542036977788012860L; - - @Override - public void sendMessages(Vertex > vertex) { - //向每个邻点发送信息,主要是label和分数。 - //send msg(label and score) along with edges, to each neighbor. - //in community detection cluster algorithm, set the unlabelled vertices as -1. - //if the label is -1, it won't send msg to its neighbor. - if (Math.abs(vertex.getValue().f0 + 1) > 1e-4) { - for (Edge edge : getEdges()) { - sendMessageTo(edge.getTarget(), new Tuple2 <>(vertex.getValue().f0, - vertex.getValue().f1 * edge.getValue()));//将label和分数乘label传递。 - } - } - } - } - - public static final class VertexLabelUpdater - extends GatherFunction , Tuple2 > { - - private static final long serialVersionUID = -4158941524263734994L; - private double delta; - private int k; - private Boolean clusterAlgorithm; - - public VertexLabelUpdater(double delta, int k, Boolean clusterAlgorithm) { - this.delta = delta; - this.k = k; - this.clusterAlgorithm = clusterAlgorithm; - } - - @Override - public void updateVertex(Vertex > vertex, - MessageIterator > inMessages) { - //as for the cluster algorithm, each point has a sign and each iteration only update one label. - //as for the classify algorithm, update only when the vertex label equals to -1. - if (clusterAlgorithm && vertex.f1.f2 % k == 0) { - vertex.f1.f2 = 1; - setNewVertexValue(vertex.f1); - } else if (!((!clusterAlgorithm) && Math.abs(vertex.getValue().f0 + 1) > 1e-4)) { - //the following refers to flink code. - //针对当前的这个点,它收到的label以及score的集合 - // we would like these two maps to be ordered - Map receivedLabelsWithScores = new TreeMap <>(); - Map labelsWithHighestScore = new TreeMap <>(); - - for (Tuple2 message : inMessages) { - // split the message into received label and score - double receivedLabel = message.f0; - double receivedScore = message.f1; - //如果之前没接收到这个label,就将这个label存入TreeMap中;如果接收过,则将score值累加。 - // if the label was received before - if (receivedLabelsWithScores.containsKey(receivedLabel)) { - double newScore = receivedScore + receivedLabelsWithScores.get(receivedLabel); - receivedLabelsWithScores.put(receivedLabel, newScore); - } else { - // first time we see the label - receivedLabelsWithScores.put(receivedLabel, receivedScore); - } - //将每个点最大的score存入。 - // store the labels with the highest scores - if (labelsWithHighestScore.containsKey(receivedLabel)) { - double currentScore = labelsWithHighestScore.get(receivedLabel); - if (currentScore < receivedScore) { - // record the highest score - labelsWithHighestScore.put(receivedLabel, receivedScore); - } - } else { - // first time we see this label - labelsWithHighestScore.put(receivedLabel, receivedScore); - } - } - if (receivedLabelsWithScores.size() > 0) { - //如果等于0则说明没有收到信息。那这样也不可能执行迭代。 - // find the label with the highest score from the ones received - double maxScore = -Double.MAX_VALUE; - //找到当前最大的累加score以及label。并没考虑label重复的情况。 - double maxScoreLabel = vertex.getValue().f0; - for (Double curLabel : receivedLabelsWithScores.keySet()) { - if (receivedLabelsWithScores.get(curLabel) > maxScore) { - maxScore = receivedLabelsWithScores.get(curLabel); - maxScoreLabel = curLabel; - } - } - //如果累加最大score的label不是现在点的label,则对score进行更新。 - // find the highest score of maxScoreLabel - double highestScore = labelsWithHighestScore.get(maxScoreLabel); - // re-score the new label - if (maxScoreLabel != vertex.getValue().f0) { - highestScore -= delta / getSuperstepNumber(); - } - // else delta = 0 - // update own label - vertex.f1.f2 += 1; - setNewVertexValue(new Tuple3 <>(maxScoreLabel, highestScore, vertex.f1.f2)); - //} else { - // setNewVertexValue(vertex.f1); - } - } - } - } - - @ForwardedFields("f0") - //initialize vertex set and add score. - public static class AddScoreToVertexValuesMapper - implements MapFunction , Vertex >> { - private static final long serialVersionUID = 8094922891125063932L; - private Random seed; - private int k; - - - public AddScoreToVertexValuesMapper(Integer k) { - this.seed = new Random(); - this.k = k; - } - - public Vertex > map(Vertex vertex) { - return new Vertex <>(vertex.getId(), new Tuple3 <>(vertex.getValue(), 1.0, (int) (vertex.getId() % k))); //todo 改 - } - } - - public static class RemoveScoreFromVertexValuesMapper - implements MapFunction >, Double> { - - private static final long serialVersionUID = 1275388361160360632L; - - @Override - public Double map(Vertex > vertex) throws Exception { - return vertex.getValue().f0; - } - } - - public static class MapVertices implements MapFunction , Double> { - - private static final long serialVersionUID = 456358729085831893L; - - @Override - public Double map(Vertex value) throws Exception { - return value.f0.doubleValue(); - } - } - - public static class ClusterMessageGroupFunction - implements GroupReduceFunction , Tuple3 > { - - @Override - public void reduce(Iterable > values, - Collector > out) throws Exception { - Map receivedLabelsWithScores = new HashMap <>(); - Long lastNodeId = 0L; - for (Tuple3 message : values) { - lastNodeId = message.f0; - Long receivedLabel = message.f1; - float receivedScore = message.f2; - receivedLabelsWithScores.put(receivedLabel, - receivedScore + receivedLabelsWithScores.getOrDefault(receivedLabel, 0.0F)); - } - float maxScore = ((Integer) Integer.MIN_VALUE).floatValue(); - Long maxScoreLabel = Long.MIN_VALUE; - for (Long curLabel : receivedLabelsWithScores.keySet()) { - float weight = receivedLabelsWithScores.get(curLabel); - if (weight > maxScore) { - maxScore = receivedLabelsWithScores.get(curLabel); - maxScoreLabel = curLabel; - } else if (Math.abs(weight - maxScore) <= 1e-4 && curLabel < maxScoreLabel) { - maxScoreLabel = curLabel; - } - } - out.collect(Tuple3.of(lastNodeId, maxScoreLabel, receivedLabelsWithScores.size())); - } - } - - public static class ClusterLabelMerger - implements JoinFunction , Tuple3 , - Tuple4 > { - - @Override - public Tuple4 join(Tuple3 first, - Tuple3 second) - throws Exception { - if (null == first || second.f1.equals(first.f1)) { - return Tuple4.of(second.f0, second.f1, second.f2, false); - } else { - if (Math.random() > 1.0 / (first.f2 + 1)) { - return Tuple4.of(second.f0, first.f1, second.f2, true); - } else { - return Tuple4.of(second.f0, second.f1, second.f2, true); - } - } - } - } - - protected static class ClassifyMessageGroupFunction - implements GroupReduceFunction , Tuple3 > { - - @Override - public void reduce(Iterable > values, - Collector > out) throws Exception { - Map labelScoresMap = new HashMap <>(); - Long currentNode = 0L; - float totalWeight = 0.F; - for (Tuple3 value : values) { - if (value.f2 < 0) { - continue; - } - currentNode = value.f0; - labelScoresMap.put(value.f1, labelScoresMap.getOrDefault(value.f1, 0.F) + value.f2); - totalWeight += value.f2; - } - Integer maxLabel = -1; - float maxScore = 0.F; - for (Entry entry : labelScoresMap.entrySet()) { - if (entry.getValue() > maxScore) { - maxLabel = entry.getKey(); - maxScore = entry.getValue(); - } - } - if (maxLabel < 0) { - out.collect(Tuple3.of(currentNode, -1, 0F)); - } else { - out.collect(Tuple3.of(currentNode, maxLabel, maxScore / totalWeight)); - } - } - } - - protected static class ClassifyLabelMerger - implements JoinFunction , Tuple4 , - Tuple5 > { - - private float delta; - - ClassifyLabelMerger(double delta) { - this.delta = ((Number) delta).floatValue(); - } - - @Override - public Tuple5 join(Tuple3 first, - Tuple4 second) - throws Exception { - Tuple5 res = null; - // if node not exists in generated label dataset, - // or node exists in input classify nodes (second.f3 = true) - // or new label with negative weight - // or node already been set label in preprocess iteration - // node will keep label unchanged. - if (null == first || second.f3 || first.f2 < 0 || second.f1 >= 0) { - res = Tuple5.of(second.f0, second.f1, second.f2, second.f3, false); - } else { - if (second.f2 == first.f2) { - res = Tuple5.of(second.f0, second.f1, second.f2 * delta + first.f2 * (1 - delta), second.f3, false); - } else { - res = Tuple5.of(second.f0, first.f1, first.f2, second.f3, true); - } - } - return res; - } - } -} - diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/CommunityDetectionClassifyBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/CommunityDetectionClassifyBatchOp.java index 728244a24..a395f4189 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/CommunityDetectionClassifyBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/CommunityDetectionClassifyBatchOp.java @@ -9,6 +9,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields; import org.apache.flink.api.java.operators.IterativeDataSet; import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; @@ -17,12 +18,19 @@ import org.apache.flink.api.java.tuple.Tuple5; import org.apache.flink.configuration.Configuration; import org.apache.flink.graph.Edge; +import org.apache.flink.graph.Graph; +import org.apache.flink.graph.Vertex; +import org.apache.flink.graph.spargel.GatherFunction; +import org.apache.flink.graph.spargel.MessageIterator; +import org.apache.flink.graph.spargel.ScatterFunction; import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.NullValue; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -33,14 +41,16 @@ import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.batch.graph.CommunityDetection.ClassifyLabelMerger; -import com.alibaba.alink.operator.batch.graph.CommunityDetection.ClassifyMessageGroupFunction; -import com.alibaba.alink.operator.common.graph.GraphUtils; +import com.alibaba.alink.operator.batch.graph.CommunityDetectionClassifyBatchOp.CommunityDetection.ClassifyLabelMerger; +import com.alibaba.alink.operator.batch.graph.CommunityDetectionClassifyBatchOp.CommunityDetection.ClassifyMessageGroupFunction; import com.alibaba.alink.params.graph.CommunityDetectionClassifyParams; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; +import java.util.Random; +import java.util.TreeMap; @InputPorts(values = { @PortSpec(value = PortType.DATA, opType = OpType.BATCH, desc = PortDesc.GRPAH_EDGES), @@ -55,6 +65,7 @@ @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = 0) @ParamSelectColumnSpec(name = "edgeWeightCol", portIndices = 0, allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("标签传播分类") +@NameEn("Common Detection Classify") public class CommunityDetectionClassifyBatchOp extends BatchOperator implements CommunityDetectionClassifyParams { private static final long serialVersionUID = -2264855900960878969L; @@ -254,4 +265,369 @@ public void flatMap(Row value, Collector out) throws Exception { this.setOutput(res, outputCols, new TypeInformation [] {edgeTypes[0], labelType}); return this; } + + /** + * Community Detection Algorithm. + *

+ *

The Vertex values of the input Graph provide the initial label assignments. + *

+ *

Initially, each vertex is assigned a tuple formed of its own initial value along with a score equal to 1.0. + * The vertices propagate their labels and max scores in iterations, each time adopting the label with the + * highest score from the list of received messages. The chosen label is afterwards re-scored using the fraction + * delta/the superstep number. Delta is passed as a parameter and has 0.5 as a default value. + */ + public static class CommunityDetection { + + private int maxIterations; + private double delta; /*(0,1) default 0.5 */ + private int k; /* 1/k nodes not update info */ + + /** + * Creates a new Community Detection algorithm instance. + * The algorithm converges when vertices no longer update their value + * or when the maximum number of iterations is reached. + * + * @param maxIterations The maximum number of iterations to run. + * @param delta The hop attenuation parameter. Its default value is 0.5. + * @see + * Towards real-time community detection in large networks + */ + public CommunityDetection(int maxIterations, double delta, int k) { + this.maxIterations = maxIterations; + this.delta = delta; + this.k = k; + } + + public Graph runCluster(Graph graph) { + return operation(graph); + + } + + public Graph runCluster(DataSet > edges, Boolean directed) { + + Graph graph; + if (directed) { + graph = Graph.fromDataSet(edges, getExecutionEnvironmentFromDataSets(edges)) + .mapVertices(new MapVertices()).getUndirected(); + } else { + graph = Graph.fromDataSet(edges, getExecutionEnvironmentFromDataSets(edges)) + .mapVertices(new MapVertices()); + } + return operation(graph); + + } + + public Graph runClassify(Graph , Double> graph) { + //this operation is the last operation of the method operation. + Graph res = graph.runScatterGatherIteration(new LabelMessengerClassify(), + new VertexLabelUpdater(delta, k, false), maxIterations) + .mapVertices(new RemoveScoreFromVertexValuesMapper()); + return res; + } + + public static Graph operation(Graph graph, int maxIterations, + double delta, int k) { + //if clusterAlgorithm is true, it means to run cluster algorithm. + //vertex的三个变量分别为:点的id,点的label和点的分数。 + // tuple3: label, score and k. + DataSet >> initializedVertices = graph + .getVertices() + .map(new AddScoreToVertexValuesMapper(k)); + + //构造新的图 + Graph , Double> graphWithScoredVertices = + Graph.fromDataSet(initializedVertices, graph.getEdges(), graph.getContext()); + //主要的scatter-gather函数 + + Graph res = graphWithScoredVertices + .runScatterGatherIteration(new LabelMessengerCluster(), + new VertexLabelUpdater(delta, k, true), maxIterations) + .mapVertices(new RemoveScoreFromVertexValuesMapper()); + return res; + + } + + private Graph operation(Graph graph) { + return operation(graph, maxIterations, delta, k); + + } + + public static final class LabelMessengerCluster + extends ScatterFunction , Tuple2 , Double> { + + private static final long serialVersionUID = 8226284268536491551L; + + @Override + public void sendMessages(Vertex > vertex) { + //向每个邻点发送信息,主要是label和分数。 + //send msg(label and score) along with edges, to each neighbor. + + for (Edge edge : getEdges()) { + sendMessageTo(edge.getTarget(), new Tuple2 <>(vertex.getValue().f0, + vertex.getValue().f1 * edge.getValue()));//将label和分数乘label传递。 + } + } + } + + public static final class LabelMessengerClassify + extends ScatterFunction , Tuple2 , Double> { + + private static final long serialVersionUID = 542036977788012860L; + + @Override + public void sendMessages(Vertex > vertex) { + //向每个邻点发送信息,主要是label和分数。 + //send msg(label and score) along with edges, to each neighbor. + //in community detection cluster algorithm, set the unlabelled vertices as -1. + //if the label is -1, it won't send msg to its neighbor. + if (Math.abs(vertex.getValue().f0 + 1) > 1e-4) { + for (Edge edge : getEdges()) { + sendMessageTo(edge.getTarget(), new Tuple2 <>(vertex.getValue().f0, + vertex.getValue().f1 * edge.getValue()));//将label和分数乘label传递。 + } + } + } + } + + public static final class VertexLabelUpdater + extends GatherFunction , Tuple2 > { + + private static final long serialVersionUID = -4158941524263734994L; + private double delta; + private int k; + private Boolean clusterAlgorithm; + + public VertexLabelUpdater(double delta, int k, Boolean clusterAlgorithm) { + this.delta = delta; + this.k = k; + this.clusterAlgorithm = clusterAlgorithm; + } + + @Override + public void updateVertex(Vertex > vertex, + MessageIterator > inMessages) { + //as for the cluster algorithm, each point has a sign and each iteration only update one label. + //as for the classify algorithm, update only when the vertex label equals to -1. + if (clusterAlgorithm && vertex.f1.f2 % k == 0) { + vertex.f1.f2 = 1; + setNewVertexValue(vertex.f1); + } else if (!((!clusterAlgorithm) && Math.abs(vertex.getValue().f0 + 1) > 1e-4)) { + //the following refers to flink code. + //针对当前的这个点,它收到的label以及score的集合 + // we would like these two maps to be ordered + Map receivedLabelsWithScores = new TreeMap <>(); + Map labelsWithHighestScore = new TreeMap <>(); + + for (Tuple2 message : inMessages) { + // split the message into received label and score + double receivedLabel = message.f0; + double receivedScore = message.f1; + //如果之前没接收到这个label,就将这个label存入TreeMap中;如果接收过,则将score值累加。 + // if the label was received before + if (receivedLabelsWithScores.containsKey(receivedLabel)) { + double newScore = receivedScore + receivedLabelsWithScores.get(receivedLabel); + receivedLabelsWithScores.put(receivedLabel, newScore); + } else { + // first time we see the label + receivedLabelsWithScores.put(receivedLabel, receivedScore); + } + //将每个点最大的score存入。 + // store the labels with the highest scores + if (labelsWithHighestScore.containsKey(receivedLabel)) { + double currentScore = labelsWithHighestScore.get(receivedLabel); + if (currentScore < receivedScore) { + // record the highest score + labelsWithHighestScore.put(receivedLabel, receivedScore); + } + } else { + // first time we see this label + labelsWithHighestScore.put(receivedLabel, receivedScore); + } + } + if (receivedLabelsWithScores.size() > 0) { + //如果等于0则说明没有收到信息。那这样也不可能执行迭代。 + // find the label with the highest score from the ones received + double maxScore = -Double.MAX_VALUE; + //找到当前最大的累加score以及label。并没考虑label重复的情况。 + double maxScoreLabel = vertex.getValue().f0; + for (Double curLabel : receivedLabelsWithScores.keySet()) { + if (receivedLabelsWithScores.get(curLabel) > maxScore) { + maxScore = receivedLabelsWithScores.get(curLabel); + maxScoreLabel = curLabel; + } + } + //如果累加最大score的label不是现在点的label,则对score进行更新。 + // find the highest score of maxScoreLabel + double highestScore = labelsWithHighestScore.get(maxScoreLabel); + // re-score the new label + if (maxScoreLabel != vertex.getValue().f0) { + highestScore -= delta / getSuperstepNumber(); + } + // else delta = 0 + // update own label + vertex.f1.f2 += 1; + setNewVertexValue(new Tuple3 <>(maxScoreLabel, highestScore, vertex.f1.f2)); + //} else { + // setNewVertexValue(vertex.f1); + } + } + } + } + + @ForwardedFields("f0") + //initialize vertex set and add score. + public static class AddScoreToVertexValuesMapper + implements MapFunction , Vertex >> { + private static final long serialVersionUID = 8094922891125063932L; + private Random seed; + private int k; + + + public AddScoreToVertexValuesMapper(Integer k) { + this.seed = new Random(); + this.k = k; + } + + public Vertex > map(Vertex vertex) { + return new Vertex <>(vertex.getId(), new Tuple3 <>(vertex.getValue(), 1.0, (int) (vertex.getId() % k))); //todo 改 + } + } + + public static class RemoveScoreFromVertexValuesMapper + implements MapFunction >, Double> { + + private static final long serialVersionUID = 1275388361160360632L; + + @Override + public Double map(Vertex > vertex) throws Exception { + return vertex.getValue().f0; + } + } + + public static class MapVertices implements MapFunction , Double> { + + private static final long serialVersionUID = 456358729085831893L; + + @Override + public Double map(Vertex value) throws Exception { + return value.f0.doubleValue(); + } + } + + public static class ClusterMessageGroupFunction + implements GroupReduceFunction , Tuple3 > { + + @Override + public void reduce(Iterable > values, + Collector > out) throws Exception { + Map receivedLabelsWithScores = new HashMap <>(); + Long lastNodeId = 0L; + for (Tuple3 message : values) { + lastNodeId = message.f0; + Long receivedLabel = message.f1; + float receivedScore = message.f2; + receivedLabelsWithScores.put(receivedLabel, + receivedScore + receivedLabelsWithScores.getOrDefault(receivedLabel, 0.0F)); + } + float maxScore = ((Integer) Integer.MIN_VALUE).floatValue(); + Long maxScoreLabel = Long.MIN_VALUE; + for (Long curLabel : receivedLabelsWithScores.keySet()) { + float weight = receivedLabelsWithScores.get(curLabel); + if (weight > maxScore) { + maxScore = receivedLabelsWithScores.get(curLabel); + maxScoreLabel = curLabel; + } else if (Math.abs(weight - maxScore) <= 1e-4 && curLabel < maxScoreLabel) { + maxScoreLabel = curLabel; + } + } + out.collect(Tuple3.of(lastNodeId, maxScoreLabel, receivedLabelsWithScores.size())); + } + } + + public static class ClusterLabelMerger + implements JoinFunction , Tuple3 , + Tuple4 > { + + @Override + public Tuple4 join(Tuple3 first, + Tuple3 second) + throws Exception { + if (null == first || second.f1.equals(first.f1)) { + return Tuple4.of(second.f0, second.f1, second.f2, false); + } else { + if (Math.random() > 1.0 / (first.f2 + 1)) { + return Tuple4.of(second.f0, first.f1, second.f2, true); + } else { + return Tuple4.of(second.f0, second.f1, second.f2, true); + } + } + } + } + + protected static class ClassifyMessageGroupFunction + implements GroupReduceFunction , Tuple3 > { + + @Override + public void reduce(Iterable > values, + Collector > out) throws Exception { + Map labelScoresMap = new HashMap <>(); + Long currentNode = 0L; + float totalWeight = 0.F; + for (Tuple3 value : values) { + if (value.f2 < 0) { + continue; + } + currentNode = value.f0; + labelScoresMap.put(value.f1, labelScoresMap.getOrDefault(value.f1, 0.F) + value.f2); + totalWeight += value.f2; + } + Integer maxLabel = -1; + float maxScore = 0.F; + for (Entry entry : labelScoresMap.entrySet()) { + if (entry.getValue() > maxScore) { + maxLabel = entry.getKey(); + maxScore = entry.getValue(); + } + } + if (maxLabel < 0) { + out.collect(Tuple3.of(currentNode, -1, 0F)); + } else { + out.collect(Tuple3.of(currentNode, maxLabel, maxScore / totalWeight)); + } + } + } + + protected static class ClassifyLabelMerger + implements JoinFunction , Tuple4 , + Tuple5 > { + + private float delta; + + ClassifyLabelMerger(double delta) { + this.delta = ((Number) delta).floatValue(); + } + + @Override + public Tuple5 join(Tuple3 first, + Tuple4 second) + throws Exception { + Tuple5 res = null; + // if node not exists in generated label dataset, + // or node exists in input classify nodes (second.f3 = true) + // or new label with negative weight + // or node already been set label in preprocess iteration + // node will keep label unchanged. + if (null == first || second.f3 || first.f2 < 0 || second.f1 >= 0) { + res = Tuple5.of(second.f0, second.f1, second.f2, second.f3, false); + } else { + if (second.f2 == first.f2) { + res = Tuple5.of(second.f0, second.f1, second.f2 * delta + first.f2 * (1 - delta), second.f3, false); + } else { + res = Tuple5.of(second.f0, first.f1, first.f2, second.f3, true); + } + } + return res; + } + } + } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/CommunityDetectionClusterBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/CommunityDetectionClusterBatchOp.java index 15c0a46a5..810e6f86b 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/CommunityDetectionClusterBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/CommunityDetectionClusterBatchOp.java @@ -19,6 +19,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -27,9 +28,8 @@ import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.batch.graph.CommunityDetection.ClusterLabelMerger; -import com.alibaba.alink.operator.batch.graph.CommunityDetection.ClusterMessageGroupFunction; -import com.alibaba.alink.operator.common.graph.GraphUtils; +import com.alibaba.alink.operator.batch.graph.CommunityDetectionClassifyBatchOp.CommunityDetection.ClusterLabelMerger; +import com.alibaba.alink.operator.batch.graph.CommunityDetectionClassifyBatchOp.CommunityDetection.ClusterMessageGroupFunction; import com.alibaba.alink.params.graph.CommunityDetectionClusterParams; @InputPorts(values = { @PortSpec(value = PortType.DATA, opType = OpType.BATCH, desc = PortDesc.GRPAH_EDGES), @@ -42,6 +42,7 @@ @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = 0) @ParamSelectColumnSpec(name = "edgeWeightCol", portIndices = 0) @NameCn("标签传播聚类") +@NameEn("Common Detection Cluster") public class CommunityDetectionClusterBatchOp extends BatchOperator implements CommunityDetectionClusterParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/ConnectedComponentsBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/ConnectedComponentsBatchOp.java index 9d79064bf..09fc5cdd4 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/ConnectedComponentsBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/ConnectedComponentsBatchOp.java @@ -16,6 +16,7 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -24,7 +25,6 @@ import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.graph.GraphUtilsWithString; import com.alibaba.alink.params.graph.ConnectedComponentParams; import com.alibaba.alink.params.graph.HasSetStable; @@ -37,6 +37,7 @@ @ParamSelectColumnSpec(name = "edgeSourceCol", portIndices = 0) @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = 0) @NameCn("最大联通分量") +@NameEn("ConnectedComponents") public class ConnectedComponentsBatchOp extends BatchOperator implements ConnectedComponentParams { private static final long serialVersionUID = -7920188555691775911L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/DeepWalkBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/DeepWalkBatchOp.java index f7e2c2485..3ad2be388 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/DeepWalkBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/DeepWalkBatchOp.java @@ -3,9 +3,11 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.huge.impl.DeepWalkImpl; @NameCn("DeepWalk") +@NameEn("DeepWalk") public class DeepWalkBatchOp extends DeepWalkImpl { private static final long serialVersionUID = -8007362121261574268L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/EdgeClusterCoefficient.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/EdgeClusterCoefficient.java deleted file mode 100644 index ba05f6136..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/EdgeClusterCoefficient.java +++ /dev/null @@ -1,151 +0,0 @@ -package com.alibaba.alink.operator.batch.graph; - -import org.apache.flink.api.common.functions.MapFunction; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.api.java.tuple.Tuple6; -import org.apache.flink.graph.Edge; -import org.apache.flink.graph.Graph; -import org.apache.flink.graph.Triplet; -import org.apache.flink.graph.Vertex; -import org.apache.flink.graph.VertexJoinFunction; -import org.apache.flink.graph.spargel.GatherFunction; -import org.apache.flink.graph.spargel.MessageIterator; -import org.apache.flink.graph.spargel.ScatterFunction; -import org.apache.flink.types.LongValue; - -import java.util.Arrays; - -/** - * for each edge of undirected graph, return 1. the degree of its source and target - * 2. the number of triangles based on this edge - * 3. the quotient of the number of triangles and the min value between the two degrees - * - * @author qingzhao - */ - -public class EdgeClusterCoefficient { - public DataSet > run(Graph graph) { - //calculate the degree of each vertex. Because it is undirected, we only consider inDegree - DataSet > vertexDataSet = graph.inDegrees().map(new Longvalue2Long()); - //construct the output form, and write all the edges in it - //DataSet> temp = graph.getEdges().map(new MapEdge()); - //write the degrees of sources and targets in the corresponding position. - //write degrees of sources (position 0) in position 2. and degrees of targets (position 1) in position 3 - // for convenience of coGroup, put the sources and targets in a Tuple2 - Graph graphWithDegree = graph - .joinWithVertices(vertexDataSet, new VerticesJoin()) - .mapEdges(new GraphTempMapEdge()); - //write the neighbors of vertices in their values. - //We use Set instead of List, because Set is 10%~20% speeder than List - Graph graphTemp = graphWithDegree - .mapVertices(new GraphTempMapVertex()) - .runScatterGatherIteration(new ScatterGraphTemp(), - new GatherGraphTemp(), - 1); - //operate on triplet and write the neighbor number in the values of edges - return graphTemp - .getTriplets() - .map(new MapTriplet()); - } - - public static class Longvalue2Long - implements MapFunction , Tuple2 > { - private static final long serialVersionUID = -8499849561757464697L; - - @Override - public Tuple2 map(Tuple2 value) throws Exception { - return new Tuple2 <>(value.f0, value.f1.getValue()); - } - } - - public static class VerticesJoin implements VertexJoinFunction { - private static final long serialVersionUID = 3413134536200006612L; - - @Override - public Double vertexJoin(Double aDouble, Long aLong) { - return aLong.doubleValue(); - } - } - - public static class GraphTempMapEdge - implements MapFunction , Long> { - private static final long serialVersionUID = 5872715320371423887L; - - @Override - public Long map(Edge value) throws Exception { - return 0L; - } - } - - public static class GraphTempMapVertex - implements MapFunction , Long[]> { - private static final long serialVersionUID = 897641610324504285L; - - @Override - public Long[] map(Vertex value) throws Exception { - return new Long[value.f1.intValue()]; - } - } - - public static class ScatterGraphTemp - extends ScatterFunction { - private static final long serialVersionUID = -7538585087831930835L; - - @Override - public void sendMessages(Vertex vertex) { - for (Edge edge : getEdges()) { - sendMessageTo(edge.getTarget(), vertex.f0); - } - } - } - - public static class GatherGraphTemp - extends GatherFunction { - private static final long serialVersionUID = 465153561839511086L; - - @Override - public void updateVertex(Vertex vertex, - MessageIterator inMessages) { - int count = 0; - for (Long msg : inMessages) { - vertex.f1[count] = msg; - count += 1; - } - Arrays.sort(vertex.f1); - setNewVertexValue(vertex.f1); - } - } - - public static class MapTriplet implements MapFunction < - Triplet , - Tuple6 > { - private static final long serialVersionUID = -1837449716800435246L; - - @Override - public Tuple6 map( - Triplet value) throws Exception { - int l2 = value.f2.length; - int l3 = value.f3.length; - int index2 = 0; - int index3 = 0; - long count = 0; - while (index2 < l2 && index3 < l3) { - if (value.f2[index2].equals(value.f3[index3])) { - index2 += 1; - index3 += 1; - count += 1; - } else if (value.f2[index2] > value.f3[index3]) { - index3 += 1; - } else { - index2 += 1; - } - } - long f2 = value.f2.length; - long f3 = value.f3.length; - return new Tuple6 <>(value.f0, value.f1, f2, f3, count, count * 1. / Math.min(f2, f3)); - } - } -} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/EdgeClusterCoefficientBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/EdgeClusterCoefficientBatchOp.java index 66292af97..257d26777 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/EdgeClusterCoefficientBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/EdgeClusterCoefficientBatchOp.java @@ -5,11 +5,18 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple6; import org.apache.flink.graph.Edge; import org.apache.flink.graph.Graph; +import org.apache.flink.graph.Triplet; import org.apache.flink.graph.Vertex; +import org.apache.flink.graph.VertexJoinFunction; +import org.apache.flink.graph.spargel.GatherFunction; +import org.apache.flink.graph.spargel.MessageIterator; +import org.apache.flink.graph.spargel.ScatterFunction; import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.LongValue; import org.apache.flink.types.NullValue; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; @@ -17,6 +24,7 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -26,13 +34,16 @@ import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.graph.GraphUtilsWithString; import com.alibaba.alink.params.graph.EdgeClusterCoefficientParams; + +import java.util.Arrays; + @InputPorts(values = @PortSpec(value = PortType.DATA, opType = OpType.BATCH, desc = PortDesc.GRPAH_EDGES)) @OutputPorts(values = @PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)) @ParamSelectColumnSpec(name = "edgeSourceCol", allowedTypeCollections = TypeCollections.INT_LONG_STRING_TYPES) @ParamSelectColumnSpec(name = "edgeTargetCol", allowedTypeCollections = TypeCollections.INT_LONG_STRING_TYPES) @NameCn("边聚类系数") +@NameEn("Edge Cluster Coefficient") public class EdgeClusterCoefficientBatchOp extends BatchOperator implements EdgeClusterCoefficientParams { @@ -105,4 +116,136 @@ public Double map(Vertex value) throws Exception { return 1.; } } + + /** + * for each edge of undirected graph, return 1. the degree of its source and target + * 2. the number of triangles based on this edge + * 3. the quotient of the number of triangles and the min value between the two degrees + */ + + public static class EdgeClusterCoefficient { + public DataSet > run(Graph graph) { + //calculate the degree of each vertex. Because it is undirected, we only consider inDegree + DataSet > vertexDataSet = graph.inDegrees().map(new Longvalue2Long()); + //construct the output form, and write all the edges in it + //DataSet> temp = graph.getEdges().map(new MapEdge()); + //write the degrees of sources and targets in the corresponding position. + //write degrees of sources (position 0) in position 2. and degrees of targets (position 1) in position 3 + // for convenience of coGroup, put the sources and targets in a Tuple2 + Graph graphWithDegree = graph + .joinWithVertices(vertexDataSet, new VerticesJoin()) + .mapEdges(new GraphTempMapEdge()); + //write the neighbors of vertices in their values. + //We use Set instead of List, because Set is 10%~20% speeder than List + Graph graphTemp = graphWithDegree + .mapVertices(new GraphTempMapVertex()) + .runScatterGatherIteration(new ScatterGraphTemp(), + new GatherGraphTemp(), + 1); + //operate on triplet and write the neighbor number in the values of edges + return graphTemp + .getTriplets() + .map(new MapTriplet()); + } + + public static class Longvalue2Long + implements MapFunction , Tuple2 > { + private static final long serialVersionUID = -8499849561757464697L; + + @Override + public Tuple2 map(Tuple2 value) throws Exception { + return new Tuple2 <>(value.f0, value.f1.getValue()); + } + } + + public static class VerticesJoin implements VertexJoinFunction { + private static final long serialVersionUID = 3413134536200006612L; + + @Override + public Double vertexJoin(Double aDouble, Long aLong) { + return aLong.doubleValue(); + } + } + + public static class GraphTempMapEdge + implements MapFunction , Long> { + private static final long serialVersionUID = 5872715320371423887L; + + @Override + public Long map(Edge value) throws Exception { + return 0L; + } + } + + public static class GraphTempMapVertex + implements MapFunction , Long[]> { + private static final long serialVersionUID = 897641610324504285L; + + @Override + public Long[] map(Vertex value) throws Exception { + return new Long[value.f1.intValue()]; + } + } + + public static class ScatterGraphTemp + extends ScatterFunction { + private static final long serialVersionUID = -7538585087831930835L; + + @Override + public void sendMessages(Vertex vertex) { + for (Edge edge : getEdges()) { + sendMessageTo(edge.getTarget(), vertex.f0); + } + } + } + + public static class GatherGraphTemp + extends GatherFunction { + private static final long serialVersionUID = 465153561839511086L; + + @Override + public void updateVertex(Vertex vertex, + MessageIterator inMessages) { + int count = 0; + for (Long msg : inMessages) { + vertex.f1[count] = msg; + count += 1; + } + Arrays.sort(vertex.f1); + setNewVertexValue(vertex.f1); + } + } + + public static class MapTriplet implements MapFunction < + Triplet , + Tuple6 > { + private static final long serialVersionUID = -1837449716800435246L; + + @Override + public Tuple6 map( + Triplet value) throws Exception { + int l2 = value.f2.length; + int l3 = value.f3.length; + int index2 = 0; + int index3 = 0; + long count = 0; + while (index2 < l2 && index3 < l3) { + if (value.f2[index2].equals(value.f3[index3])) { + index2 += 1; + index3 += 1; + count += 1; + } else if (value.f2[index2] > value.f3[index3]) { + index3 += 1; + } else { + index2 += 1; + } + } + long f2 = value.f2.length; + long f3 = value.f3.length; + return new Tuple6 <>(value.f0, value.f1, f2, f3, count, count * 1. / Math.min(f2, f3)); + } + } + } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/graph/GraphUtils.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/GraphUtils.java similarity index 99% rename from core/src/main/java/com/alibaba/alink/operator/common/graph/GraphUtils.java rename to core/src/main/java/com/alibaba/alink/operator/batch/graph/GraphUtils.java index 1c9ffe2a3..0c9d985bd 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/graph/GraphUtils.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/GraphUtils.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.operator.common.graph; +package com.alibaba.alink.operator.batch.graph; import org.apache.flink.api.common.functions.CoGroupFunction; import org.apache.flink.api.common.functions.FlatMapFunction; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/graph/GraphUtilsWithString.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/GraphUtilsWithString.java similarity index 99% rename from core/src/main/java/com/alibaba/alink/operator/common/graph/GraphUtilsWithString.java rename to core/src/main/java/com/alibaba/alink/operator/batch/graph/GraphUtilsWithString.java index 972d85eea..5667e272b 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/graph/GraphUtilsWithString.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/GraphUtilsWithString.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.operator.common.graph; +package com.alibaba.alink.operator.batch.graph; import org.apache.flink.api.common.functions.CoGroupFunction; import org.apache.flink.api.common.functions.JoinFunction; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/KCore.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/KCore.java deleted file mode 100644 index 96a21df79..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/KCore.java +++ /dev/null @@ -1,196 +0,0 @@ -package com.alibaba.alink.operator.batch.graph; - -import org.apache.flink.api.common.functions.FilterFunction; -import org.apache.flink.api.common.functions.FlatMapFunction; -import org.apache.flink.api.common.functions.GroupReduceFunction; -import org.apache.flink.api.common.functions.MapFunction; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.operators.IterativeDataSet; -import org.apache.flink.api.java.tuple.Tuple5; -import org.apache.flink.graph.Edge; -import org.apache.flink.util.Collector; - -import java.util.ArrayList; -import java.util.List; - -/** - * This algorithm iteratively delete all vertices whose degree is not larger than k ,so that it will select - * a graph whose vertices all have degrees larger than k. - * Dealing with undirectedGraph, we groupby the edge dataset so that we can get the degree of vertices - * through the numbers of vertices appear in the field 0 or 1 in the dataset, and iteratively delete - * edges denoting small degree until the remaining dataset meets the requirement. - * - * @author qingzhao - */ -public class KCore { - public static int k; - public int maxIter; - - /** - * @param k Remove all vertices with degree not larger than k. - * @param maxIter The maximum number of iterations to run. - */ - public KCore(int k, int maxIter) { - this.k = k; - this.maxIter = maxIter; - } - - public DataSet > run(DataSet > edges, Boolean directed) { - DataSet > initialState; - if (directed) { - initialState = edges.flatMap(new FlatMapFunction - , Tuple5 >() { - private static final long serialVersionUID = 3415098877090917677L; - - @Override - public void flatMap(Edge value, - Collector > out) { - Tuple5 res = new Tuple5 (); - res.f0 = value.f0; - res.f1 = value.f1; - res.f2 = -1L; - res.f3 = -1L; - res.f4 = 0.; - out.collect(res); - res.f0 = value.f1; - res.f1 = value.f0; - out.collect(res); - } - }); - } else { - initialState = edges.flatMap(new FlatMapFunction - , Tuple5 >() { - private static final long serialVersionUID = -1356257363097879387L; - - @Override - public void flatMap(Edge value, - Collector > out) { - Tuple5 res = new Tuple5 (); - res.f0 = value.f0; - res.f1 = value.f1; - res.f2 = -1L; - res.f3 = -1L; - res.f4 = 0.; - out.collect(res); - } - }); - } - DataSet > outState = operation(initialState); - return outState.map(new MapFunction , Edge >() { - private static final long serialVersionUID = -1652848589684719913L; - - @Override - public Edge map(Tuple5 value) throws Exception { - return new Edge (value.f0, value.f1, 1.); - } - }); - } - - public DataSet > operation( - DataSet > initialState) { - IterativeDataSet > state = initialState - .iterate(this.maxIter); - //Count numbers with field 0, and then filter edges denoting small degree. - DataSet > secondState = state - .groupBy(0) - .reduceGroup(new ReduceOnFirstField()) - .filter(new FilterSmallOnesOnFirstField(k)).name("firstStep"); - //If there is no vertices with small degree, then break the iteration. - DataSet > thirdState = secondState - .groupBy(1) - .reduceGroup(new ReduceOnSecondField()).name("secondStep"); - //seems fussy, but may not avoid - DataSet > outState = state - .closeWith( - thirdState.filter(new FilterSmallOnesOnSecondField(k)).name("filterSmallOne"), - thirdState.filter(new FilterLargeOnesOnThiState(k)).name("filterLargeOne")); - - return outState; - } - - public static class ReduceOnFirstField - implements - GroupReduceFunction , Tuple5 > { - private static final long serialVersionUID = 263920722211539724L; - - @Override - public void reduce(Iterable > values, - Collector > out) throws Exception { - long counter = 0L; - List > l = new ArrayList <>(); - for (Tuple5 i : values) { - counter += 1L; - l.add(i); - } - for (Tuple5 i : l) { - out.collect(new Tuple5 <>(i.f0, i.f1, counter, i.f3, i.f4)); - } - } - } - - public static class ReduceOnSecondField - implements - GroupReduceFunction , Tuple5 > { - private static final long serialVersionUID = 7840099990204577056L; - - @Override - public void reduce(Iterable > values, - Collector > out) throws Exception { - long counter = 0L; - List > l = new ArrayList <>(); - for (Tuple5 i : values) { - counter += 1L; - l.add(i); - } - for (Tuple5 i : l) { - out.collect(new Tuple5 <>(i.f0, i.f1, i.f2, counter, i.f4)); - } - } - } - - public static class FilterSmallOnesOnFirstField - implements FilterFunction > { - private static final long serialVersionUID = -4414815465890029511L; - private long k; - - private FilterSmallOnesOnFirstField(long k) { - this.k = k; - } - - @Override - public boolean filter(Tuple5 value) throws Exception { - return value.f2 > this.k; - } - } - - public static class FilterSmallOnesOnSecondField - implements FilterFunction > { - private static final long serialVersionUID = 156799354134467716L; - private long k; - - private FilterSmallOnesOnSecondField(long k) { - this.k = k; - } - - @Override - public boolean filter(Tuple5 value) throws Exception { - return value.f3 > this.k; - } - } - - public static class FilterLargeOnesOnThiState - implements FilterFunction > { - private static final long serialVersionUID = 1257898737107879380L; - private long k; - - private FilterLargeOnesOnThiState(long k) { - this.k = k; - } - - @Override - public boolean filter(Tuple5 value) throws Exception { - return value.f3 <= this.k; - } - } - -} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/KCoreBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/KCoreBatchOp.java index 37ba47691..9e87ed0cb 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/KCoreBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/KCoreBatchOp.java @@ -1,9 +1,15 @@ package com.alibaba.alink.operator.batch.graph; +import org.apache.flink.api.common.functions.FilterFunction; import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.operators.IterativeDataSet; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple5; +import org.apache.flink.graph.Edge; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; @@ -11,6 +17,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -24,12 +31,16 @@ import com.alibaba.alink.operator.batch.graph.memory.MemoryVertexCentricIteration; import com.alibaba.alink.params.graph.KCoreParams; +import java.util.ArrayList; import java.util.Iterator; +import java.util.List; + @InputPorts(values = @PortSpec(value = PortType.DATA, opType = OpType.BATCH, desc = PortDesc.GRPAH_EDGES)) @OutputPorts(values = @PortSpec(value = PortType.DATA)) @ParamSelectColumnSpec(name = "edgeSourceCol", allowedTypeCollections = TypeCollections.INT_LONG_STRING_TYPES) @ParamSelectColumnSpec(name = "edgeTargetCol", allowedTypeCollections = TypeCollections.INT_LONG_STRING_TYPES) @NameCn("KCore算法") +@NameEn("KCore") public class KCoreBatchOp extends BatchOperator implements KCoreParams { private static final long serialVersionUID = -7537644695230031028L; @@ -139,4 +150,184 @@ public void initEdgesValues() { setAllEdgeValues(VALID_EDGE); } } + + /** + * This algorithm iteratively delete all vertices whose degree is not larger than k ,so that it will select + * a graph whose vertices all have degrees larger than k. + * Dealing with undirectedGraph, we groupby the edge dataset so that we can get the degree of vertices + * through the numbers of vertices appear in the field 0 or 1 in the dataset, and iteratively delete + * edges denoting small degree until the remaining dataset meets the requirement. + */ + public static class KCore { + public static int k; + public int maxIter; + + /** + * @param k Remove all vertices with degree not larger than k. + * @param maxIter The maximum number of iterations to run. + */ + public KCore(int k, int maxIter) { + this.k = k; + this.maxIter = maxIter; + } + + public DataSet > run(DataSet > edges, Boolean directed) { + DataSet > initialState; + if (directed) { + initialState = edges.flatMap(new FlatMapFunction + , Tuple5 >() { + private static final long serialVersionUID = 3415098877090917677L; + + @Override + public void flatMap(Edge value, + Collector > out) { + Tuple5 res = new Tuple5 (); + res.f0 = value.f0; + res.f1 = value.f1; + res.f2 = -1L; + res.f3 = -1L; + res.f4 = 0.; + out.collect(res); + res.f0 = value.f1; + res.f1 = value.f0; + out.collect(res); + } + }); + } else { + initialState = edges.flatMap(new FlatMapFunction + , Tuple5 >() { + private static final long serialVersionUID = -1356257363097879387L; + + @Override + public void flatMap(Edge value, + Collector > out) { + Tuple5 res = new Tuple5 (); + res.f0 = value.f0; + res.f1 = value.f1; + res.f2 = -1L; + res.f3 = -1L; + res.f4 = 0.; + out.collect(res); + } + }); + } + DataSet > outState = operation(initialState); + return outState.map(new MapFunction , Edge >() { + private static final long serialVersionUID = -1652848589684719913L; + + @Override + public Edge map(Tuple5 value) throws Exception { + return new Edge (value.f0, value.f1, 1.); + } + }); + } + + public DataSet > operation( + DataSet > initialState) { + IterativeDataSet > state = initialState + .iterate(this.maxIter); + //Count numbers with field 0, and then filter edges denoting small degree. + DataSet > secondState = state + .groupBy(0) + .reduceGroup(new ReduceOnFirstField()) + .filter(new FilterSmallOnesOnFirstField(k)).name("firstStep"); + //If there is no vertices with small degree, then break the iteration. + DataSet > thirdState = secondState + .groupBy(1) + .reduceGroup(new ReduceOnSecondField()).name("secondStep"); + //seems fussy, but may not avoid + DataSet > outState = state + .closeWith( + thirdState.filter(new FilterSmallOnesOnSecondField(k)).name("filterSmallOne"), + thirdState.filter(new FilterLargeOnesOnThiState(k)).name("filterLargeOne")); + + return outState; + } + + public static class ReduceOnFirstField + implements + GroupReduceFunction , Tuple5 > { + private static final long serialVersionUID = 263920722211539724L; + + @Override + public void reduce(Iterable > values, + Collector > out) throws Exception { + long counter = 0L; + List > l = new ArrayList <>(); + for (Tuple5 i : values) { + counter += 1L; + l.add(i); + } + for (Tuple5 i : l) { + out.collect(new Tuple5 <>(i.f0, i.f1, counter, i.f3, i.f4)); + } + } + } + + public static class ReduceOnSecondField + implements + GroupReduceFunction , Tuple5 > { + private static final long serialVersionUID = 7840099990204577056L; + + @Override + public void reduce(Iterable > values, + Collector > out) throws Exception { + long counter = 0L; + List > l = new ArrayList <>(); + for (Tuple5 i : values) { + counter += 1L; + l.add(i); + } + for (Tuple5 i : l) { + out.collect(new Tuple5 <>(i.f0, i.f1, i.f2, counter, i.f4)); + } + } + } + + public static class FilterSmallOnesOnFirstField + implements FilterFunction > { + private static final long serialVersionUID = -4414815465890029511L; + private long k; + + private FilterSmallOnesOnFirstField(long k) { + this.k = k; + } + + @Override + public boolean filter(Tuple5 value) throws Exception { + return value.f2 > this.k; + } + } + + public static class FilterSmallOnesOnSecondField + implements FilterFunction > { + private static final long serialVersionUID = 156799354134467716L; + private long k; + + private FilterSmallOnesOnSecondField(long k) { + this.k = k; + } + + @Override + public boolean filter(Tuple5 value) throws Exception { + return value.f3 > this.k; + } + } + + public static class FilterLargeOnesOnThiState + implements FilterFunction > { + private static final long serialVersionUID = 1257898737107879380L; + private long k; + + private FilterLargeOnesOnThiState(long k) { + this.k = k; + } + + @Override + public boolean filter(Tuple5 value) throws Exception { + return value.f3 <= this.k; + } + } + + } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/LineBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/LineBatchOp.java index bbc898542..9955e2369 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/LineBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/LineBatchOp.java @@ -21,9 +21,10 @@ import org.apache.flink.types.Row; import org.apache.flink.util.Collector; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; @@ -36,7 +37,6 @@ import com.alibaba.alink.operator.batch.huge.line.ApsSerializeModelLine; import com.alibaba.alink.operator.common.aps.ApsContext; import com.alibaba.alink.operator.common.aps.ApsEnv; -import com.alibaba.alink.operator.common.graph.GraphUtilsWithString; import com.alibaba.alink.operator.common.nlp.WordCountUtil; import com.alibaba.alink.params.graph.LineParams; import com.alibaba.alink.params.nlp.HasBatchSize; @@ -51,6 +51,7 @@ @ParamSelectColumnSpec(name="targetCol", allowedTypeCollections = TypeCollections.INT_LONG_STRING_TYPES) @ParamSelectColumnSpec(name="weightCol", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("Line") +@NameEn("Line") public class LineBatchOp extends BatchOperator implements LineParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/MetaPath2VecBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/MetaPath2VecBatchOp.java index 1d062ad17..8e83d8da7 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/MetaPath2VecBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/MetaPath2VecBatchOp.java @@ -3,9 +3,11 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.huge.impl.MetaPath2VecImpl; @NameCn("MetaPath To Vector") +@NameEn("MetaPath To Vector") public class MetaPath2VecBatchOp extends MetaPath2VecImpl { private static final long serialVersionUID = -6118527393338279346L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/MetaPathWalkBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/MetaPathWalkBatchOp.java index 640df5a28..9eb9ccd6a 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/MetaPathWalkBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/MetaPathWalkBatchOp.java @@ -23,6 +23,7 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -79,6 +80,7 @@ @ParamSelectColumnSpec(name = "vertexCol", portIndices = 1, allowedTypeCollections = {TypeCollections.INT_LONG_TYPES, TypeCollections.STRING_TYPES}) @ParamSelectColumnSpec(name = "typeCol", portIndices = 1, allowedTypeCollections = {TypeCollections.STRING_TYPES}) @NameCn("MetaPath游走") +@NameEn("MetaPath Walk") public class MetaPathWalkBatchOp extends BatchOperator implements MetaPathWalkParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/ModularityCal.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/ModularityCal.java deleted file mode 100644 index cb33fb0aa..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/ModularityCal.java +++ /dev/null @@ -1,275 +0,0 @@ -package com.alibaba.alink.operator.batch.graph; - -import org.apache.flink.api.common.functions.AbstractRichFunction; -import org.apache.flink.api.common.functions.CrossFunction; -import org.apache.flink.api.common.functions.FilterFunction; -import org.apache.flink.api.common.functions.MapFunction; -import org.apache.flink.api.common.functions.ReduceFunction; -import org.apache.flink.api.common.functions.RichGroupReduceFunction; -import org.apache.flink.api.common.functions.RichMapFunction; -import org.apache.flink.api.common.functions.RichMapPartitionFunction; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.aggregation.Aggregations; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.api.java.tuple.Tuple1; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.api.java.tuple.Tuple3; -import org.apache.flink.api.java.tuple.Tuple5; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.graph.Edge; -import org.apache.flink.graph.EdgeDirection; -import org.apache.flink.graph.Graph; -import org.apache.flink.graph.NeighborsFunctionWithVertexValue; -import org.apache.flink.graph.Vertex; -import org.apache.flink.util.Collector; - -import java.util.HashMap; -import java.util.List; -import java.util.Map.Entry; - -/** - * This algorithm calculates the modularity of the input graph. - * The values of vertices of input graph represent the community label. - * The edges represent connect relationship between vertices. - * - * @author qingzhao - */ -public class ModularityCal { - - public static DataSet > modularity(DataSet > groupWeights) { - DataSet > m = groupWeights.aggregate(Aggregations.SUM, 2).project(2); - DataSet > weights = groupWeights.groupBy(0) - .aggregate(Aggregations.SUM, 1) - .and(Aggregations.SUM, 2); - DataSet > modularity = weights.reduceGroup(new RichGroupReduceFunction , Tuple1>() { - @Override - public void reduce(Iterable > values, Collector > out) - throws Exception { - double m = ((Tuple1 ) getRuntimeContext().getBroadcastVariable("m").get(0)).f0; - double inGroup = 0; - double outGroup = 0; - for (Tuple3 value : values) { - inGroup += value.f1; - outGroup += Math.pow(value.f2, 2); - } - out.collect(Tuple1.of(inGroup / m - outGroup / Math.pow(m, 2))); - } - }).withBroadcastSet(m, "m"); - return modularity; - } - - public static DataSet > run2(Graph graph) { - DataSet > edgeInfo = graph - .groupReduceOnNeighbors(new NeighborsFunctionWithVertexValue >() { - - @Override - public void iterateNeighbors(Vertex vertex, - Iterable , - Vertex >> neighbors, - Collector > out) throws Exception { - for (Tuple2 , Vertex > neighbor : neighbors) { - out.collect(Tuple5.of(vertex.f0, vertex.f1, neighbor.f1.f0, neighbor.f1.f1, neighbor.f0.f2)); - } - } - }, EdgeDirection.OUT); - - DataSet > m = edgeInfo.aggregate(Aggregations.SUM, 4).project(4); - - DataSet inAndOutModularity = edgeInfo - .mapPartition( - new MapModularity()) - .reduce(new ReduceFunction () { - @Override - public HashMap reduce(HashMap value1, - HashMap value2) - throws Exception { - HashMap > v1 = value1; - HashMap > v2 = value2; - for (Entry > entry : v2.entrySet()) { - Tuple2 v1Modu = v1.getOrDefault(entry.getKey(), Tuple2.of(0.0, 0.0)); - v1Modu.f0 += entry.getValue().f0; - v1Modu.f1 += entry.getValue().f1; - v1.put(entry.getKey(), v1Modu); - } - return value1; - } - }); - - DataSet > modularity = inAndOutModularity - .map(new RichMapFunction >() { - @Override - public Tuple1 map(HashMap value) throws Exception { - double m = ((Tuple1 ) getRuntimeContext().getBroadcastVariable("m").get(0)).f0; - double in = 0; - double out = 0; - HashMap > v = value; - for (Entry > entry : v.entrySet()) { - double localout = 0; - in += entry.getValue().f0; - localout = entry.getValue().f0 + entry.getValue().f1; - out += Math.pow(localout, 2); - } - - in /= m; - - return Tuple1.of(in - out / Math.pow(m, 2)); - } - }).withBroadcastSet(m, "m"); - - return modularity; - } - - private static class MapModularity - extends RichMapPartitionFunction , HashMap> { - - @Override - public void mapPartition(Iterable > values, - Collector out) throws Exception { - - HashMap > inModularitys = new HashMap <>(); - for (Tuple5 value : values) { - - Tuple2 inAndOut = inModularitys.getOrDefault(value.f1, Tuple2.of(0.0, 0.0)); - if (value.f1.equals(value.f3)) { - double inModu = inAndOut.f0; - inModu += value.f4; - inAndOut.f0 = inModu; - inModularitys.put(value.f1, inAndOut); - } else { - double outModu = inAndOut.f1; - outModu += value.f4; - inAndOut.f1 = outModu; - inModularitys.put(value.f1, inAndOut); - } - } - out.collect(inModularitys); - } - } - - public static DataSet > run(Graph graph) { - //Save the dense matrix in a DataSet with the form of triad. The three elements are row id, column id and the - // value. - //We only need to calculate the matrix through the diag as well as the column, so this design is convenient. - - //change all the edges to the tripe. the three position of the Tuple3 represents the two community of - // the two nodes of the edge, and the 3rd position is 1 - // may try it with getTriplets(). - - //node id, neighbor id, edge weight - DataSet > communityInfo = graph - .groupReduceOnNeighbors(new ErgodicEdge(), EdgeDirection.OUT); - //groupby all the edge information and form the k*k matrix - DataSet > communityInfoReduced = communityInfo - .groupBy(new SelectTuple()) - .reduce(new ReduceOnCommunity()); - //the following two steps calculate m. - //this step calculate sum on row - DataSet > reducedOnRow = communityInfoReduced - .groupBy(1) - .aggregate(Aggregations.SUM, 2) - .project(2); - DataSet > m = reducedOnRow - .aggregate(Aggregations.SUM, 0); - DataSet > temp2 = reducedOnRow - .map(new MapSquare()).aggregate(Aggregations.SUM, 0); - //.reduce(new Sum()); - DataSet > temp1 = communityInfoReduced - .filter(new FilterDiag()) - .aggregate(Aggregations.SUM, 2) - .project(2); - //temp1, temp2 and m are DataSet that only contains one element. - return temp1 - .cross(temp2) - .with(new CrossStep()) - .withBroadcastSet(m, "m"); - } - - public static class ErgodicEdge - implements NeighborsFunctionWithVertexValue > { - private static final long serialVersionUID = 5295386257754049577L; - - @Override - public void iterateNeighbors(Vertex vertex, - Iterable , Vertex >> neighbors, - Collector > out) { - long f0 = vertex.f1.longValue(); - for (Tuple2 , Vertex > neighbor : neighbors) { - //long f1 = neighbor.f1.f1.longValue(); - long f1 = neighbor.f1.f0; - out.collect(Tuple3.of(f0, f1, 1L)); - } - } - } - - public static class SelectTuple - implements KeySelector , Tuple2 > { - private static final long serialVersionUID = 5638365638596494304L; - - @Override - public Tuple2 getKey(Tuple3 value) throws Exception { - return Tuple2.of(value.f0, value.f1); - } - } - - public static class ReduceOnCommunity - implements ReduceFunction > { - private static final long serialVersionUID = 3502336992662864358L; - - @Override - public Tuple3 reduce(Tuple3 value1, - Tuple3 value2) throws Exception { - return new Tuple3 <>(value1.f0, value1.f1, value1.f2 + value2.f2); - } - } - - public static class MapSquare - implements MapFunction , Tuple1 > { - private static final long serialVersionUID = -1719101888137570397L; - - @Override - public Tuple1 map(Tuple1 value) throws Exception { - return new Tuple1 <>(value.f0 * value.f0); - } - } - - public static class FilterDiag implements FilterFunction > { - private static final long serialVersionUID = 6595663411872011784L; - - @Override - public boolean filter(Tuple3 value) throws Exception { - return value.f0.equals(value.f1); - } - } - // - // public static class Sum implements ReduceFunction > { - // private static final long serialVersionUID = -5418729191039529263L; - // - // @Override - // public Tuple1 reduce(Tuple1 value1, Tuple1 value2) throws Exception { - // return new Tuple1 <>(value1.f0 + value2.f0); - // } - // } - - protected static class CrossStep extends AbstractRichFunction - implements CrossFunction , Tuple1 , Tuple1 > { - private static final long serialVersionUID = -7359362890112928974L; - private Tuple1 mTuple; - - @Override - public void open(Configuration parameters) throws Exception { - List > dicList = getRuntimeContext().getBroadcastVariable("m"); - for (Tuple1 s : dicList) { - mTuple = s; - } - } - - @Override - public Tuple1 cross(Tuple1 temp1Tuple, Tuple1 temp2Tuple) throws Exception { - long temp1 = temp1Tuple.f0; - long temp2 = temp2Tuple.f0; - long m = mTuple.f0; - return new Tuple1 <>(1. * temp1 / m - 1. * temp2 / (m * m)); - } - } -} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/ModularityCalBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/ModularityCalBatchOp.java index 492cba1b3..49dc25b24 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/ModularityCalBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/ModularityCalBatchOp.java @@ -1,20 +1,36 @@ package com.alibaba.alink.operator.batch.graph; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.functions.CrossFunction; +import org.apache.flink.api.common.functions.FilterFunction; import org.apache.flink.api.common.functions.JoinFunction; import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichGroupReduceFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.aggregation.Aggregations; +import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple5; +import org.apache.flink.configuration.Configuration; import org.apache.flink.graph.Edge; +import org.apache.flink.graph.EdgeDirection; +import org.apache.flink.graph.Graph; +import org.apache.flink.graph.NeighborsFunctionWithVertexValue; import org.apache.flink.graph.Vertex; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -23,8 +39,12 @@ import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.graph.GraphUtils; import com.alibaba.alink.params.graph.ModularityCalParams; + +import java.util.HashMap; +import java.util.List; +import java.util.Map.Entry; + @InputPorts(values = { @PortSpec(value = PortType.DATA, opType = OpType.BATCH, desc = PortDesc.GRPAH_EDGES), @PortSpec(value = PortType.DATA, opType = OpType.BATCH, desc = PortDesc.GRAPH_VERTICES), @@ -36,6 +56,7 @@ @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = 0) @ParamSelectColumnSpec(name = "edgeWeightCol", portIndices = 0) @NameCn("模块度计算") +@NameEn("Calculate Modularity") public class ModularityCalBatchOp extends BatchOperator implements ModularityCalParams { private static final long serialVersionUID = -7765756516724178687L; @@ -111,4 +132,248 @@ public Row map(Tuple1 value) throws Exception { return this; } + /** + * This algorithm calculates the modularity of the input graph. + * The values of vertices of input graph represent the community label. + * The edges represent connect relationship between vertices. + */ + public static class ModularityCal { + + public static DataSet > modularity(DataSet > groupWeights) { + DataSet > m = groupWeights.aggregate(Aggregations.SUM, 2).project(2); + DataSet > weights = groupWeights.groupBy(0) + .aggregate(Aggregations.SUM, 1) + .and(Aggregations.SUM, 2); + DataSet > modularity = weights.reduceGroup(new RichGroupReduceFunction , Tuple1>() { + @Override + public void reduce(Iterable > values, Collector > out) + throws Exception { + double m = ((Tuple1 ) getRuntimeContext().getBroadcastVariable("m").get(0)).f0; + double inGroup = 0; + double outGroup = 0; + for (Tuple3 value : values) { + inGroup += value.f1; + outGroup += Math.pow(value.f2, 2); + } + out.collect(Tuple1.of(inGroup / m - outGroup / Math.pow(m, 2))); + } + }).withBroadcastSet(m, "m"); + return modularity; + } + + public static DataSet > run2(Graph graph) { + DataSet > edgeInfo = graph + .groupReduceOnNeighbors(new NeighborsFunctionWithVertexValue >() { + + @Override + public void iterateNeighbors(Vertex vertex, + Iterable , + Vertex >> neighbors, + Collector > out) throws Exception { + for (Tuple2 , Vertex > neighbor : neighbors) { + out.collect(Tuple5.of(vertex.f0, vertex.f1, neighbor.f1.f0, neighbor.f1.f1, neighbor.f0.f2)); + } + } + }, EdgeDirection.OUT); + + DataSet > m = edgeInfo.aggregate(Aggregations.SUM, 4).project(4); + + DataSet inAndOutModularity = edgeInfo + .mapPartition( + new MapModularity()) + .reduce(new ReduceFunction () { + @Override + public HashMap reduce(HashMap value1, + HashMap value2) + throws Exception { + HashMap > v1 = value1; + HashMap > v2 = value2; + for (Entry > entry : v2.entrySet()) { + Tuple2 v1Modu = v1.getOrDefault(entry.getKey(), Tuple2.of(0.0, 0.0)); + v1Modu.f0 += entry.getValue().f0; + v1Modu.f1 += entry.getValue().f1; + v1.put(entry.getKey(), v1Modu); + } + return value1; + } + }); + + DataSet > modularity = inAndOutModularity + .map(new RichMapFunction >() { + @Override + public Tuple1 map(HashMap value) throws Exception { + double m = ((Tuple1 ) getRuntimeContext().getBroadcastVariable("m").get(0)).f0; + double in = 0; + double out = 0; + HashMap > v = value; + for (Entry > entry : v.entrySet()) { + double localout = 0; + in += entry.getValue().f0; + localout = entry.getValue().f0 + entry.getValue().f1; + out += Math.pow(localout, 2); + } + + in /= m; + + return Tuple1.of(in - out / Math.pow(m, 2)); + } + }).withBroadcastSet(m, "m"); + + return modularity; + } + + private static class MapModularity + extends RichMapPartitionFunction , HashMap> { + + @Override + public void mapPartition(Iterable > values, + Collector out) throws Exception { + + HashMap > inModularitys = new HashMap <>(); + for (Tuple5 value : values) { + + Tuple2 inAndOut = inModularitys.getOrDefault(value.f1, Tuple2.of(0.0, 0.0)); + if (value.f1.equals(value.f3)) { + double inModu = inAndOut.f0; + inModu += value.f4; + inAndOut.f0 = inModu; + inModularitys.put(value.f1, inAndOut); + } else { + double outModu = inAndOut.f1; + outModu += value.f4; + inAndOut.f1 = outModu; + inModularitys.put(value.f1, inAndOut); + } + } + out.collect(inModularitys); + } + } + + public static DataSet > run(Graph graph) { + //Save the dense matrix in a DataSet with the form of triad. The three elements are row id, column id and the + // value. + //We only need to calculate the matrix through the diag as well as the column, so this design is convenient. + + //change all the edges to the tripe. the three position of the Tuple3 represents the two community of + // the two nodes of the edge, and the 3rd position is 1 + // may try it with getTriplets(). + + //node id, neighbor id, edge weight + DataSet > communityInfo = graph + .groupReduceOnNeighbors(new ErgodicEdge(), EdgeDirection.OUT); + //groupby all the edge information and form the k*k matrix + DataSet > communityInfoReduced = communityInfo + .groupBy(new SelectTuple()) + .reduce(new ReduceOnCommunity()); + //the following two steps calculate m. + //this step calculate sum on row + DataSet > reducedOnRow = communityInfoReduced + .groupBy(1) + .aggregate(Aggregations.SUM, 2) + .project(2); + DataSet > m = reducedOnRow + .aggregate(Aggregations.SUM, 0); + DataSet > temp2 = reducedOnRow + .map(new MapSquare()).aggregate(Aggregations.SUM, 0); + //.reduce(new Sum()); + DataSet > temp1 = communityInfoReduced + .filter(new FilterDiag()) + .aggregate(Aggregations.SUM, 2) + .project(2); + //temp1, temp2 and m are DataSet that only contains one element. + return temp1 + .cross(temp2) + .with(new CrossStep()) + .withBroadcastSet(m, "m"); + } + + public static class ErgodicEdge + implements NeighborsFunctionWithVertexValue > { + private static final long serialVersionUID = 5295386257754049577L; + + @Override + public void iterateNeighbors(Vertex vertex, + Iterable , Vertex >> neighbors, + Collector > out) { + long f0 = vertex.f1.longValue(); + for (Tuple2 , Vertex > neighbor : neighbors) { + //long f1 = neighbor.f1.f1.longValue(); + long f1 = neighbor.f1.f0; + out.collect(Tuple3.of(f0, f1, 1L)); + } + } + } + + public static class SelectTuple + implements KeySelector , Tuple2 > { + private static final long serialVersionUID = 5638365638596494304L; + + @Override + public Tuple2 getKey(Tuple3 value) throws Exception { + return Tuple2.of(value.f0, value.f1); + } + } + + public static class ReduceOnCommunity + implements ReduceFunction > { + private static final long serialVersionUID = 3502336992662864358L; + + @Override + public Tuple3 reduce(Tuple3 value1, + Tuple3 value2) throws Exception { + return new Tuple3 <>(value1.f0, value1.f1, value1.f2 + value2.f2); + } + } + + public static class MapSquare + implements MapFunction , Tuple1 > { + private static final long serialVersionUID = -1719101888137570397L; + + @Override + public Tuple1 map(Tuple1 value) throws Exception { + return new Tuple1 <>(value.f0 * value.f0); + } + } + + public static class FilterDiag implements FilterFunction > { + private static final long serialVersionUID = 6595663411872011784L; + + @Override + public boolean filter(Tuple3 value) throws Exception { + return value.f0.equals(value.f1); + } + } + // + // public static class Sum implements ReduceFunction > { + // private static final long serialVersionUID = -5418729191039529263L; + // + // @Override + // public Tuple1 reduce(Tuple1 value1, Tuple1 value2) throws Exception { + // return new Tuple1 <>(value1.f0 + value2.f0); + // } + // } + + protected static class CrossStep extends AbstractRichFunction + implements CrossFunction , Tuple1 , Tuple1 > { + private static final long serialVersionUID = -7359362890112928974L; + private Tuple1 mTuple; + + @Override + public void open(Configuration parameters) throws Exception { + List > dicList = getRuntimeContext().getBroadcastVariable("m"); + for (Tuple1 s : dicList) { + mTuple = s; + } + } + + @Override + public Tuple1 cross(Tuple1 temp1Tuple, Tuple1 temp2Tuple) throws Exception { + long temp1 = temp1Tuple.f0; + long temp2 = temp2Tuple.f0; + long m = mTuple.f0; + return new Tuple1 <>(1. * temp1 / m - 1. * temp2 / (m * m)); + } + } + } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/MultiSourceShortestPathBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/MultiSourceShortestPathBatchOp.java index f05374571..aa9fc0095 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/MultiSourceShortestPathBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/MultiSourceShortestPathBatchOp.java @@ -24,6 +24,7 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -32,7 +33,6 @@ import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.graph.GraphUtilsWithString; import com.alibaba.alink.params.graph.MultiSourceShortestPathParams; import org.apache.commons.lang3.StringUtils; @@ -55,6 +55,7 @@ @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = 0) @ParamSelectColumnSpec(name = "edgeWeightCol", portIndices = 0) @NameCn("多源最短路径") +@NameEn("Multi Source Shortest Path") public class MultiSourceShortestPathBatchOp extends BatchOperator implements MultiSourceShortestPathParams { private static final long serialVersionUID = -1637471953684406867L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/Node2VecBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/Node2VecBatchOp.java index 446df059c..00f597062 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/Node2VecBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/Node2VecBatchOp.java @@ -3,9 +3,11 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.huge.impl.Node2VecImpl; @NameCn("Node To Vector") +@NameEn("Node To Vector") public class Node2VecBatchOp extends Node2VecImpl { private static final long serialVersionUID = 8596107700297808776L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/Node2VecWalkBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/Node2VecWalkBatchOp.java index 7c5c87f9f..327b9b45f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/Node2VecWalkBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/Node2VecWalkBatchOp.java @@ -20,6 +20,7 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -80,6 +81,7 @@ @ParamSelectColumnSpec(name = "targetCol", portIndices = 0, allowedTypeCollections = {TypeCollections.INT_LONG_TYPES, TypeCollections.STRING_TYPES}) @ParamSelectColumnSpec(name = "weightCol", portIndices = 0, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}) @NameCn("Node2Vec游走") +@NameEn("Node2Vec Walk") public final class Node2VecWalkBatchOp extends BatchOperator implements Node2VecWalkParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/PageRankBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/PageRankBatchOp.java index b1215efbb..174ff72e6 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/PageRankBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/PageRankBatchOp.java @@ -14,6 +14,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -22,8 +23,8 @@ import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.batch.graph.memory.MemoryVertexCentricIteration; import com.alibaba.alink.operator.batch.graph.memory.MemoryComputeFunction; +import com.alibaba.alink.operator.batch.graph.memory.MemoryVertexCentricIteration; import com.alibaba.alink.params.graph.PageRankParams; import java.util.Iterator; @@ -39,6 +40,7 @@ @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = 0) @ParamSelectColumnSpec(name = "edgeWeightCol", portIndices = 0) @NameCn("PageRank算法") +@NameEn("PageRank") public class PageRankBatchOp extends BatchOperator implements PageRankParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/RandomWalkBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/RandomWalkBatchOp.java index a1983191e..deef51710 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/RandomWalkBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/RandomWalkBatchOp.java @@ -19,6 +19,7 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -70,6 +71,7 @@ @ParamSelectColumnSpec(name = "targetCol", portIndices = 0, allowedTypeCollections = {TypeCollections.INT_LONG_TYPES, TypeCollections.STRING_TYPES}) @ParamSelectColumnSpec(name = "weightCol", portIndices = 0, allowedTypeCollections = {TypeCollections.NUMERIC_TYPES}) @NameCn("随机游走") +@NameEn("Random Walk") public final class RandomWalkBatchOp extends BatchOperator implements RandomWalkParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/SingleSourceShortestPathBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/SingleSourceShortestPathBatchOp.java index cff0251ac..8ef875236 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/SingleSourceShortestPathBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/SingleSourceShortestPathBatchOp.java @@ -20,6 +20,7 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -28,7 +29,6 @@ import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.graph.GraphUtilsWithString; import com.alibaba.alink.params.graph.SingleSourceShortestPathParams; @InputPorts(values = @PortSpec(value = PortType.DATA, opType = OpType.BATCH, desc = PortDesc.GRPAH_EDGES)) @OutputPorts(values = @PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)) @@ -36,6 +36,7 @@ @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = 0) @ParamSelectColumnSpec(name = "edgeWeightCol", portIndices = 0) @NameCn("单源最短路径") +@NameEn("Single Source Shortest Path") public class SingleSourceShortestPathBatchOp extends BatchOperator implements SingleSourceShortestPathParams { private static final long serialVersionUID = -1637471953684406867L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/TreeDepth.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/TreeDepth.java deleted file mode 100644 index 4c8651dc8..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/TreeDepth.java +++ /dev/null @@ -1,164 +0,0 @@ -package com.alibaba.alink.operator.batch.graph; - -import org.apache.flink.api.common.functions.MapFunction; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.tuple.Tuple3; -import org.apache.flink.graph.Edge; -import org.apache.flink.graph.Graph; -import org.apache.flink.graph.Vertex; -import org.apache.flink.graph.asm.degree.annotate.directed.VertexInDegree; -import org.apache.flink.graph.pregel.ComputeFunction; -import org.apache.flink.graph.pregel.MessageIterator; -import org.apache.flink.graph.pregel.VertexCentricConfiguration; -import org.apache.flink.types.LongValue; - -import com.alibaba.alink.operator.batch.BatchOperator; - -/** - * As for a legal input forest, there is not any nodes have more than one precursor, or there is no ring. - * If the input is illegal, exception will be raised. - * Considering the situation that a no-zero-degree-root ring exists, we set the max iteration being 50. - * There cannot exist a input has a depth larger than 2^50. After 50 iterations, if there exists nodes with - * negative depth value, exception will be raised. - *

- * The time complexity of this algorithm is T(log n), relating to the depth. - * Algorithm description: each vertex has a Tuple3 value, the three elements represent its id, its father's id, - * as well as the relative depth between it and its father. If the relative depth is positive, it mean this vertex - * has got linked to a root node, otherwise the negative depth denotes the difference between it and its father. - * The vertices with depth not less than zero have get linked to a root, they won't send back messages unless - * they receive messages from their sons. As for the son vertices, they send messages to their father nodes, - * and update father nodes according to the back replies, until they get linked to a root. - *

- *

- * In a legal tree, each node have at most one pre-node. So if one node has more than one pre-node, the algorithm - * will judge it illegal. - * - * @author qingzhao - */ - -public class TreeDepth { - public Integer maxIter; - - /** - * @param maxIter The maximum number of iterations to run. - */ - public TreeDepth(int maxIter) { - this.maxIter = maxIter; - } - - private Graph , Double> operation(Graph graph) { - DataSet > inVertexTemp; - //choose the root nodes, whose value is 0, while others' are 1. - try { - inVertexTemp = graph.run(new VertexInDegree () - .setIncludeZeroDegreeVertices(true)); - } catch (Exception e) { - throw new RuntimeException(e); - } - - DataSet >> graphNewVertices = inVertexTemp.map( - new MapVertexValue()); - DataSet > edgeDataSet = graph.getEdges().map( - new ReverseEdge()); - Graph , Double> graphExecute = Graph.fromDataSet( - graphNewVertices, - edgeDataSet, - BatchOperator.getExecutionEnvironmentFromDataSets(graphNewVertices, edgeDataSet)); - VertexCentricConfiguration parameters = new VertexCentricConfiguration(); - parameters.setName("tree depth iteration"); - Graph , Double> res = graphExecute - .runVertexCentricIteration( - new Execute(), - null, - maxIter, parameters); - return res; - } - - public DataSet > run(Graph graph) { - return operation(graph).getVertices().map( - new JudgeTupleIllegal()); - } - - //initial step, set the initial vertex value - public static class MapVertexValue - implements MapFunction , Vertex >> { - private static final long serialVersionUID = 2154022863365357679L; - - @Override - public Vertex > map(Vertex value) { - return value.f1.getValue() == 0 ? - new Vertex <>(value.f0, new Tuple3 <>(value.f0, value.f0, 0.)) : - new Vertex <>(value.f0, new Tuple3 <>(value.f0, value.f0, -1.)); - } - } - - public static class ReverseEdge implements MapFunction , Edge > { - private static final long serialVersionUID = 7575794558756147475L; - - @Override - public Edge map(Edge value) { - if (value.f2 <= 0) { - throw new RuntimeException("Edge " + value + " is illegal. Edge weight must be positive!"); - } - return new Edge <>(value.f1, value.f0, value.f2); - } - } - - public static class Execute extends - ComputeFunction , Double, Tuple3 > { - private static final long serialVersionUID = -2503583975560433984L; - - @Override - public void compute(Vertex > vertex, - MessageIterator> messages) - throws Exception { - //initial step. Vertices besides roots send message to their fathers and change their values. - if (vertex.f1.f1.equals(vertex.f0) && vertex.f1.f1.equals(vertex.f1.f0) && vertex.f1.f2 != 0) { - //if a vertex send message to more than one vertex, it mean that the vertex has more than one - // precursor. - boolean flag = false; - for (Edge edge : getEdges()) { - if (!flag) { - Tuple3 temp = new Tuple3 <>(vertex.f1.f0, edge.f1, -edge.f2); - sendMessageTo(edge.getTarget(), temp); - setNewVertexValue(temp); - flag = true; - } else { - throw new Exception("illegal input!!!"); - } - } - } else { - //when receiving a message, judge it is from father or son. - for (Tuple3 msg : messages) { - //receive message from son - if (msg.f1.equals(vertex.f0)) { - //send message to its son - sendMessageTo(msg.f0, vertex.f1); - } else { - //receive message from father - if (msg.f2 >= 0) { - setNewVertexValue(new Tuple3 <>(vertex.f0, msg.f1, msg.f2 - vertex.f1.f2)); - } else { - Tuple3 temp = new Tuple3 <>(vertex.f0, msg.f1, msg.f2 + vertex.f1.f2); - setNewVertexValue(temp); - sendMessageTo(temp.f1, temp); - } - } - } - } - } - } - - public static class JudgeTupleIllegal - implements MapFunction >, Tuple3 > { - private static final long serialVersionUID = 858956933724773542L; - - @Override - public Tuple3 map(Vertex > value) throws Exception { - if (value.f1.f2 < 0) { - throw new RuntimeException("illegal input!!!"); - } - return value.f1; - } - } -} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/TreeDepthBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/TreeDepthBatchOp.java index 7c11bd52c..2b63264bf 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/TreeDepthBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/TreeDepthBatchOp.java @@ -1,16 +1,25 @@ package com.alibaba.alink.operator.batch.graph; +import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.graph.Edge; import org.apache.flink.graph.Graph; +import org.apache.flink.graph.Vertex; +import org.apache.flink.graph.asm.degree.annotate.directed.VertexInDegree; +import org.apache.flink.graph.pregel.ComputeFunction; +import org.apache.flink.graph.pregel.MessageIterator; +import org.apache.flink.graph.pregel.VertexCentricConfiguration; import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.LongValue; import org.apache.flink.types.Row; import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -20,7 +29,6 @@ import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.utils.GraphTransformUtils; -import com.alibaba.alink.operator.common.graph.GraphUtilsWithString; import com.alibaba.alink.params.graph.TreeDepthParams; @InputPorts(values = @PortSpec(value = PortType.DATA, opType = OpType.BATCH, desc = PortDesc.GRPAH_EDGES)) @OutputPorts(values = @PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)) @@ -28,6 +36,7 @@ @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = 0) @ParamSelectColumnSpec(name = "edgeWeightCol", portIndices = 0) @NameCn("树深度") +@NameEn("Tree Depth") public class TreeDepthBatchOp extends BatchOperator implements TreeDepthParams { private static final long serialVersionUID = -6574485904046547006L; @@ -68,4 +77,151 @@ public TreeDepthBatchOp linkFrom(BatchOperator ... inputs) { return this; } + + /** + * As for a legal input forest, there is not any nodes have more than one precursor, or there is no ring. + * If the input is illegal, exception will be raised. + * Considering the situation that a no-zero-degree-root ring exists, we set the max iteration being 50. + * There cannot exist a input has a depth larger than 2^50. After 50 iterations, if there exists nodes with + * negative depth value, exception will be raised. + *

+ * The time complexity of this algorithm is T(log n), relating to the depth. + * Algorithm description: each vertex has a Tuple3 value, the three elements represent its id, its father's id, + * as well as the relative depth between it and its father. If the relative depth is positive, it mean this vertex + * has got linked to a root node, otherwise the negative depth denotes the difference between it and its father. + * The vertices with depth not less than zero have get linked to a root, they won't send back messages unless + * they receive messages from their sons. As for the son vertices, they send messages to their father nodes, + * and update father nodes according to the back replies, until they get linked to a root. + *

+ *

+ * In a legal tree, each node have at most one pre-node. So if one node has more than one pre-node, the algorithm + * will judge it illegal. + */ + + public static class TreeDepth { + public Integer maxIter; + + /** + * @param maxIter The maximum number of iterations to run. + */ + public TreeDepth(int maxIter) { + this.maxIter = maxIter; + } + + private Graph , Double> operation(Graph graph) { + DataSet > inVertexTemp; + //choose the root nodes, whose value is 0, while others' are 1. + try { + inVertexTemp = graph.run(new VertexInDegree () + .setIncludeZeroDegreeVertices(true)); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DataSet >> graphNewVertices = inVertexTemp.map( + new MapVertexValue()); + DataSet > edgeDataSet = graph.getEdges().map( + new ReverseEdge()); + Graph , Double> graphExecute = Graph.fromDataSet( + graphNewVertices, + edgeDataSet, + getExecutionEnvironmentFromDataSets(graphNewVertices, edgeDataSet)); + VertexCentricConfiguration parameters = new VertexCentricConfiguration(); + parameters.setName("tree depth iteration"); + Graph , Double> res = graphExecute + .runVertexCentricIteration( + new Execute(), + null, + maxIter, parameters); + return res; + } + + public DataSet > run(Graph graph) { + return operation(graph).getVertices().map( + new JudgeTupleIllegal()); + } + + //initial step, set the initial vertex value + public static class MapVertexValue + implements MapFunction , Vertex >> { + private static final long serialVersionUID = 2154022863365357679L; + + @Override + public Vertex > map(Vertex value) { + return value.f1.getValue() == 0 ? + new Vertex <>(value.f0, new Tuple3 <>(value.f0, value.f0, 0.)) : + new Vertex <>(value.f0, new Tuple3 <>(value.f0, value.f0, -1.)); + } + } + + public static class ReverseEdge implements MapFunction , Edge > { + private static final long serialVersionUID = 7575794558756147475L; + + @Override + public Edge map(Edge value) { + if (value.f2 <= 0) { + throw new RuntimeException("Edge " + value + " is illegal. Edge weight must be positive!"); + } + return new Edge <>(value.f1, value.f0, value.f2); + } + } + + public static class Execute extends + ComputeFunction , Double, Tuple3 > { + private static final long serialVersionUID = -2503583975560433984L; + + @Override + public void compute(Vertex > vertex, + MessageIterator > messages) + throws Exception { + //initial step. Vertices besides roots send message to their fathers and change their values. + if (vertex.f1.f1.equals(vertex.f0) && vertex.f1.f1.equals(vertex.f1.f0) && vertex.f1.f2 != 0) { + //if a vertex send message to more than one vertex, it mean that the vertex has more than one + // precursor. + boolean flag = false; + for (Edge edge : getEdges()) { + if (!flag) { + Tuple3 temp = new Tuple3 <>(vertex.f1.f0, edge.f1, -edge.f2); + sendMessageTo(edge.getTarget(), temp); + setNewVertexValue(temp); + flag = true; + } else { + throw new Exception("illegal input!!!"); + } + } + } else { + //when receiving a message, judge it is from father or son. + for (Tuple3 msg : messages) { + //receive message from son + if (msg.f1.equals(vertex.f0)) { + //send message to its son + sendMessageTo(msg.f0, vertex.f1); + } else { + //receive message from father + if (msg.f2 >= 0) { + setNewVertexValue(new Tuple3 <>(vertex.f0, msg.f1, msg.f2 - vertex.f1.f2)); + } else { + Tuple3 temp = new Tuple3 <>(vertex.f0, msg.f1, msg.f2 + vertex.f1.f2); + setNewVertexValue(temp); + sendMessageTo(temp.f1, temp); + } + } + } + } + } + } + + public static class JudgeTupleIllegal + implements MapFunction >, Tuple3 > { + private static final long serialVersionUID = 858956933724773542L; + + @Override + public Tuple3 map(Vertex > value) throws Exception { + if (value.f1.f2 < 0) { + throw new RuntimeException("illegal input!!!"); + } + return value.f1; + } + } + } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/TriangleList.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/TriangleList.java deleted file mode 100644 index e12c5e92e..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/TriangleList.java +++ /dev/null @@ -1,59 +0,0 @@ -package com.alibaba.alink.operator.batch.graph; - -import org.apache.flink.api.common.functions.FlatMapFunction; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.tuple.Tuple3; -import org.apache.flink.graph.Graph; -import org.apache.flink.graph.asm.translate.TranslateFunction; -import org.apache.flink.graph.library.clustering.directed.TriangleListing; -import org.apache.flink.types.DoubleValue; -import org.apache.flink.types.LongValue; -import org.apache.flink.util.Collector; - -/** - * Return the numbers of triangles of the input graph. - * - * @author qingzhao - */ -public class TriangleList { - public static DataSet > run(Graph graph) throws Exception { - DataSet > inn = graph - .translateGraphIds(new LongToLongValue()) - .translateVertexValues(new DoubleToDoubleValue()) - .translateEdgeValues(new DoubleToDoubleValue()) - .run(new TriangleListing <>()); - return inn.flatMap(new FlatMapOut()); - } - - public static class LongToLongValue implements TranslateFunction { - private static final long serialVersionUID = -8328414272957924783L; - - @Override - public LongValue translate(Long value, LongValue reuse) { - return new LongValue(value); - } - } - - public static class DoubleToDoubleValue implements TranslateFunction { - private static final long serialVersionUID = -1396593185269273081L; - - @Override - public DoubleValue translate(Double value, DoubleValue reuse) { - return new DoubleValue(value); - } - } - - public static class FlatMapOut - implements FlatMapFunction , Tuple3 > { - private static final long serialVersionUID = -7744611721605759120L; - - @Override - public void flatMap(TriangleListing.Result value, Collector > out) { - Tuple3 temp = new Tuple3 <>(); - temp.f0 = value.getVertexId0().getValue(); - temp.f1 = value.getVertexId1().getValue(); - temp.f2 = value.getVertexId2().getValue(); - out.collect(temp); - } - } -} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/TriangleListBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/TriangleListBatchOp.java index f8f2a5503..431c4168f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/TriangleListBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/TriangleListBatchOp.java @@ -1,5 +1,6 @@ package com.alibaba.alink.operator.batch.graph; +import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.DataSet; @@ -7,13 +8,20 @@ import org.apache.flink.graph.Edge; import org.apache.flink.graph.Graph; import org.apache.flink.graph.Vertex; +import org.apache.flink.graph.asm.translate.TranslateFunction; +import org.apache.flink.graph.library.clustering.directed.TriangleListing; +import org.apache.flink.graph.library.clustering.directed.TriangleListing.Result; import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.DoubleValue; +import org.apache.flink.types.LongValue; import org.apache.flink.types.NullValue; import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -22,7 +30,6 @@ import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.graph.GraphUtilsWithString; import com.alibaba.alink.params.graph.TriangleListParams; @InputPorts(values = @PortSpec(value = PortType.DATA, opType = OpType.BATCH, desc = PortDesc.GRPAH_EDGES)) @@ -30,6 +37,7 @@ @ParamSelectColumnSpec(name = "edgeSourceCol", portIndices = 0) @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = 0) @NameCn("计数三角形") +@NameEn("Count Triangles") public class TriangleListBatchOp extends BatchOperator implements TriangleListParams { private static final long serialVersionUID = -5985547688589472574L; @@ -84,4 +92,49 @@ public Double map(Vertex value) throws Exception { } } + /** + * Return the numbers of triangles of the input graph. + */ + public static class TriangleList { + public static DataSet > run(Graph graph) throws Exception { + DataSet > inn = graph + .translateGraphIds(new LongToLongValue()) + .translateVertexValues(new DoubleToDoubleValue()) + .translateEdgeValues(new DoubleToDoubleValue()) + .run(new TriangleListing <>()); + return inn.flatMap(new FlatMapOut()); + } + + public static class LongToLongValue implements TranslateFunction { + private static final long serialVersionUID = -8328414272957924783L; + + @Override + public LongValue translate(Long value, LongValue reuse) { + return new LongValue(value); + } + } + + public static class DoubleToDoubleValue implements TranslateFunction { + private static final long serialVersionUID = -1396593185269273081L; + + @Override + public DoubleValue translate(Double value, DoubleValue reuse) { + return new DoubleValue(value); + } + } + + public static class FlatMapOut + implements FlatMapFunction , Tuple3 > { + private static final long serialVersionUID = -7744611721605759120L; + + @Override + public void flatMap(Result value, Collector > out) { + Tuple3 temp = new Tuple3 <>(); + temp.f0 = value.getVertexId0().getValue(); + temp.f1 = value.getVertexId1().getValue(); + temp.f2 = value.getVertexId2().getValue(); + out.collect(temp); + } + } + } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/VertexClusterCoefficient.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/VertexClusterCoefficient.java deleted file mode 100644 index 7b2953130..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/VertexClusterCoefficient.java +++ /dev/null @@ -1,111 +0,0 @@ -package com.alibaba.alink.operator.batch.graph; - -import org.apache.flink.api.common.functions.CoGroupFunction; -import org.apache.flink.api.common.functions.FlatMapFunction; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.aggregation.Aggregations; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.api.java.tuple.Tuple4; -import org.apache.flink.graph.Graph; -import org.apache.flink.graph.asm.translate.TranslateFunction; -import org.apache.flink.graph.library.clustering.undirected.TriangleListing; -import org.apache.flink.types.LongValue; -import org.apache.flink.util.Collector; - -/** - * as for each vertex, return 1. its degree 2. the number of triangles with it as a vertex - * - * @author qingzhao - */ - -public class VertexClusterCoefficient { - public DataSet > run(Graph graph) throws Exception { - //get all the triangles and write all the vertices id in Tuple2 position 0. We will then calculate - // how many times the vertices appear in Tuplw2 position 1 in the following operation. - DataSet > triangleVertex = graph.translateGraphIds(new LongToLongvalue()) - .translateVertexValues(new DoubleToLongvalue()) - .translateEdgeValues(new DoubleToLongvalue()) - .run(new TriangleListing ()) - .flatMap(new Result2Long()); - DataSet > vertexCounted = triangleVertex - .groupBy(0).aggregate(Aggregations.SUM, 1); - //.reduceGroup(new CountVertex()); - //这个coGroup配对的两个DataSet规模是点的数目。。 - //采用编码解码来优化?? - return graph.inDegrees() - .coGroup(vertexCounted) - .where(0) - .equalTo(0) - .with(new CoGroupStep()); - } - - public static class LongToLongvalue implements TranslateFunction { - private static final long serialVersionUID = 6836903282078114665L; - - @Override - public LongValue translate(Long value, LongValue reuse) { - return new LongValue(value); - } - } - - public static class DoubleToLongvalue implements TranslateFunction { - private static final long serialVersionUID = -1849025103879518660L; - - @Override - public LongValue translate(Double value, LongValue reuse) { - return new LongValue(value.intValue()); - } - } - - public static class Result2Long - implements FlatMapFunction , Tuple2 > { - private static final long serialVersionUID = 2997438245762067649L; - - @Override - public void flatMap( - TriangleListing.Result value, - Collector > out) { - out.collect(new Tuple2 <>(value.getVertexId0().getValue(), 1L)); - out.collect(new Tuple2 <>(value.getVertexId1().getValue(), 1L)); - out.collect(new Tuple2 <>(value.getVertexId2().getValue(), 1L)); - } - } - - // public static class CountVertex - // implements GroupReduceFunction, Tuple2> { - // private static final long serialVersionUID = -659852844684797387L; - // - // @Override - // public void reduce(Iterable> values, - // Collector> out) throws Exception { - // long id = -1L; - // long count = 0L; - // for (Tuple2 i : values) { - // id = i.f0; - // count += 1L; - // } - // out.collect(new Tuple2<>(id, count)); - // } - // } - - public static class CoGroupStep implements CoGroupFunction , - Tuple2 , - Tuple4 > { - private static final long serialVersionUID = -6391324728861498560L; - - @Override - public void coGroup(Iterable > first, - Iterable > second, - Collector > out) { - for (Tuple2 i : second) { - Tuple4 outSingle = new Tuple4 <>(); - Tuple2 firstSingle = first.iterator().next(); - outSingle.f0 = firstSingle.f0; - outSingle.f1 = firstSingle.f1.getValue(); - outSingle.f2 = i.f1; - outSingle.f3 = outSingle.f2 * 2. / (outSingle.f1 * (outSingle.f1 - 1)); - out.collect(outSingle); - } - } - } -} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/VertexClusterCoefficientBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/VertexClusterCoefficientBatchOp.java index c10c0062b..9d2861ff1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/VertexClusterCoefficientBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/VertexClusterCoefficientBatchOp.java @@ -1,19 +1,30 @@ package com.alibaba.alink.operator.batch.graph; +import org.apache.flink.api.common.functions.CoGroupFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.aggregation.Aggregations; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple4; import org.apache.flink.graph.Edge; import org.apache.flink.graph.Graph; import org.apache.flink.graph.Vertex; +import org.apache.flink.graph.asm.translate.TranslateFunction; +import org.apache.flink.graph.library.clustering.undirected.TriangleListing; +import org.apache.flink.graph.library.clustering.undirected.TriangleListing.Result; import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.LongValue; import org.apache.flink.types.NullValue; import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -22,7 +33,6 @@ import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.graph.GraphUtilsWithString; import com.alibaba.alink.params.graph.VertexClusterCoefficientParams; @InputPorts(values = @PortSpec(value = PortType.DATA, opType = OpType.BATCH, desc = PortDesc.GRPAH_EDGES)) @@ -30,6 +40,7 @@ @ParamSelectColumnSpec(name = "edgeSourceCol", portIndices = 0) @ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = 0) @NameCn("点聚类系数") +@NameEn("Vertex Cluster Coefficient") public class VertexClusterCoefficientBatchOp extends BatchOperator implements VertexClusterCoefficientParams { private static final long serialVersionUID = 3694935054423399372L; @@ -85,4 +96,101 @@ public Double map(Vertex value) throws Exception { return 1.; } } + + /** + * as for each vertex, return 1. its degree 2. the number of triangles with it as a vertex + * + */ + + public static class VertexClusterCoefficient { + public DataSet > run(Graph graph) throws Exception { + //get all the triangles and write all the vertices id in Tuple2 position 0. We will then calculate + // how many times the vertices appear in Tuplw2 position 1 in the following operation. + DataSet > triangleVertex = graph.translateGraphIds(new LongToLongvalue()) + .translateVertexValues(new DoubleToLongvalue()) + .translateEdgeValues(new DoubleToLongvalue()) + .run(new TriangleListing ()) + .flatMap(new Result2Long()); + DataSet > vertexCounted = triangleVertex + .groupBy(0).aggregate(Aggregations.SUM, 1); + //.reduceGroup(new CountVertex()); + //这个coGroup配对的两个DataSet规模是点的数目。。 + //采用编码解码来优化?? + return graph.inDegrees() + .coGroup(vertexCounted) + .where(0) + .equalTo(0) + .with(new CoGroupStep()); + } + + public static class LongToLongvalue implements TranslateFunction { + private static final long serialVersionUID = 6836903282078114665L; + + @Override + public LongValue translate(Long value, LongValue reuse) { + return new LongValue(value); + } + } + + public static class DoubleToLongvalue implements TranslateFunction { + private static final long serialVersionUID = -1849025103879518660L; + + @Override + public LongValue translate(Double value, LongValue reuse) { + return new LongValue(value.intValue()); + } + } + + public static class Result2Long + implements FlatMapFunction , Tuple2 > { + private static final long serialVersionUID = 2997438245762067649L; + + @Override + public void flatMap( + Result value, + Collector > out) { + out.collect(new Tuple2 <>(value.getVertexId0().getValue(), 1L)); + out.collect(new Tuple2 <>(value.getVertexId1().getValue(), 1L)); + out.collect(new Tuple2 <>(value.getVertexId2().getValue(), 1L)); + } + } + + // public static class CountVertex + // implements GroupReduceFunction, Tuple2> { + // private static final long serialVersionUID = -659852844684797387L; + // + // @Override + // public void reduce(Iterable> values, + // Collector> out) throws Exception { + // long id = -1L; + // long count = 0L; + // for (Tuple2 i : values) { + // id = i.f0; + // count += 1L; + // } + // out.collect(new Tuple2<>(id, count)); + // } + // } + + public static class CoGroupStep implements CoGroupFunction , + Tuple2 , + Tuple4 > { + private static final long serialVersionUID = -6391324728861498560L; + + @Override + public void coGroup(Iterable > first, + Iterable > second, + Collector > out) { + for (Tuple2 i : second) { + Tuple4 outSingle = new Tuple4 <>(); + Tuple2 firstSingle = first.iterator().next(); + outSingle.f0 = firstSingle.f0; + outSingle.f1 = firstSingle.f1.getValue(); + outSingle.f2 = i.f1; + outSingle.f3 = outSingle.f2 * 2. / (outSingle.f1 * (outSingle.f1 - 1)); + out.collect(outSingle); + } + } + } + } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/VertexNeighborSearch.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/VertexNeighborSearch.java deleted file mode 100644 index 6d422b45e..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/batch/graph/VertexNeighborSearch.java +++ /dev/null @@ -1,107 +0,0 @@ -package com.alibaba.alink.operator.batch.graph; - -import org.apache.flink.api.common.functions.FilterFunction; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.graph.Edge; -import org.apache.flink.graph.Graph; -import org.apache.flink.graph.GraphAlgorithm; -import org.apache.flink.graph.Vertex; -import org.apache.flink.graph.asm.translate.TranslateFunction; -import org.apache.flink.graph.pregel.ComputeFunction; -import org.apache.flink.graph.pregel.MessageCombiner; -import org.apache.flink.graph.pregel.MessageIterator; - -import java.util.HashSet; - -/** - * This class implements a graph algorithm that returns a sub-graph induced by - * a node list `sources` and their neighbors with at most `depth` hops. - * - * @author Fan Hong - */ - -public class VertexNeighborSearch implements GraphAlgorithm > { - - private int depth; - private HashSet sources; - - public VertexNeighborSearch(HashSet sources, int depth) { - this.sources = sources; - this.depth = depth; - } - - @Override - public Graph run(Graph graph) throws Exception { - Graph subgraph = graph - .translateVertexValues(new SetLongMaxValue()) - .runVertexCentricIteration(new VertexNeighborComputeFunction(sources), new MinimumDistanceCombiner(), - depth + 1) - .filterOnVertices(new FilterByValue(depth)); - - DataSet > vertices = subgraph.getVertices(); - return subgraph; - } - - public static final class SetLongMaxValue implements TranslateFunction { - private static final long serialVersionUID = 6439208445273249327L; - - @Override - public Long translate(Long aLong, Long o) { - return Long.MAX_VALUE / 2L; - } - } - - public static final class FilterByValue implements FilterFunction > { - private static final long serialVersionUID = -3443337881858305297L; - private int thresh; - - FilterByValue(int thresh) { - this.thresh = thresh; - } - - @Override - public boolean filter(Vertex vertex) { - return vertex.getValue() <= thresh; - } - } - - public static final class VertexNeighborComputeFunction extends ComputeFunction { - private static final long serialVersionUID = -4927352871373053814L; - private HashSet sources; - - VertexNeighborComputeFunction(HashSet sources) { - this.sources = sources; - } - - @Override - public void compute(Vertex vertex, MessageIterator messages) { - long minDistance = sources.contains(vertex.getId()) ? 0L : Long.MAX_VALUE / 2; - - for (Long msg : messages) { - minDistance = Math.min(minDistance, msg); - } - - if (minDistance < vertex.getValue()) { - setNewVertexValue(minDistance); - for (Edge e : getEdges()) { - sendMessageTo(e.getTarget(), minDistance + 1); - } - } - } - } - - public static final class MinimumDistanceCombiner extends MessageCombiner { - - private static final long serialVersionUID = 1916706983491173310L; - - public void combineMessages(MessageIterator messages) { - - long minMessage = Long.MAX_VALUE / 2; - for (Long msg : messages) { - minMessage = Math.min(minMessage, msg); - } - sendCombinedMessage(minMessage); - } - } -} - diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/graph/VertexNeighborSearchBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/graph/VertexNeighborSearchBatchOp.java new file mode 100644 index 000000000..64b5f8661 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/graph/VertexNeighborSearchBatchOp.java @@ -0,0 +1,460 @@ +package com.alibaba.alink.operator.batch.graph; + +import org.apache.flink.api.common.functions.CrossFunction; +import org.apache.flink.api.common.functions.FilterFunction; +import org.apache.flink.api.common.functions.GroupCombineFunction; +import org.apache.flink.api.common.functions.JoinFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple1; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.graph.Edge; +import org.apache.flink.graph.Graph; +import org.apache.flink.graph.GraphAlgorithm; +import org.apache.flink.graph.Vertex; +import org.apache.flink.graph.asm.translate.TranslateFunction; +import org.apache.flink.graph.pregel.ComputeFunction; +import org.apache.flink.graph.pregel.MessageCombiner; +import org.apache.flink.graph.pregel.MessageIterator; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Table; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.common.MLEnvironmentFactory; +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortSpec.OpType; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.utils.AlinkSerializable; +import com.alibaba.alink.common.viz.AlinkViz; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.common.viz.VizDataWriterInterface; +import com.alibaba.alink.common.viz.VizOpChartData; +import com.alibaba.alink.common.viz.VizOpDataInfo; +import com.alibaba.alink.common.viz.VizOpMeta; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.DataSetUtil; +import com.alibaba.alink.operator.common.dataproc.FirstReducer; +import com.alibaba.alink.params.graph.VertexNeighborSearchParams; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static com.alibaba.alink.common.utils.JsonConverter.gson; + +/** + * This algorithm implements vertex neighbor search. + * It returns an induced sub-graph whose vertices are within most k-hops of sources vertices. + * + * @author Fan Hong + */ +@InputPorts(values = { + @PortSpec(value = PortType.DATA, opType = OpType.BATCH, desc = PortDesc.GRPAH_EDGES), + @PortSpec(value = PortType.DATA, opType = OpType.BATCH, desc = PortDesc.GRAPH_VERTICES, isOptional = true) +}) +@OutputPorts(values = { + @PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT) +}) +@ParamSelectColumnSpec(name = "vertexIdCol", portIndices = 1) +@ParamSelectColumnSpec(name = "edgeSourceCol", portIndices = 0) +@ParamSelectColumnSpec(name = "edgeTargetCol", portIndices = 0) + +@NameCn("点邻居搜索") +@NameEn("Vertex Neighbor Search") +public final class VertexNeighborSearchBatchOp extends BatchOperator + implements VertexNeighborSearchParams , AlinkViz { + + private static final long serialVersionUID = 4341880061845091326L; + static int MAX_NUM_EDGES_TO_ES = 1200; + + public VertexNeighborSearchBatchOp() { + super(new Params()); + } + + public VertexNeighborSearchBatchOp(Params params) { + super(params); + } + + @Override + public VertexNeighborSearchBatchOp linkFrom(BatchOperator ... inputs) { + checkMinOpSize(1, inputs); + + VizDataWriterInterface writer = this.getVizDataWriter(); + + // Parse parameters and inputs + BatchOperator edgesOperator = inputs[0]; + + Boolean isUndirected = getAsUndirectedGraph(); + String vertexIdColName = getVertexIdCol(); + String edgeSourceColName = getEdgeSourceCol(); + String edgeTargetColName = getEdgeTargetCol(); + + int edgeSourceColId = TableUtil.findColIndexWithAssertAndHint(edgesOperator.getColNames(), edgeSourceColName); + int edgeTargetColId = TableUtil.findColIndexWithAssertAndHint(edgesOperator.getColNames(), edgeTargetColName); + + int depth = getDepth(); + HashSet sources = new HashSet <>(Arrays.asList(getSources())); + + DataSet > edges = edgesOperator.getDataSet() + .map(new Row2EdgeTuple(edgeSourceColId, edgeTargetColId)); + + // Construct graph + Graph graph; + BatchOperator verticesOperator = null; + int vertexIdColId = 0; + DataSet > vertices = null; + if (inputs.length > 1) { + verticesOperator = inputs[1]; + vertexIdColId = TableUtil.findColIndexWithAssertAndHint(verticesOperator.getColNames(), vertexIdColName); + vertices = verticesOperator.getDataSet().map(new Row2VertexTuple(vertexIdColId)); + graph = Graph.fromTupleDataSet(vertices, edges, + MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment()); + } else { + graph = Graph.fromTupleDataSet(edges, new VertexValueInitializer(), + MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment()); + } + + if (isUndirected) { + graph = graph.getUndirected(); + } + + // Vertex neighbor search algorithm + Graph subgraph = graph; + try { + subgraph = new VertexNeighborSearch(sources, depth).run(graph); + } catch (Exception e) { + e.printStackTrace(); + } + + // Filter and output vertices and edges from original data to keep additional attributes + DataSet inducedEdges = subgraph.getEdgesAsTuple3() + .joinWithHuge(edgesOperator.getDataSet()) + .where(0, 1) + .equalTo(0, 1) + .with(new OnlySecondJoinFunction()); + this.setOutput(inducedEdges, edgesOperator.getSchema()); + Table[] sideOutputs = new Table[1]; + this.setSideOutputTables(new Table[1]); + DataSet inducedVertices = null; + String[] outVerticesColNames = null; + if (verticesOperator != null) { + inducedVertices = subgraph.getVerticesAsTuple2() + .join(verticesOperator.getDataSet()) + .where(0) + .equalTo(0) + .with(new OnlySecondJoinFunction()); + outVerticesColNames = verticesOperator.getColNames(); + sideOutputs[0] = DataSetConversionUtil.toTable(getMLEnvironmentId(), inducedVertices, + verticesOperator.getSchema() + + ); + } else { + // Even there is no input tables for vertices, we still create an output one + inducedVertices = subgraph.getVerticesAsTuple2() + .project(0) + .map(new MapFunction , Row>() { + private static final long serialVersionUID = 599089156563158818L; + + @Override + public Row map(Tuple1 o) { + return Row.of(o.f0); + } + }); + outVerticesColNames = new String[] {"name"}; + vertexIdColName = "name"; + sideOutputs[0] = DataSetConversionUtil.toTable(getMLEnvironmentId(), inducedVertices, + new String[] {"String"}, + new TypeInformation [] {AlinkTypes.STRING}); + } + + this.setSideOutputTables(sideOutputs); + // Construct and write results + if (writer != null) { + /** + * Write all information about edges and vertices in to VizWriter. + * However, if #edges is larger than MAX_NUM_EDGES_TO_ES, + * only first `MAX_NUM_EDGES_TO_ES` edges and no vertices are written. + */ + final DataSet visWriterResult = inducedEdges + .reduceGroup(new FirstReducer <>(MAX_NUM_EDGES_TO_ES)) + .combineGroup(new AllInOneGroupCombineFunction ()) + .cross(inducedVertices.combineGroup(new AllInOneGroupCombineFunction ())) + .with(new Graph2JsonCrossFunction(sources, + edgesOperator.getColNames(), edgeSourceColName, edgeTargetColName, + outVerticesColNames, vertexIdColName, + isUndirected, MAX_NUM_EDGES_TO_ES)) + .map(new VizWriterMapFunction(0, writer)); + DataSetUtil.linkDummySink(visWriterResult); + + // Write meta to VizDataWriter + VizOpMeta meta = new VizOpMeta(); + meta.dataInfos = new VizOpDataInfo[1]; + meta.dataInfos[0] = new VizOpDataInfo(0); + + meta.cascades = new HashMap <>(); + meta.cascades.put( + gson.toJson(new String[] {"图可视化"}), + new VizOpChartData(0)); + + meta.setSchema(edgesOperator.getSchema()); + meta.params = getParams(); + meta.isOutput = false; + meta.opName = "VertexNeighborSearchBatchOp"; + + writer.writeBatchMeta(meta); + } + + return this; + } + + public static class VertexValueInitializer implements MapFunction { + private static final long serialVersionUID = -8771018283053295267L; + + @Override + public Long map(String s) throws Exception { + return 0L; + } + } + + public static class Row2EdgeTuple implements MapFunction > { + private static final long serialVersionUID = -7430905996667254712L; + private int sourceColId; + private int targetColId; + + Row2EdgeTuple(int sourceColId, int targetColId) { + this.sourceColId = sourceColId; + this.targetColId = targetColId; + } + + @Override + public Tuple3 map(Row value) throws Exception { + return Tuple3.of((String) value.getField(sourceColId), (String) value.getField(targetColId), 0L); + } + } + + public static class Row2VertexTuple implements MapFunction > { + private static final long serialVersionUID = -1149958337899075070L; + private int vertexColId; + + Row2VertexTuple(int vertexColId) { + this.vertexColId = vertexColId; + } + + @Override + public Tuple2 map(Row value) throws Exception { + return Tuple2.of((String) value.getField(vertexColId), 0L); + } + } + + public static class OnlySecondJoinFunction implements JoinFunction { + private static final long serialVersionUID = 5961146867726046621L; + + @Override + public IN2 join(IN1 in1, IN2 in2) { + return in2; + } + } + + public static class AllInOneGroupCombineFunction implements GroupCombineFunction > { + private static final long serialVersionUID = -7055580437134926663L; + + @Override + public void combine(Iterable iterable, Collector > collector) { + List list = new ArrayList <>(); + for (T t : iterable) { + list.add(t); + } + collector.collect(list); + } + } + + public static class Graph2JsonCrossFunction + implements CrossFunction , List , String>, AlinkSerializable { + private static final long serialVersionUID = -6978604589815110412L; + private Set selectedVertexIds; + private String[] edgesColNames; + private String edgeSourceColName; + private String edgeTargetColName; + private String[] verticesColNames; + private String vertexIdColName; + private boolean isUndirected; + private int maxNumEdgesToEs; + + private Object[][] edges; + private Object[][] vertices; + + Graph2JsonCrossFunction(Set selectedVertexIds, + String[] edgesColNames, String edgeSourceColName, String edgeTargetColName, + String[] verticesColNames, String vertexIdColName, + boolean isUndirected, int maxNumEdgesToEs) { + this.selectedVertexIds = selectedVertexIds; + this.edgesColNames = edgesColNames; + this.edgeSourceColName = edgeSourceColName; + this.edgeTargetColName = edgeTargetColName; + this.verticesColNames = verticesColNames; + this.vertexIdColName = vertexIdColName; + this.isUndirected = isUndirected; + this.maxNumEdgesToEs = maxNumEdgesToEs; + } + + @Override + public String cross(List edges, List vertices) { + + List edgesList = new ArrayList <>(); + + int counter = 0; + Set relatedVerticesSet = new HashSet <>(edges.size() * 2); + for (Row row : edges) { + Object[] obj = new Object[edgesColNames.length]; + for (int i = 0; i < edgesColNames.length; i += 1) { + obj[i] = row.getField(i); + } + edgesList.add(obj); + relatedVerticesSet.add((String) obj[0]); + relatedVerticesSet.add((String) obj[1]); + counter += 1; + if (counter >= maxNumEdgesToEs) { // do not write too many edges to es + break; + } + } + this.edges = new Object[edgesList.size()][]; + edgesList.toArray(this.edges); + + List verticesList = new ArrayList <>(); + for (Row row : vertices) { + Object[] obj = new Object[verticesColNames.length]; + for (int i = 0; i < verticesColNames.length; i += 1) { + obj[i] = row.getField(i); + } + if (relatedVerticesSet.contains(obj[0])) { + verticesList.add(obj); + } + } + this.vertices = new Object[verticesList.size()][]; + verticesList.toArray(this.vertices); + return gson.toJson(this); + } + } + + public static class VizWriterMapFunction implements MapFunction { + private static final long serialVersionUID = -3562810264038340464L; + int dataId; + VizDataWriterInterface writer; + + VizWriterMapFunction(int dataId, VizDataWriterInterface writer) { + this.dataId = dataId; + this.writer = writer; + } + + @Override + public String map(String s) throws Exception { + writer.writeBatchData(dataId, s, System.currentTimeMillis()); + return s; + } + } + + /** + * This class implements a graph algorithm that returns a sub-graph induced by + * a node list `sources` and their neighbors with at most `depth` hops. + */ + + public static class VertexNeighborSearch implements + GraphAlgorithm > { + + private int depth; + private HashSet sources; + + public VertexNeighborSearch(HashSet sources, int depth) { + this.sources = sources; + this.depth = depth; + } + + @Override + public Graph run(Graph graph) throws Exception { + Graph subgraph = graph + .translateVertexValues(new SetLongMaxValue()) + .runVertexCentricIteration(new VertexNeighborComputeFunction(sources), new MinimumDistanceCombiner(), + depth + 1) + .filterOnVertices(new FilterByValue(depth)); + + DataSet > vertices = subgraph.getVertices(); + return subgraph; + } + + public static final class SetLongMaxValue implements TranslateFunction { + private static final long serialVersionUID = 6439208445273249327L; + + @Override + public Long translate(Long aLong, Long o) { + return Long.MAX_VALUE / 2L; + } + } + + public static final class FilterByValue implements FilterFunction > { + private static final long serialVersionUID = -3443337881858305297L; + private int thresh; + + FilterByValue(int thresh) { + this.thresh = thresh; + } + + @Override + public boolean filter(Vertex vertex) { + return vertex.getValue() <= thresh; + } + } + + public static final class VertexNeighborComputeFunction extends ComputeFunction { + private static final long serialVersionUID = -4927352871373053814L; + private HashSet sources; + + VertexNeighborComputeFunction(HashSet sources) { + this.sources = sources; + } + + @Override + public void compute(Vertex vertex, MessageIterator messages) { + long minDistance = sources.contains(vertex.getId()) ? 0L : Long.MAX_VALUE / 2; + + for (Long msg : messages) { + minDistance = Math.min(minDistance, msg); + } + + if (minDistance < vertex.getValue()) { + setNewVertexValue(minDistance); + for (Edge e : getEdges()) { + sendMessageTo(e.getTarget(), minDistance + 1); + } + } + } + } + + public static final class MinimumDistanceCombiner extends MessageCombiner { + + private static final long serialVersionUID = 1916706983491173310L; + + public void combineMessages(MessageIterator messages) { + + long minMessage = Long.MAX_VALUE / 2; + for (Long msg : messages) { + minMessage = Math.min(minMessage, msg); + } + sendCombinedMessage(minMessage); + } + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeDeepWalkTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeDeepWalkTrainBatchOp.java index fa8817c58..84ec6eb6d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeDeepWalkTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeDeepWalkTrainBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.huge.impl.DeepWalkImpl; import com.alibaba.alink.operator.common.aps.ApsCheckpoint; import com.alibaba.alink.params.huge.HasNumCheckpoint; @NameCn("大规模DeepWalk") +@NameEn("Huge Data DeepWalk") public final class HugeDeepWalkTrainBatchOp extends DeepWalkImpl implements HasNumCheckpoint { private static final long serialVersionUID = 5413242732809242754L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeLabeledWord2VecTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeLabeledWord2VecTrainBatchOp.java index 6a3ab8303..202e350a2 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeLabeledWord2VecTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeLabeledWord2VecTrainBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.huge.impl.LabeledWord2VecImpl; import com.alibaba.alink.operator.common.aps.ApsCheckpoint; import com.alibaba.alink.params.huge.HasNumCheckpoint; @NameCn("大规模带标签的Word2Vec") +@NameEn("Huge Data Labeled Word to Vector") public final class HugeLabeledWord2VecTrainBatchOp extends LabeledWord2VecImpl implements HasNumCheckpoint { private static final long serialVersionUID = -3014286578422196705L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeMetaPath2VecTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeMetaPath2VecTrainBatchOp.java index 3d42579ee..18c016a86 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeMetaPath2VecTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeMetaPath2VecTrainBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.huge.impl.MetaPath2VecImpl; import com.alibaba.alink.operator.common.aps.ApsCheckpoint; import com.alibaba.alink.params.huge.HasNumCheckpoint; @NameCn("大规模MethPath2Vec") +@NameEn("Huge Data MetaPath2Vec") public final class HugeMetaPath2VecTrainBatchOp extends MetaPath2VecImpl implements HasNumCheckpoint { private static final long serialVersionUID = -8398787630956847264L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeNode2VecTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeNode2VecTrainBatchOp.java index 47996f38d..89b8b7aaa 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeNode2VecTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeNode2VecTrainBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.huge.impl.Node2VecImpl; import com.alibaba.alink.operator.common.aps.ApsCheckpoint; import com.alibaba.alink.params.huge.HasNumCheckpoint; @NameCn("大规模Node2Vec") +@NameEn("Huge Node2Vec Training") public final class HugeNode2VecTrainBatchOp extends Node2VecImpl implements HasNumCheckpoint { private static final long serialVersionUID = 4360078150555638432L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeWord2VecTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeWord2VecTrainBatchOp.java index 8619cd65d..d7af83414 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeWord2VecTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/huge/HugeWord2VecTrainBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.huge.impl.Word2VecImpl; import com.alibaba.alink.operator.common.aps.ApsCheckpoint; import com.alibaba.alink.params.huge.HasNumCheckpoint; @NameCn("大规模Word2Vec") +@NameEn("Huge Word2Vec Training") public final class HugeWord2VecTrainBatchOp extends Word2VecImpl implements HasNumCheckpoint { private static final long serialVersionUID = -1222790480709681729L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/DeepWalkImpl.java b/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/DeepWalkImpl.java index 029f198f3..bd2407171 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/DeepWalkImpl.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/DeepWalkImpl.java @@ -15,7 +15,6 @@ import com.alibaba.alink.operator.batch.graph.RandomWalkBatchOp; import com.alibaba.alink.operator.batch.sql.JoinBatchOp; import com.alibaba.alink.operator.common.aps.ApsCheckpoint; -import com.alibaba.alink.operator.common.graph.GraphEmbedding; import com.alibaba.alink.params.nlp.DeepWalkParams; @InputPorts(values = @PortSpec(value = PortType.DATA, desc = PortDesc.GRAPH)) diff --git a/core/src/main/java/com/alibaba/alink/operator/common/graph/GraphEmbedding.java b/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/GraphEmbedding.java similarity index 98% rename from core/src/main/java/com/alibaba/alink/operator/common/graph/GraphEmbedding.java rename to core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/GraphEmbedding.java index ac2fd7181..dea849321 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/graph/GraphEmbedding.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/GraphEmbedding.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.operator.common.graph; +package com.alibaba.alink.operator.batch.huge.impl; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.operators.base.JoinOperatorBase; @@ -11,7 +11,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.TableSourceBatchOp; @@ -23,7 +23,7 @@ import com.alibaba.alink.params.nlp.walk.HasVertexCol; import com.alibaba.alink.params.shared.colname.HasWeightColDefaultAsNull; -public class GraphEmbedding { +class GraphEmbedding { public static final String SOURCE_COL = "sourcecol"; public static final String TARGET_COL = "targetcol"; public static final String WEIGHT_COL = "weightcol"; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/MetaPath2VecImpl.java b/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/MetaPath2VecImpl.java index 30d84675f..ce1915b33 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/MetaPath2VecImpl.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/MetaPath2VecImpl.java @@ -16,7 +16,6 @@ import com.alibaba.alink.operator.batch.graph.RandomWalkBatchOp; import com.alibaba.alink.operator.batch.sql.JoinBatchOp; import com.alibaba.alink.operator.common.aps.ApsCheckpoint; -import com.alibaba.alink.operator.common.graph.GraphEmbedding; import com.alibaba.alink.params.nlp.MetaPath2VecParams; @InputPorts(values = { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/Node2VecImpl.java b/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/Node2VecImpl.java index e4017eda2..ff4c75381 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/Node2VecImpl.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/Node2VecImpl.java @@ -15,7 +15,6 @@ import com.alibaba.alink.operator.batch.graph.Node2VecWalkBatchOp; import com.alibaba.alink.operator.batch.sql.JoinBatchOp; import com.alibaba.alink.operator.common.aps.ApsCheckpoint; -import com.alibaba.alink.operator.common.graph.GraphEmbedding; import com.alibaba.alink.params.nlp.Node2VecParams; @InputPorts(values = @PortSpec(value = PortType.DATA, desc = PortDesc.GRAPH)) diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/Word2VecImpl.java b/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/Word2VecImpl.java index f4f536125..23e4a1afd 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/Word2VecImpl.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/huge/impl/Word2VecImpl.java @@ -36,7 +36,7 @@ import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.VectorUtil; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.huge.word2vec.ApsIteratorW2V; import com.alibaba.alink.operator.batch.huge.word2vec.ApsSerializeDataW2V; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/image/ReadImageToTensorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/image/ReadImageToTensorBatchOp.java index 58bb3cd73..e27db808f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/image/ReadImageToTensorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/image/ReadImageToTensorBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -11,6 +12,7 @@ @ParamSelectColumnSpec(name = "relativeFilePathCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("图片转张量") +@NameEn("Read Image To Tensor") public class ReadImageToTensorBatchOp extends MapBatchOp implements ReadImageToTensorParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/image/WriteTensorToImageBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/image/WriteTensorToImageBatchOp.java index 71c20aede..5e46c990d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/image/WriteTensorToImageBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/image/WriteTensorToImageBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -12,6 +13,7 @@ @ParamSelectColumnSpec(name = "tensorCol", allowedTypeCollections = TypeCollections.TENSOR_TYPES) @ParamSelectColumnSpec(name = "relativeFilePathCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("张量转图片") +@NameEn("Write Tensor To Image") public class WriteTensorToImageBatchOp extends MapBatchOp implements WriteTensorToImageParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/BertTextEmbeddingBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/BertTextEmbeddingBatchOp.java index e85c861b7..8df908dd9 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/BertTextEmbeddingBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/BertTextEmbeddingBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -14,6 +15,7 @@ */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("Bert文本嵌入") +@NameEn("Bert Text Embedding") public class BertTextEmbeddingBatchOp extends MapBatchOp implements BertTextEmbeddingParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocCountVectorizerPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocCountVectorizerPredictBatchOp.java index 7f3401cd9..268ef8973 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocCountVectorizerPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocCountVectorizerPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -15,6 +16,7 @@ */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("文本特征生成预测") +@NameEn("Doc Count Vectorizer Prediction") public final class DocCountVectorizerPredictBatchOp extends ModelMapBatchOp implements DocCountVectorizerPredictParams { private static final long serialVersionUID = 2584222216856311012L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocCountVectorizerTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocCountVectorizerTrainBatchOp.java index 4bd3b6ae1..db339f7a3 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocCountVectorizerTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocCountVectorizerTrainBatchOp.java @@ -15,6 +15,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; @@ -30,6 +31,7 @@ import com.alibaba.alink.operator.common.nlp.NLPConstant; import com.alibaba.alink.params.nlp.DocCountVectorizerTrainParams; import com.alibaba.alink.params.nlp.DocHashCountVectorizerTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import java.util.ArrayList; import java.util.List; @@ -44,6 +46,8 @@ @OutputPorts(values = {@PortSpec(value = PortType.MODEL)}) @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("文本特征生成训练") +@NameEn("Doc Count Vectorizer Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.nlp.DocCountVectorizer") public final class DocCountVectorizerTrainBatchOp extends BatchOperator implements DocCountVectorizerTrainParams { private static final String WORD_COL_NAME = "word"; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocHashCountVectorizerPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocHashCountVectorizerPredictBatchOp.java index ccb43591c..de8f0b28a 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocHashCountVectorizerPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocHashCountVectorizerPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -15,6 +16,7 @@ */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("文本哈希特征生成预测") +@NameEn("DocHash Count Vectorizer Prediction") public class DocHashCountVectorizerPredictBatchOp extends ModelMapBatchOp implements DocHashCountVectorizerPredictParams { private static final long serialVersionUID = -6029385456358959482L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocHashCountVectorizerTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocHashCountVectorizerTrainBatchOp.java index 7d2a1f6ac..28b37f0be 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocHashCountVectorizerTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocHashCountVectorizerTrainBatchOp.java @@ -12,6 +12,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; @@ -23,6 +24,7 @@ import com.alibaba.alink.operator.common.nlp.DocHashCountVectorizerModelDataConverter; import com.alibaba.alink.operator.common.nlp.NLPConstant; import com.alibaba.alink.params.nlp.DocHashCountVectorizerTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import java.util.HashMap; import java.util.Iterator; @@ -39,6 +41,8 @@ @OutputPorts(values = {@PortSpec(value = PortType.MODEL)}) @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("文本哈希特征生成训练") +@NameEn("DocHash Count Vectorizer Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.nlp.DocHashCountVectorizer") public class DocHashCountVectorizerTrainBatchOp extends BatchOperator implements DocHashCountVectorizerTrainParams { private static final long serialVersionUID = 6469196128919853279L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocWordCountBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocWordCountBatchOp.java index ee2f7aaa3..5b0c04672 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocWordCountBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/DocWordCountBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -22,6 +23,7 @@ @ParamSelectColumnSpec(name = "docIdCol") @ParamSelectColumnSpec(name = "contentCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("文本词频统计") +@NameEn("Document Word Count") public final class DocWordCountBatchOp extends BatchOperator implements DocWordCountParams { private static final long serialVersionUID = 4163509124304798730L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/KeywordsExtractionBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/KeywordsExtractionBatchOp.java index e269b81d1..9934e7a93 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/KeywordsExtractionBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/KeywordsExtractionBatchOp.java @@ -13,6 +13,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -20,7 +21,7 @@ import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.OutputColsHelper; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; @@ -43,6 +44,7 @@ @OutputPorts(values = @PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)) @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPE) @NameCn("关键词抽取") +@NameEn("Keywords Extraction") public final class KeywordsExtractionBatchOp extends BatchOperator implements KeywordsExtractionParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/NGramBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/NGramBatchOp.java index 867c805d1..cceb507ff 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/NGramBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/NGramBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -17,6 +18,7 @@ */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("NGram") +@NameEn("NGram") public class NGramBatchOp extends MapBatchOp implements NGramParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/RegexTokenizerBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/RegexTokenizerBatchOp.java index 67e5f2c25..6d27fe3e0 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/RegexTokenizerBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/RegexTokenizerBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -15,6 +16,7 @@ */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("RegexTokenizer") +@NameEn("RegexTokenizer") public final class RegexTokenizerBatchOp extends MapBatchOp implements RegexTokenizerParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/SegmentBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/SegmentBatchOp.java index b09d6d4ca..d84292044 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/SegmentBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/SegmentBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -14,6 +15,7 @@ */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("分词") +@NameEn("Segment") public final class SegmentBatchOp extends MapBatchOp implements SegmentParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/StopWordsRemoverBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/StopWordsRemoverBatchOp.java index a941a40a1..5cf661daa 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/StopWordsRemoverBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/StopWordsRemoverBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -14,6 +15,7 @@ */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("停用词过滤") +@NameEn("StopWordsRemover") public final class StopWordsRemoverBatchOp extends MapBatchOp implements StopWordsRemoverParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/TfidfBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/TfidfBatchOp.java index 0fc94baf5..43e6958f2 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/TfidfBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/TfidfBatchOp.java @@ -10,6 +10,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -31,6 +32,7 @@ @ParamSelectColumnSpec(name = "countCol", allowedTypeCollections = TypeCollections.LONG_TYPES) @ParamSelectColumnSpec(name = "docIdCol") @NameCn("TF-IDF") +@NameEn("Tfidf") public final class TfidfBatchOp extends BatchOperator implements TfIdfParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/TokenizerBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/TokenizerBatchOp.java index 3fab44f40..8e21a8164 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/TokenizerBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/TokenizerBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -14,6 +15,7 @@ */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("文本分解") +@NameEn("Tokenizer") public final class TokenizerBatchOp extends MapBatchOp implements TokenizerParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/Word2VecPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/Word2VecPredictBatchOp.java index fd71f2b23..d7de53f32 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/Word2VecPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/Word2VecPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -23,6 +24,7 @@ */ @ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("Word2Vec预测") +@NameEn("Word2Vec Prediction") public class Word2VecPredictBatchOp extends ModelMapBatchOp implements Word2VecPredictParams { private static final long serialVersionUID = 1415195739005424277L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/Word2VecTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/Word2VecTrainBatchOp.java index e6cf460d8..ae389d283 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/Word2VecTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/Word2VecTrainBatchOp.java @@ -26,6 +26,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; @@ -39,9 +40,9 @@ import com.alibaba.alink.common.comqueue.communication.AllReduce; import com.alibaba.alink.common.io.directreader.DefaultDistributedInfo; import com.alibaba.alink.common.io.directreader.DistributedInfo; -import com.alibaba.alink.common.lazy.WithTrainInfo; +import com.alibaba.alink.operator.batch.utils.WithTrainInfo; import com.alibaba.alink.common.linalg.DenseVector; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.ExpTableArray; import com.alibaba.alink.common.utils.RowUtil; import com.alibaba.alink.operator.batch.BatchOperator; @@ -51,6 +52,7 @@ import com.alibaba.alink.operator.common.nlp.WordCountUtil; import com.alibaba.alink.params.nlp.Word2VecTrainParams; import com.alibaba.alink.params.shared.tree.HasSeed; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -83,6 +85,8 @@ allowedTypeCollections = TypeCollections.STRING_TYPES ) @NameCn("Word2Vec训练") +@NameEn("Word2Vec Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.nlp.Word2Vec") public class Word2VecTrainBatchOp extends BatchOperator implements Word2VecTrainParams , WithTrainInfo { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/WordCountBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/WordCountBatchOp.java index bd64ba027..443d59785 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/WordCountBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/WordCountBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -22,6 +23,7 @@ @OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("单词计数") +@NameEn("WordCount") public final class WordCountBatchOp extends BatchOperator implements WordCountParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/outlier/IForestModelOutlierPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/outlier/IForestModelOutlierPredictBatchOp.java index 03bad1d03..e24dab64a 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/outlier/IForestModelOutlierPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/outlier/IForestModelOutlierPredictBatchOp.java @@ -3,10 +3,12 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.outlier.BaseModelOutlierPredictBatchOp; import com.alibaba.alink.operator.common.outlier.IForestModelDetector; @NameCn("IForest模型异常检测预测") +@NameEn("IForest Model Outlier Predict") public class IForestModelOutlierPredictBatchOp extends BaseModelOutlierPredictBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/outlier/IForestModelOutlierTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/outlier/IForestModelOutlierTrainBatchOp.java index c9c21351c..c89b32465 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/outlier/IForestModelOutlierTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/outlier/IForestModelOutlierTrainBatchOp.java @@ -37,6 +37,7 @@ import com.alibaba.alink.params.outlier.WithMultiVarParams; import com.alibaba.alink.params.dataproc.HasTargetType.TargetType; import com.alibaba.alink.params.dataproc.NumericalTypeCastParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import org.apache.commons.lang3.SerializationUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,6 +55,7 @@ @OutputPorts(values = @PortSpec(value = PortType.MODEL)) @NameCn("IForest模型异常检测训练") @NameEn("IForest model outlier") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.outlier.IForestModelOutlier") public class IForestModelOutlierTrainBatchOp extends BatchOperator implements IForestTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/outlier/OcsvmModelOutlierPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/outlier/OcsvmModelOutlierPredictBatchOp.java new file mode 100644 index 000000000..83a8af164 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/outlier/OcsvmModelOutlierPredictBatchOp.java @@ -0,0 +1,23 @@ +package com.alibaba.alink.operator.batch.outlier; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.common.outlier.BaseModelOutlierPredictBatchOp; +import com.alibaba.alink.operator.common.outlier.OcsvmModelDetector; +@NameCn("One Class SVM异常检测模型预测") +@NameEn("Ocsvm outlier model predict") +public final class OcsvmModelOutlierPredictBatchOp + extends BaseModelOutlierPredictBatchOp { + + private static final long serialVersionUID = 7075220963340722343L; + + public OcsvmModelOutlierPredictBatchOp() { + this(new Params()); + } + + public OcsvmModelOutlierPredictBatchOp(Params params) { + super(OcsvmModelDetector::new, params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/outlier/OcsvmModelOutlierTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/outlier/OcsvmModelOutlierTrainBatchOp.java new file mode 100644 index 000000000..310c945e6 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/outlier/OcsvmModelOutlierTrainBatchOp.java @@ -0,0 +1,314 @@ +package com.alibaba.alink.operator.batch.outlier; + +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.MLEnvironmentFactory; +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.SparseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.linalg.VectorUtil; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.outlier.OcsvmModelData; +import com.alibaba.alink.operator.common.outlier.OcsvmModelData.SvmModelData; +import com.alibaba.alink.operator.common.outlier.OcsvmModelDataConverter; +import com.alibaba.alink.params.outlier.OcsvmModelTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import static com.alibaba.alink.operator.common.outlier.OcsvmKernel.svmTrain; + +/** + * One class SVM algorithm. we use the libsvm package (https://www.csie.ntu.edu.tw/~cjlin/libsvm/) to solve one class + * svm problem with one thread, and using bagging algo to improve the scale of dataset. + * + * @author weibo zhao + */ +@InputPorts(values = @PortSpec(value = PortType.DATA)) +@OutputPorts(values = @PortSpec(value = PortType.MODEL)) +@NameCn("One Class SVM异常检测模型训练") +@NameEn("Ocsvm outlier model train") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.outlier.OcsvmModelOutlier") +public final class OcsvmModelOutlierTrainBatchOp extends BatchOperator + implements OcsvmModelTrainParams { + + private static final long serialVersionUID = 6727016080849088600L; + + public OcsvmModelOutlierTrainBatchOp() { + super(new Params()); + } + + public OcsvmModelOutlierTrainBatchOp(Params params) { + super(params); + } + + @Override + public OcsvmModelOutlierTrainBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + + String[] featureColNames = getFeatureCols(); + String tensorColName = getVectorCol(); + if ("".equals(tensorColName)) { + tensorColName = null; + } + if (featureColNames != null && featureColNames.length == 0) { + featureColNames = null; + } + final double nu = getNu(); + DataSet data; + if (tensorColName != null || featureColNames == null) { + // select feature data + data = in.select(tensorColName).getDataSet(); + } else { + int[] featureIndices = new int[featureColNames.length]; + for (int i = 0; i < featureColNames.length; ++i) { + featureIndices[i] = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), + featureColNames[i]); + } + data = in.getDataSet().map(new SelectFeat(featureIndices)); + } + + DataSet bNumber = + data.mapPartition(new CalculateNum(featureColNames, tensorColName)) + .reduce(new ReduceFunction >() { + private static final long serialVersionUID = 2307240714136503892L; + + @Override + public Tuple3 reduce(Tuple3 t1, + Tuple3 t2) { + return Tuple3.of(t1.f0 + t2.f0, Math.max(t1.f1, t2.f1), t2.f2); + } + }).map(new RichMapFunction , Integer>() { + @Override + public Integer map(Tuple3 num) { + if (num.f1 < 10) { + return Math.max(num.f2, (int) Math.ceil(num.f0 * num.f1 * nu / 20000.0)); + } else if (num.f1 < 100 && num.f1 > 10) { // if feature length < 100 + return Math.max(num.f2, (int) Math.ceil(num.f0 * num.f1 * nu / 100000.0)); + } else { + return Math.max(num.f2, (int) Math.ceil(num.f0 * nu / 1000.0)); + } + } + }); + + // append the key to dataOp for groupby + DataSet > trainData + = data.mapPartition(new RichMapPartitionFunction >() { + private int bNumber; + private final Random rand = new Random(); + + @Override + public void open(Configuration parameters) { + bNumber = (Integer) getRuntimeContext().getBroadcastVariable("bNumber").get(0); + } + + @Override + public void mapPartition(Iterable rows, Collector > out) { + for (Row row : rows) { + Integer key = this.rand.nextInt(this.bNumber); + out.collect(Tuple2.of(key, row)); + } + } + }).withBroadcastSet(bNumber, "bNumber"); + + // train + DataSet > models = trainData.groupBy(0) + .reduceGroup(new TrainSvm(getParams())); + + // transform the models + DataSet model = models + .mapPartition(new Transform(getParams())) + .withBroadcastSet(bNumber, "bNumber") + .setParallelism(1); + + this.setOutput(model, new OcsvmModelDataConverter().getModelSchema()); + + return this; + } + + public static class SelectFeat implements MapFunction { + private static final long serialVersionUID = 331016784088329722L; + private final int[] featureIndices; + + public SelectFeat(int[] featureIndices) { + this.featureIndices = featureIndices; + } + + @Override + public Row map(Row value) throws Exception { + Row ret = new Row(featureIndices.length); + for (int i = 0; i < featureIndices.length; ++i) { + ret.setField(i, ((Number) value.getField(featureIndices[i])).doubleValue()); + } + return ret; + } + } + + public static class Transform extends AbstractRichFunction + implements MapPartitionFunction , Row> { + private static final long serialVersionUID = -8875298030671722207L; + private final String[] featureColNames; + private int baggingNumber; + private final KernelType kernelType; + private final int degree; + private double gamma; + private final double coef0; + private final String vectorCol; + + public Transform(Params params) { + this.featureColNames = params.get(OcsvmModelTrainParams.FEATURE_COLS); + this.kernelType = params.get(OcsvmModelTrainParams.KERNEL_TYPE); + this.degree = params.get(OcsvmModelTrainParams.DEGREE); + this.coef0 = params.get(OcsvmModelTrainParams.COEF0); + this.vectorCol = params.get(OcsvmModelTrainParams.VECTOR_COL); + } + + @Override + public void open(Configuration parameters) throws Exception { + baggingNumber = (Integer) getRuntimeContext() + .getBroadcastVariable("bNumber").get(0); + } + + @Override + public void mapPartition(Iterable > iterable, Collector collector) throws Exception { + List models = new ArrayList <>(); + int size = 0; + for (Tuple2 model : iterable) { + models.add(model.f1); + gamma = model.f0; + size++; + } + + SvmModelData[] modelArray = new SvmModelData[size]; + for (int i = 0; i < size; ++i) { + modelArray[i] = models.get(i); + } + if (modelArray.length != 0) { + OcsvmModelData ocsvmModelData = new OcsvmModelData(); + ocsvmModelData.models = modelArray; + ocsvmModelData.featureColNames = featureColNames; + ocsvmModelData.baggingNumber = baggingNumber; + ocsvmModelData.kernelType = kernelType; + ocsvmModelData.coef0 = coef0; + ocsvmModelData.degree = degree; + ocsvmModelData.gamma = gamma; + ocsvmModelData.vectorCol = vectorCol; + new OcsvmModelDataConverter().save(ocsvmModelData, collector); + + } + } + } + + public static class TrainSvm implements GroupReduceFunction , Tuple2> { + private static final long serialVersionUID = -2783415250850319839L; + private final Params param; + private final String tensorColName; + + TrainSvm(Params param) { + this.param = param; + this.tensorColName = param.get(OcsvmModelTrainParams.VECTOR_COL); + } + + @Override + public void reduce(Iterable > its, + Collector > collector) throws Exception { + List vy = new ArrayList <>(); + List vectors = new ArrayList <>(); + int maxIndex = 0; + int numRows = 0; + if (this.tensorColName != null) { + for (Tuple2 it : its) { + numRows++; + vy.add(0.0); + Object obj = it.f1.getField(0); + vectors.add(VectorUtil.getVector(obj)); + } + } else { + for (Tuple2 it : its) { + numRows++; + vy.add(0.0); + int size = it.f1.getArity(); + maxIndex = size; + Vector vec = new DenseVector(size); + for (int i = 0; i < size; ++i) { + vec.set(i, ((Number) it.f1.getField(i)).doubleValue()); + } + vectors.add(vec); + } + } + if (numRows > 0) { + Vector[] sample = new Vector[vy.size()]; + for (int i = 0; i < vy.size(); i++) { + sample[i] = vectors.get(i); + } + + if (Math.abs(param.get(OcsvmModelTrainParams.GAMMA)) < 1.0e-18 && maxIndex > 0) { + param.set(OcsvmModelTrainParams.GAMMA, 1.0 / maxIndex); + } + SvmModelData model = svmTrain(sample, param); + collector.collect(Tuple2.of(1.0/maxIndex, model)); + } + } + } + + public static class CalculateNum extends RichMapPartitionFunction > { + private static final long serialVersionUID = -679835005763383100L; + private final String[] featureColnames; + private final String tensorColName; + + public CalculateNum(String[] featureColnames, String tensorColName) { + this.featureColnames = featureColnames; + this.tensorColName = tensorColName; + } + + @Override + public void mapPartition(Iterable values, Collector > out) + throws Exception { + int parallel = getRuntimeContext().getNumberOfParallelSubtasks(); + int count = 0; + int featureLen = -1; + for (Row row : values) { + count++; + if (tensorColName != null) { + Object obj = row.getField(0); + Vector vec = VectorUtil.getVector(obj); + if (vec instanceof SparseVector) { + + int[] indices = ((SparseVector) vec).getIndices(); + featureLen = Math.max(featureLen, + (indices.length == 0 ? -1 : indices[indices.length - 1]) + 1); + } else { + featureLen = vec.size(); + } + } else { + featureLen = this.featureColnames.length; + } + } + out.collect(Tuple3.of(count, featureLen, parallel)); + } + } +} + + diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/outlier/OcsvmOutlier4GroupedDataBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/outlier/OcsvmOutlier4GroupedDataBatchOp.java index 945f946ef..ed7e8775f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/outlier/OcsvmOutlier4GroupedDataBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/outlier/OcsvmOutlier4GroupedDataBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.outlier.BaseOutlier4GroupedDataBatchOp; import com.alibaba.alink.operator.common.outlier.OcsvmDetector; import com.alibaba.alink.params.outlier.OcsvmDetectorParams; @NameCn("One-Class SVM分组异常检测") +@NameEn("One-Class Outlier For Grouped Data") public class OcsvmOutlier4GroupedDataBatchOp extends BaseOutlier4GroupedDataBatchOp implements OcsvmDetectorParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/outlier/SosOutlierBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/outlier/SosOutlierBatchOp.java index 3452b822e..6c44f2be7 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/outlier/SosOutlierBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/outlier/SosOutlierBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.outlier.BaseOutlierBatchOp; import com.alibaba.alink.operator.common.outlier.SosDetector; import com.alibaba.alink.params.outlier.SosDetectorParams; @NameCn("SOS 异常检测") +@NameEn("Sos Outlier") public class SosOutlierBatchOp extends BaseOutlierBatchOp implements SosDetectorParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/pytorch/TorchModelPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/pytorch/TorchModelPredictBatchOp.java index 17d11fb3d..b44897710 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/pytorch/TorchModelPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/pytorch/TorchModelPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.pytorch.TorchModelPredictMapper; import com.alibaba.alink.params.dl.TorchModelPredictParams; @@ -11,6 +12,7 @@ * This operator loads TorchScript model and do predictions. */ @NameCn("PyTorch模型预测") +@NameEn("Torch Model Prediction") public class TorchModelPredictBatchOp extends MapBatchOp implements TorchModelPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsImplicitTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsImplicitTrainBatchOp.java index b2a438f38..132e19c19 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsImplicitTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsImplicitTrainBatchOp.java @@ -7,6 +7,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -52,6 +53,7 @@ @ParamSelectColumnSpec(name = "rateCol", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("ALS隐式训练") +@NameEn("Implicit Als Training") public final class AlsImplicitTrainBatchOp extends BatchOperator implements AlsImplicitTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsItemsPerUserRecommBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsItemsPerUserRecommBatchOp.java index 45a4843c3..9a2077920 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsItemsPerUserRecommBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsItemsPerUserRecommBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.recommendation.AlsRecommKernel; import com.alibaba.alink.operator.common.recommendation.RecommType; import com.alibaba.alink.params.recommendation.BaseItemsPerUserRecommParams; @@ -11,6 +12,7 @@ * This op recommend items for user with als model. */ @NameCn("ALS:ItemsPerUser推荐") +@NameEn("Als Items Per User Recommend") public class AlsItemsPerUserRecommBatchOp extends BaseRecommBatchOp implements BaseItemsPerUserRecommParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsModelInfoBatchOp.java index 9e0440d5c..76be47919 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsModelInfoBatchOp.java @@ -10,10 +10,10 @@ import org.apache.flink.types.Row; import org.apache.flink.util.Collector; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.common.lazy.LazyEvaluation; import com.alibaba.alink.common.lazy.LazyObjectsManager; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.BaseSourceBatchOp; import com.alibaba.alink.operator.common.recommendation.AlsModelInfo; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsRateRecommBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsRateRecommBatchOp.java index df27605c3..d1a3e687d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsRateRecommBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsRateRecommBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.recommendation.AlsRecommKernel; import com.alibaba.alink.operator.common.recommendation.RecommType; import com.alibaba.alink.params.recommendation.BaseRateRecommParams; @@ -11,6 +12,7 @@ * this op rating user item pair with als model. */ @NameCn("ALS:打分推荐推荐") +@NameEn("Als Rate Recommend") public class AlsRateRecommBatchOp extends BaseRecommBatchOp implements BaseRateRecommParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsSimilarItemsRecommBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsSimilarItemsRecommBatchOp.java index feaf1dec5..21be974db 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsSimilarItemsRecommBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsSimilarItemsRecommBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.recommendation.AlsRecommKernel; import com.alibaba.alink.operator.common.recommendation.RecommType; import com.alibaba.alink.params.recommendation.BaseSimilarItemsRecommParams; @@ -11,6 +12,7 @@ * Recommend similar items for the given item. */ @NameCn("ALS:相似items推荐") +@NameEn("ALS Similar Items Recommend") public class AlsSimilarItemsRecommBatchOp extends BaseRecommBatchOp implements BaseSimilarItemsRecommParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsSimilarUsersRecommBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsSimilarUsersRecommBatchOp.java index 29bbbdb5b..aafc1f956 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsSimilarUsersRecommBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsSimilarUsersRecommBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.recommendation.AlsRecommKernel; import com.alibaba.alink.operator.common.recommendation.RecommType; import com.alibaba.alink.params.recommendation.BaseSimilarUsersRecommParams; @@ -11,6 +12,7 @@ * Recommend similar items for the given item. */ @NameCn("ALS:相似users推荐") +@NameEn("ALS Similar Users Recommend") public class AlsSimilarUsersRecommBatchOp extends BaseRecommBatchOp implements BaseSimilarUsersRecommParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsTrainBatchOp.java index 646f53b47..b38eb5813 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsTrainBatchOp.java @@ -7,6 +7,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -14,7 +15,7 @@ import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.recommendation.AlsModelInfo; import com.alibaba.alink.operator.common.recommendation.HugeMfAlsImpl; @@ -54,6 +55,7 @@ @ParamSelectColumnSpec(name = "rateCol", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("ALS训练") +@NameEn("ALS Training") public final class AlsTrainBatchOp extends BatchOperator implements AlsTrainParams , diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsUsersPerItemRecommBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsUsersPerItemRecommBatchOp.java index a405b5aa3..0dcd5806f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsUsersPerItemRecommBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/AlsUsersPerItemRecommBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.recommendation.AlsRecommKernel; import com.alibaba.alink.operator.common.recommendation.RecommType; import com.alibaba.alink.params.recommendation.BaseUsersPerItemRecommParams; @@ -11,6 +12,7 @@ * This op recommend users for item with als model. */ @NameCn("ALS:UsersPerItem推荐") +@NameEn("ALS Users Per Item Recommend") public class AlsUsersPerItemRecommBatchOp extends BaseRecommBatchOp implements BaseUsersPerItemRecommParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FlattenKObjectBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FlattenKObjectBatchOp.java index c99e993f8..e229d89dd 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FlattenKObjectBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FlattenKObjectBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.FlatMapBatchOp; @@ -15,6 +16,7 @@ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("展开KObject") +@NameEn("Flatten KObject") public class FlattenKObjectBatchOp extends FlatMapBatchOp implements FlattenKObjectParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmItemsPerUserRecommBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmItemsPerUserRecommBatchOp.java index 58eb3c29d..15aec344c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmItemsPerUserRecommBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmItemsPerUserRecommBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.recommendation.FmRecommKernel; import com.alibaba.alink.operator.common.recommendation.RecommType; import com.alibaba.alink.params.recommendation.BaseItemsPerUserRecommParams; @@ -11,6 +12,7 @@ * Fm recommendation batch op for recommending items to user. */ @NameCn("FM:ItemsPerUser推荐") +@NameEn("FM Items Per User Recommend") public class FmItemsPerUserRecommBatchOp extends BaseRecommBatchOp implements BaseItemsPerUserRecommParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmRateRecommBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmRateRecommBatchOp.java index 19cde604b..da8a01f24 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmRateRecommBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmRateRecommBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.recommendation.FmRecommKernel; import com.alibaba.alink.operator.common.recommendation.RecommType; import com.alibaba.alink.params.recommendation.BaseRateRecommParams; @@ -11,6 +12,7 @@ * Fm rating batch op for recommendation. */ @NameCn("FM:打分推荐") +@NameEn("FM:Rate Recommend") public class FmRateRecommBatchOp extends BaseRecommBatchOp implements BaseRateRecommParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmRecommBinaryImplicitTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmRecommBinaryImplicitTrainBatchOp.java index 7e9c2406b..45227ef23 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmRecommBinaryImplicitTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmRecommBinaryImplicitTrainBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -41,6 +42,7 @@ @ParamSelectColumnSpec(name = "itemCategoricalFeatureCols", portIndices = 2) @NameCn("FM二分类隐式训练") +@NameEn("Fm Recommend Binary Implicit Training") public final class FmRecommBinaryImplicitTrainBatchOp extends BatchOperator implements FmRecommBinaryImplicitTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmRecommTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmRecommTrainBatchOp.java index 28134b53d..a8e76439c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmRecommTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmRecommTrainBatchOp.java @@ -5,6 +5,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -57,6 +58,7 @@ @ParamSelectColumnSpec(name = "itemCategoricalFeatureCols", portIndices = 2) @NameCn("FM推荐训练") +@NameEn("Fm Recommend Training") public final class FmRecommTrainBatchOp extends BatchOperator implements FmRecommTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmUsersPerItemRecommBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmUsersPerItemRecommBatchOp.java index dcc81836b..60e9044c3 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmUsersPerItemRecommBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/FmUsersPerItemRecommBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.recommendation.FmRecommKernel; import com.alibaba.alink.operator.common.recommendation.RecommType; import com.alibaba.alink.params.recommendation.BaseUsersPerItemRecommParams; @@ -11,6 +12,7 @@ * Fm recommendation batch op for recommending users to item. */ @NameCn("FM:UsersPerItem推荐") +@NameEn("Fm Users Per Item Recommend") public class FmUsersPerItemRecommBatchOp extends BaseRecommBatchOp implements BaseUsersPerItemRecommParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfItemsPerUserRecommBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfItemsPerUserRecommBatchOp.java index aad3519ce..74548dc27 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfItemsPerUserRecommBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfItemsPerUserRecommBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.recommendation.ItemCfRecommKernel; import com.alibaba.alink.operator.common.recommendation.RecommType; import com.alibaba.alink.params.recommendation.BaseItemsPerUserRecommParams; @@ -11,6 +12,7 @@ * Recommend similar items for the given item. */ @NameCn("ItemCf:ItemsPerUser推荐") +@NameEn("ItemCf Items Per User Recommend") public class ItemCfItemsPerUserRecommBatchOp extends BaseRecommBatchOp implements BaseItemsPerUserRecommParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfModelInfoBatchOp.java index 628a23a22..0d17cf24d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.recommendation.ItemCfModelInfo; import java.util.List; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfRateRecommBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfRateRecommBatchOp.java index f4a4816f8..0fce6cff0 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfRateRecommBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfRateRecommBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.recommendation.ItemCfRecommKernel; import com.alibaba.alink.operator.common.recommendation.RecommType; import com.alibaba.alink.params.recommendation.BaseRateRecommParams; @@ -11,6 +12,7 @@ * Rating for user-item pair with item CF model. */ @NameCn("ItemCf:打分推荐") +@NameEn("ItemCf RateRecommend") public class ItemCfRateRecommBatchOp extends BaseRecommBatchOp implements BaseRateRecommParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfSimilarItemsRecommBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfSimilarItemsRecommBatchOp.java index 8b8d23694..281e026c4 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfSimilarItemsRecommBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfSimilarItemsRecommBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.recommendation.ItemCfRecommKernel; import com.alibaba.alink.operator.common.recommendation.RecommType; import com.alibaba.alink.params.recommendation.BaseSimilarItemsRecommParams; @@ -11,6 +12,7 @@ * Recommend similar items for the given item. */ @NameCn("ItemCf:相似items推荐") +@NameEn("ItemCf Similar Items Recommend") public class ItemCfSimilarItemsRecommBatchOp extends BaseRecommBatchOp implements BaseSimilarItemsRecommParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfTrainBatchOp.java index 62389e3e1..f05e396df 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfTrainBatchOp.java @@ -12,9 +12,10 @@ import org.apache.flink.types.Row; import org.apache.flink.util.Collector; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; @@ -23,7 +24,7 @@ import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.exceptions.AkPreconditions; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; @@ -58,6 +59,7 @@ @ParamSelectColumnSpec(name = "rateCol", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("ItemCf训练") +@NameEn("ItemCf Training") public class ItemCfTrainBatchOp extends BatchOperator implements ItemCfRecommTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfUsersPerItemRecommBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfUsersPerItemRecommBatchOp.java index 627cad509..dcd45841c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfUsersPerItemRecommBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/ItemCfUsersPerItemRecommBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.recommendation.ItemCfRecommKernel; import com.alibaba.alink.operator.common.recommendation.RecommType; import com.alibaba.alink.params.recommendation.BaseUsersPerItemRecommParams; @@ -11,6 +12,7 @@ * Recommend similar users for the given item. */ @NameCn("ItemCf:UsersPerItem推荐") +@NameEn("ItemCf Users Per Item Recommend") public class ItemCfUsersPerItemRecommBatchOp extends BaseRecommBatchOp implements BaseUsersPerItemRecommParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/LeaveKObjectOutBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/LeaveKObjectOutBatchOp.java index b4350dcff..991142f13 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/LeaveKObjectOutBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/LeaveKObjectOutBatchOp.java @@ -11,11 +11,12 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.DataSetWrapperBatchOp; @@ -40,6 +41,7 @@ @ParamSelectColumnSpec(name = "groupCol") @ParamSelectColumnSpec(name = "objectCol") @NameCn("推荐结果采样处理") +@NameEn("Leave K Object Out") public class LeaveKObjectOutBatchOp extends BatchOperator implements LeaveKObjectOutParams { private static final long serialVersionUID = 8447591038487459735L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/LeaveTopKObjectOutBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/LeaveTopKObjectOutBatchOp.java index 09c8037bd..05092a9b1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/LeaveTopKObjectOutBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/LeaveTopKObjectOutBatchOp.java @@ -11,12 +11,13 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.DataSetWrapperBatchOp; @@ -42,6 +43,7 @@ @ParamSelectColumnSpec(name = "rateCol", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("推荐结果TopK采样处理") +@NameEn("Leave TopK Object Out") public class LeaveTopKObjectOutBatchOp extends BatchOperator implements LeaveTopKObjectOutParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/NegativeItemSamplingBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/NegativeItemSamplingBatchOp.java index b75634e0b..efc3a4d0c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/NegativeItemSamplingBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/NegativeItemSamplingBatchOp.java @@ -15,6 +15,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -42,6 +43,7 @@ @PortSpec(PortType.DATA) }) @NameCn("推荐负采样") +@NameEn("Negative Item Sampling") public final class NegativeItemSamplingBatchOp extends BatchOperator implements NegativeItemSamplingParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/RecommendationRankingBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/RecommendationRankingBatchOp.java index 719b55999..34e0c13ed 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/RecommendationRankingBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/RecommendationRankingBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -11,6 +12,7 @@ @ParamSelectColumnSpec(name = "mTableCol", allowedTypeCollections = TypeCollections.MTABLE_TYPES) @NameCn("推荐组件:精排") +@NameEn("Recommendation Ranking") public class RecommendationRankingBatchOp extends ModelMapBatchOp implements RecommendationRankingParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/SwingTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/SwingTrainBatchOp.java index a45b1a563..e92013f5e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/SwingTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/SwingTrainBatchOp.java @@ -21,7 +21,7 @@ import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.exceptions.AkIllegalDataException; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; @@ -38,7 +38,6 @@ import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; -import java.util.HashSet; import java.util.Map.Entry; /** @@ -53,8 +52,8 @@ @ParamSelectColumnSpec(name = "itemCol") @NameCn("swing训练") @NameEn("Swing Recommendation Training") -public class SwingTrainBatchOp extends BatchOperator - implements SwingTrainParams { +public class SwingTrainBatchOp extends BatchOperator + implements SwingTrainParams { private static final long serialVersionUID = 6094224433980263495L; private static final String ITEM_ID_COLNAME = "alink_itemID_in_swing"; @@ -67,20 +66,20 @@ public SwingTrainBatchOp() { } @Override - public SwingTrainBatchOp linkFrom(BatchOperator... inputs) { + public SwingTrainBatchOp linkFrom(BatchOperator ... inputs) { String userCol = getUserCol(); String itemCol = getItemCol(); Integer maxUserItems = getMaxUserItems(); Integer minUserItems = getMinUserItems(); Integer maxItemNumber = getMaxItemNumber(); boolean normalize = getResultNormalize(); - String[] selectedCols = new String[]{userCol, itemCol}; + String[] selectedCols = new String[] {userCol, itemCol}; - BatchOperator in = checkAndGetFirst(inputs) + BatchOperator in = checkAndGetFirst(inputs) .select(selectedCols); long mlEnvId = getMLEnvironmentId(); - TypeInformation itemType = TableUtil.findColType(in.getSchema(), itemCol); + TypeInformation itemType = TableUtil.findColType(in.getSchema(), itemCol); if (!itemType.equals(Types.STRING) && !itemType.equals(Types.INT) && !itemType.equals(Types.LONG)) { throw new AkIllegalDataException("not supported item type:" + itemType + ", should be int,long or string"); } @@ -101,20 +100,22 @@ public SwingTrainBatchOp linkFrom(BatchOperator... inputs) { } //存储item ID,同用户其他item ID - DataSet, Long, Long[]>> mainItemData = in.getDataSet() + DataSet , Long, Long[]>> mainItemData = in.getDataSet() .groupBy(new RowKeySelector(0)) .reduceGroup(new BuildSwingData(maxUserItems, minUserItems, idIndex)) .name("build_main_item_data"); - DataSet itemSimilarity = mainItemData + DataSet itemSimilarity = mainItemData .groupBy(1) - .reduceGroup(new CalcSimilarity(getAlpha(), maxItemNumber, getUserAlpha(), getUserBeta(), normalize)) + .reduceGroup( + new CalcSimilarity(getAlpha(), maxItemNumber, maxUserItems, getUserAlpha(), getUserBeta(), normalize)) .name("compute_similarity"); - BatchOperator itemResult = BatchOperator.fromTable(DataSetConversionUtil.toTable(getMLEnvironmentId(), itemSimilarity, - new String[]{itemCol, "swing_items", "swing_scores"}, - new TypeInformation[]{itemType, Types.OBJECT_ARRAY(Types.LONG), Types.OBJECT_ARRAY(Types.FLOAT)} - )); + BatchOperator itemResult = BatchOperator.fromTable( + DataSetConversionUtil.toTable(getMLEnvironmentId(), itemSimilarity, + new String[] {itemCol, "swing_items", "swing_scores"}, + new TypeInformation[] {itemType, Types.OBJECT_ARRAY(Types.LONG), Types.OBJECT_ARRAY(Types.FLOAT)} + )); if (itemType.equals(Types.STRING)) { itemResult = new HugeIndexerStringPredictBatchOp() @@ -123,7 +124,7 @@ public SwingTrainBatchOp linkFrom(BatchOperator... inputs) { .linkFrom(model, itemResult); } Params meta = getParams().set(SwingRecommKernel.ITEM_TYPE, FlinkTypeConverter.getTypeString(itemType)); - DataSet modelData = itemResult.getDataSet() + DataSet modelData = itemResult.getDataSet() .mapPartition(new BuildModelData(itemCol, meta)) .name("build_model_data"); @@ -131,7 +132,7 @@ public SwingTrainBatchOp linkFrom(BatchOperator... inputs) { return this; } - public static class RowKeySelector implements KeySelector> { + public static class RowKeySelector implements KeySelector > { private static final long serialVersionUID = 7514280642434354647L; int index; @@ -140,8 +141,8 @@ public RowKeySelector(int index) { } @Override - public Comparable getKey(Row value) { - return (Comparable) value.getField(index); + public Comparable getKey(Row value) { + return (Comparable ) value.getField(index); } } @@ -149,7 +150,7 @@ public Comparable getKey(Row value) { * group by user col. */ private static class BuildSwingData - implements GroupReduceFunction, Long, Long[]>> { + implements GroupReduceFunction , Long, Long[]>> { private static final long serialVersionUID = 6417591701594465880L; int maxUserItems; @@ -163,18 +164,19 @@ private static class BuildSwingData } @Override - public void reduce(Iterable values, - Collector, Long, Long[]>> out) throws Exception { - HashMap> userItemMap = new HashMap <>(); + public void reduce(Iterable values, + Collector , Long, Long[]>> out) throws Exception { + HashMap > userItemMap = new HashMap <>(); for (Row value : values) { - userItemMap.put(Long.valueOf(String.valueOf(value.getField(idIndex))), (Comparable) value.getField(1)); + userItemMap.put(Long.valueOf(String.valueOf(value.getField(idIndex))), + (Comparable ) value.getField(1)); } if (userItemMap.size() < this.minUserItems || userItemMap.size() > this.maxUserItems) { return; } Long[] userItemIDs = new Long[userItemMap.size()]; int index = 0; - for (Entry > pair : userItemMap.entrySet()) { + for (Entry > pair : userItemMap.entrySet()) { userItemIDs[index++] = pair.getKey(); } for (Long userItemID : userItemIDs) { @@ -184,38 +186,41 @@ public void reduce(Iterable values, } private static class CalcSimilarity - extends RichGroupReduceFunction, Long, Long[]>, Row> { + extends RichGroupReduceFunction , Long, Long[]>, Row> { private static final long serialVersionUID = -2438120820385058339L; private final float alpha; int maxItemNumber; + int maxUserItems; float userAlpha; float userBeta; boolean normalize; - CalcSimilarity(float alpha, int maxItemNumber, float userAlpha, float userBeta, boolean normalize) { + CalcSimilarity(float alpha, int maxItemNumber, int maxUserItems, float userAlpha, float userBeta, + boolean normalize) { this.alpha = alpha; this.userAlpha = userAlpha; this.userBeta = userBeta; this.maxItemNumber = maxItemNumber; + this.maxUserItems = maxUserItems; this.normalize = normalize; } private float computeUserWeight(int size) { - return (float)(1.0 / Math.pow(userAlpha + size, userBeta)); + return (float) (1.0 / Math.pow(userAlpha + size, userBeta)); } @Override - public void reduce(Iterable, Long, Long[]>> values, - Collector out) throws Exception { - Comparable item = null; + public void reduce(Iterable , Long, Long[]>> values, + Collector out) throws Exception { + Comparable item = null; Long mainItem = null; - ArrayList dataList = new ArrayList <>(); - for (Tuple3, Long, Long[]> value : values) { + ArrayList dataList = new ArrayList <>(); + for (Tuple3 , Long, Long[]> value : values) { item = value.f0; mainItem = value.f1; if (dataList.size() == this.maxItemNumber) { - int randomIndex = (int)(Math.random() * (this.maxItemNumber + 1)); + int randomIndex = (int) (Math.random() * (this.maxItemNumber + 1)); if (randomIndex < this.maxItemNumber) { dataList.set(randomIndex, value.f2); } @@ -223,27 +228,25 @@ public void reduce(Iterable, Long, Long[]>> values, dataList.add(value.f2); } } - ArrayList > itemSetList = new ArrayList <>(dataList.size()); float[] userWeights = new float[dataList.size()]; int weightIndex = 0; for (Long[] value : dataList) { - HashSet itemSet = new HashSet <>(value.length); - itemSet.addAll(Arrays.asList(value)); - itemSetList.add(itemSet); + Arrays.sort(value); userWeights[weightIndex++] = computeUserWeight(value.length); } //双重遍历,计算swing权重 HashMap id2swing = new HashMap <>(); - for (int i = 0; i < itemSetList.size(); i++) { - for (int j = i + 1; j < itemSetList.size(); j++) { - HashSet interaction = (HashSet ) itemSetList.get(i).clone(); - interaction.retainAll(itemSetList.get(j)); - if (interaction.size() == 0) { + long[] interaction = new long[maxUserItems]; + for (int i = 0; i < dataList.size(); i++) { + for (int j = i + 1; j < dataList.size(); j++) { + int interactionSize = countCommonItems(dataList.get(i), dataList.get(j), interaction); + if (interactionSize == 0) { continue; } - float similarity = userWeights[i] * userWeights[j] / (alpha + interaction.size()); - for (Long id : interaction) { + float similarity = userWeights[i] * userWeights[j] / (alpha + interactionSize); + for (int k = 0; k < interactionSize; k++) { + Long id = interaction[k]; if (id.equals(mainItem)) { continue; } @@ -252,7 +255,7 @@ public void reduce(Iterable, Long, Long[]>> values, } } } - ArrayList> itemAndScore = new ArrayList<>(); + ArrayList > itemAndScore = new ArrayList <>(); id2swing.forEach( (key, value) -> itemAndScore.add(Tuple2.of(key, value)) ); @@ -275,23 +278,42 @@ public int compare(Tuple2 o1, Tuple2 o2) { } out.collect(Row.of(item, itemIds, itemScores)); } + + private static int countCommonItems(Long[] u, Long[] v, long[] interaction) { + int pointerU = 0; + int pointerV = 0; + int interactionSize = 0; + while (pointerU < u.length && pointerV < v.length) { + if (u[pointerU].equals(v[pointerV])) { + interaction[interactionSize++] = u[pointerU]; + pointerU++; + pointerV++; + } else if (u[pointerU] < v[pointerV]) { + pointerU++; + } else { + pointerV++; + } + } + return interactionSize; + } } - private static class BuildModelData extends RichMapPartitionFunction { + private static class BuildModelData extends RichMapPartitionFunction { private final String itemCol; private final Params meta; + BuildModelData(String itemCol, Params meta) { this.itemCol = itemCol; this.meta = meta; } @Override - public void mapPartition(Iterable values, Collector out) throws Exception { + public void mapPartition(Iterable values, Collector out) throws Exception { if (getRuntimeContext().getIndexOfThisSubtask() == 0) { out.collect(Row.of(null, meta.toJson())); } for (Row value : values) { - Comparable originMainItem = (Comparable) value.getField(0); + Comparable originMainItem = (Comparable ) value.getField(0); Object itemsValue = value.getField(1); Object[] items; if (itemsValue instanceof String) { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfItemsPerUserRecommBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfItemsPerUserRecommBatchOp.java index 06f5945ba..b95930fcc 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfItemsPerUserRecommBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfItemsPerUserRecommBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.recommendation.RecommType; import com.alibaba.alink.operator.common.recommendation.UserCfRecommKernel; import com.alibaba.alink.params.recommendation.BaseItemsPerUserRecommParams; @@ -11,6 +12,7 @@ * Recommend similar items for the given item. */ @NameCn("UserCf:ItemsPerUser推荐") +@NameEn("UserCf Items Per User Recommend") public class UserCfItemsPerUserRecommBatchOp extends BaseRecommBatchOp implements BaseItemsPerUserRecommParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfModelInfoBatchOp.java index 30cb97fd0..41550b1f2 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfModelInfoBatchOp.java @@ -3,7 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.recommendation.UserCfModelInfo; import java.util.List; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfRateRecommBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfRateRecommBatchOp.java index 8ec4610ed..f5946d668 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfRateRecommBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfRateRecommBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.recommendation.RecommType; import com.alibaba.alink.operator.common.recommendation.UserCfRecommKernel; import com.alibaba.alink.params.recommendation.BaseRateRecommParams; @@ -11,6 +12,7 @@ * Rating for user-item pair with user CF model. */ @NameCn("UserCf:打分推荐") +@NameEn("UserCf Rate Recommend") public class UserCfRateRecommBatchOp extends BaseRecommBatchOp implements BaseRateRecommParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfSimilarUsersRecommBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfSimilarUsersRecommBatchOp.java index f868a7c3c..f3fb951e3 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfSimilarUsersRecommBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfSimilarUsersRecommBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.recommendation.RecommType; import com.alibaba.alink.operator.common.recommendation.UserCfRecommKernel; import com.alibaba.alink.params.recommendation.BaseSimilarUsersRecommParams; @@ -11,6 +12,7 @@ * Recommend similar items for the given item. */ @NameCn("UserCf:相似users推荐") +@NameEn("UserCf Similar Users Recommend") public class UserCfSimilarUsersRecommBatchOp extends BaseRecommBatchOp implements BaseSimilarUsersRecommParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfTrainBatchOp.java index a86d4933c..9c9af87cf 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfTrainBatchOp.java @@ -4,12 +4,13 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.recommendation.UserCfModelInfo; import com.alibaba.alink.params.recommendation.UserCfRecommTrainParams; @@ -26,6 +27,7 @@ @ParamSelectColumnSpec(name = "rateCol", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("UserCf训练") +@NameEn("UserCf Training") public class UserCfTrainBatchOp extends BatchOperator implements UserCfRecommTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfUsersPerItemRecommBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfUsersPerItemRecommBatchOp.java index f281d9444..e8b375310 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfUsersPerItemRecommBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/UserCfUsersPerItemRecommBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.recommendation.RecommType; import com.alibaba.alink.operator.common.recommendation.UserCfRecommKernel; import com.alibaba.alink.params.recommendation.BaseUsersPerItemRecommParams; @@ -11,6 +12,7 @@ * Recommend similar users for the given item. */ @NameCn("UserCf:UsersPerItem推荐") +@NameEn("UserCf Users Per Item Recommend") public class UserCfUsersPerItemRecommBatchOp extends BaseRecommBatchOp implements BaseUsersPerItemRecommParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/AftSurvivalRegModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/AftSurvivalRegModelInfoBatchOp.java index 1ca4bdf72..9af9987f8 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/AftSurvivalRegModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/AftSurvivalRegModelInfoBatchOp.java @@ -3,8 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; -import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.linear.LinearRegressorModelInfo; import java.util.List; @@ -27,8 +26,4 @@ protected LinearRegressorModelInfo createModelInfo(List rows) { return new LinearRegressorModelInfo(rows); } - @Override - protected BatchOperator processModel() { - return this; - } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/AftSurvivalRegPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/AftSurvivalRegPredictBatchOp.java index 07e2df307..29d50af07 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/AftSurvivalRegPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/AftSurvivalRegPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -18,6 +19,7 @@ @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("生存回归预测") +@NameEn("Aft Survival Regression Prediction") public class AftSurvivalRegPredictBatchOp extends ModelMapBatchOp implements AftRegPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/AftSurvivalRegTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/AftSurvivalRegTrainBatchOp.java index c366eeaaa..964da2e78 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/AftSurvivalRegTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/AftSurvivalRegTrainBatchOp.java @@ -18,6 +18,7 @@ import com.alibaba.alink.common.annotation.FeatureColsVectorColMutexRule; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -27,8 +28,8 @@ import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; import com.alibaba.alink.common.exceptions.AkIllegalModelException; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; -import com.alibaba.alink.common.lazy.WithTrainInfo; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithTrainInfo; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.linalg.Vector; @@ -45,6 +46,7 @@ import com.alibaba.alink.params.shared.colname.HasVectorCol; import com.alibaba.alink.params.shared.linear.HasWithIntercept; import com.alibaba.alink.params.shared.linear.LinearTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import java.util.ArrayList; import java.util.List; @@ -74,6 +76,8 @@ @FeatureColsVectorColMutexRule @NameCn("生存回归训练") +@NameEn("Aft Survival Regression Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.AftSurvivalRegression") public class AftSurvivalRegTrainBatchOp extends BatchOperator implements AftRegTrainParams , WithTrainInfo , @@ -200,7 +204,7 @@ public Integer map(Tuple3 value) { } }); this.setSideOutputTables( - BaseLinearModelTrainBatchOp.getSideTablesOfCoefficient(modelRows, initData, featSize, + BaseLinearModelTrainBatchOp.getSideTablesOfCoefficient(coefVectorSet.project(1), modelRows, initData, featSize, params.get(LinearTrainParams.FEATURE_COLS), params.get(LinearTrainParams.WITH_INTERCEPT), getMLEnvironmentId())); diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/BertTextPairRegressorPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/BertTextPairRegressorPredictBatchOp.java index f0e13bfc4..f3613f8d1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/BertTextPairRegressorPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/BertTextPairRegressorPredictBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; /** * Prediction with a text pair regressor using Bert models. */ @NameCn("Bert文本对回归预测") +@NameEn("Bert Text Pair Regression Prediction") public class BertTextPairRegressorPredictBatchOp extends TFTableModelRegressorPredictBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/BertTextPairRegressorTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/BertTextPairRegressorTrainBatchOp.java index a8642d16f..406a1ecdf 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/BertTextPairRegressorTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/BertTextPairRegressorTrainBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.dl.BaseEasyTransferTrainBatchOp; @@ -11,6 +12,7 @@ import com.alibaba.alink.params.dl.HasTaskType; import com.alibaba.alink.params.tensorflow.bert.BertTextPairTrainParams; import com.alibaba.alink.params.tensorflow.bert.HasTaskName; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Train a text pair regressor using Bert models. @@ -19,6 +21,8 @@ @ParamSelectColumnSpec(name = "textPairCol", allowedTypeCollections = TypeCollections.STRING_TYPE) @ParamSelectColumnSpec(name = "labelCol") @NameCn("Bert文本对回归训练") +@NameEn("Bert Text Pair Regression Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.BertTextPairRegressor") public class BertTextPairRegressorTrainBatchOp extends BaseEasyTransferTrainBatchOp implements BertTextPairTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/BertTextRegressorPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/BertTextRegressorPredictBatchOp.java index ad0301552..fe6805230 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/BertTextRegressorPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/BertTextRegressorPredictBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; /** * Prediction with a text regressor using Bert models. */ @NameCn("Bert文本回归预测") +@NameEn("Bert Text Regression Prediction") public class BertTextRegressorPredictBatchOp extends TFTableModelRegressorPredictBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/BertTextRegressorTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/BertTextRegressorTrainBatchOp.java index 3552d627b..5d48ef0aa 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/BertTextRegressorTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/BertTextRegressorTrainBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.dl.BaseEasyTransferTrainBatchOp; @@ -11,6 +12,7 @@ import com.alibaba.alink.params.dl.HasTaskType; import com.alibaba.alink.params.tensorflow.bert.BertTextTrainParams; import com.alibaba.alink.params.tensorflow.bert.HasTaskName; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Train a text regressor using Bert models. @@ -18,6 +20,8 @@ @ParamSelectColumnSpec(name = "textCol", allowedTypeCollections = TypeCollections.STRING_TYPE) @ParamSelectColumnSpec(name = "labelCol") @NameCn("Bert文本回归训练") +@NameEn("Bert Text Regression Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.BertTextRegressor") public class BertTextRegressorTrainBatchOp extends BaseEasyTransferTrainBatchOp implements BertTextTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/CartRegPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/CartRegPredictBatchOp.java index d04ee8040..cebb6c1b0 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/CartRegPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/CartRegPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.tree.predictors.RandomForestModelMapper; import com.alibaba.alink.params.regression.CartRegPredictParams; @@ -11,6 +12,7 @@ * The batch operator that predict the data using the cart regression model. */ @NameCn("CART决策树回归预测") +@NameEn("Cart Regression Prediction") public final class CartRegPredictBatchOp extends ModelMapBatchOp implements CartRegPredictParams { private static final long serialVersionUID = 8351046637860036501L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/CartRegTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/CartRegTrainBatchOp.java index 5be152736..0c002cbcb 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/CartRegTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/CartRegTrainBatchOp.java @@ -3,7 +3,8 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp; import com.alibaba.alink.operator.common.tree.TreeModelInfo; import com.alibaba.alink.operator.common.tree.TreeUtil; @@ -11,12 +12,15 @@ import com.alibaba.alink.params.shared.tree.HasFeatureSubsamplingRatio; import com.alibaba.alink.params.shared.tree.HasNumTreesDefaltAs10; import com.alibaba.alink.params.shared.tree.HasSubsamplingRatio; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Fit a cart regression model. */ @NameCn("CART决策树回归训练") -public final class CartRegTrainBatchOp extends BaseRandomForestTrainBatchOp implements +@NameEn("Cart Regression Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.CartReg") +public class CartRegTrainBatchOp extends BaseRandomForestTrainBatchOp implements CartRegTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/DecisionTreeRegPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/DecisionTreeRegPredictBatchOp.java index 4ca579bde..248243330 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/DecisionTreeRegPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/DecisionTreeRegPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.tree.predictors.RandomForestModelMapper; import com.alibaba.alink.params.regression.DecisionTreeRegPredictParams; @@ -28,6 +29,7 @@ * @see Random_forest */ @NameCn("决策树回归预测") +@NameEn("Decision Tree Regression Prediction") public final class DecisionTreeRegPredictBatchOp extends ModelMapBatchOp implements DecisionTreeRegPredictParams { private static final long serialVersionUID = -8850643703068190492L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/DecisionTreeRegTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/DecisionTreeRegTrainBatchOp.java index 1b72553fc..3e74915b2 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/DecisionTreeRegTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/DecisionTreeRegTrainBatchOp.java @@ -3,7 +3,8 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp; import com.alibaba.alink.operator.common.tree.TreeModelInfo; import com.alibaba.alink.operator.common.tree.TreeUtil; @@ -11,6 +12,7 @@ import com.alibaba.alink.params.shared.tree.HasFeatureSubsamplingRatio; import com.alibaba.alink.params.shared.tree.HasNumTreesDefaltAs10; import com.alibaba.alink.params.shared.tree.HasSubsamplingRatio; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * The random forest use the bagging to prevent the overfitting. @@ -33,7 +35,9 @@ * @see Random_forest */ @NameCn("决策树回归训练") -public final class DecisionTreeRegTrainBatchOp extends BaseRandomForestTrainBatchOp +@NameEn("Decision Tree Regression Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.DecisionTreeRegressor") +public class DecisionTreeRegTrainBatchOp extends BaseRandomForestTrainBatchOp implements DecisionTreeRegTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/FmRegressorModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/FmRegressorModelInfoBatchOp.java index 945b14f3e..11ddf9bd6 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/FmRegressorModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/FmRegressorModelInfoBatchOp.java @@ -3,8 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; -import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.fm.FmRegressorModelInfo; import java.util.List; @@ -34,8 +33,4 @@ protected FmRegressorModelInfo createModelInfo(List rows) { return new FmRegressorModelInfo(rows); } - @Override - protected BatchOperator processModel() { - return this; - } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/FmRegressorTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/FmRegressorTrainBatchOp.java index 4a50347cb..241425d45 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/FmRegressorTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/FmRegressorTrainBatchOp.java @@ -5,13 +5,14 @@ import com.alibaba.alink.common.annotation.NameCn; import com.alibaba.alink.common.annotation.NameEn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; -import com.alibaba.alink.common.lazy.WithTrainInfo; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithTrainInfo; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.fm.FmRegressorModelInfo; import com.alibaba.alink.operator.common.fm.FmRegressorModelTrainInfo; import com.alibaba.alink.operator.common.fm.FmTrainBatchOp; import com.alibaba.alink.params.recommendation.FmTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import java.util.List; @@ -20,6 +21,7 @@ */ @NameCn("FM回归训练") @NameEn("FM Regression Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.FmRegressor") public class FmRegressorTrainBatchOp extends FmTrainBatchOp implements FmTrainParams , WithModelInfoBatchOp , diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/GbdtRegPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/GbdtRegPredictBatchOp.java index 2d1580321..61a56ffe5 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/GbdtRegPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/GbdtRegPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -31,6 +32,7 @@ */ @ParamSelectColumnSpec(name = "vectorCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("GBDT回归预测") +@NameEn("GBDT Regression Prediction") public final class GbdtRegPredictBatchOp extends ModelMapBatchOp implements GbdtRegPredictParams { private static final long serialVersionUID = 5866895002748842133L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/GbdtRegTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/GbdtRegTrainBatchOp.java index f8f757cee..ffb05171e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/GbdtRegTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/GbdtRegTrainBatchOp.java @@ -3,12 +3,14 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.common.tree.TreeModelInfo; import com.alibaba.alink.operator.common.tree.parallelcart.BaseGbdtTrainBatchOp; import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossType; import com.alibaba.alink.operator.common.tree.parallelcart.loss.LossUtils; import com.alibaba.alink.params.regression.GbdtRegTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Gradient Boosting(often abbreviated to GBDT or GBM) is a popular supervised learning model. @@ -31,7 +33,9 @@ * for an introduction on data-parallel, feature-parallel, etc., algorithms to construct decision forests. */ @NameCn("GBDT回归训练") -public final class GbdtRegTrainBatchOp extends BaseGbdtTrainBatchOp +@NameEn("GBDT Regression Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.GbdtRegressor") +public class GbdtRegTrainBatchOp extends BaseGbdtTrainBatchOp implements GbdtRegTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmEvaluationBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmEvaluationBatchOp.java index 72b1bd8cf..7aee49436 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmEvaluationBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmEvaluationBatchOp.java @@ -10,12 +10,13 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.regression.glm.FamilyLink; @@ -37,6 +38,7 @@ @ParamSelectColumnSpec(name = "labelCol", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @ParamSelectColumnSpec(name = "weightCol", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("广义线性回归评估") +@NameEn("GLM Evaluation") public class GlmEvaluationBatchOp extends BatchOperator implements GlmTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmModelInfoBatchOp.java index 479a8b638..a0edc3f2c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmModelInfoBatchOp.java @@ -3,9 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; -import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.linear.LinearRegressorModelInfo; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.regression.glm.GlmModelInfo; import java.util.List; @@ -28,8 +26,4 @@ protected GlmModelInfo createModelInfo(List rows) { return new GlmModelInfo(rows); } - @Override - protected BatchOperator processModel() { - return this; - } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmPredictBatchOp.java index e64854157..07e28cae9 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.regression.GlmModelMapper; import com.alibaba.alink.params.regression.GlmPredictParams; @@ -11,6 +12,7 @@ * Generalized Linear Model. https://en.wikipedia.org/wiki/Generalized_linear_model. */ @NameCn("广义线性回归预测") +@NameEn("GLM Prediction") public class GlmPredictBatchOp extends ModelMapBatchOp implements GlmPredictParams { private static final long serialVersionUID = 1855615229106536018L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmTrainBatchOp.java index 29c318130..0fb1fdb01 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmTrainBatchOp.java @@ -13,13 +13,14 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.regression.GlmModelData; @@ -30,6 +31,7 @@ import com.alibaba.alink.operator.common.regression.glm.GlmUtil.GlmModelSummary; import com.alibaba.alink.operator.common.regression.glm.GlmUtil.WeightedLeastSquaresModel; import com.alibaba.alink.params.regression.GlmTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Generalized Linear Model. https://en.wikipedia.org/wiki/Generalized_linear_model. @@ -45,6 +47,8 @@ @ParamSelectColumnSpec(name = "labelCol", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @ParamSelectColumnSpec(name = "weightCol", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("广义线性回归训练") +@NameEn("GLM Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.GeneralizedLinearRegression") public final class GlmTrainBatchOp extends BatchOperator implements GlmTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/IsotonicRegPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/IsotonicRegPredictBatchOp.java index c86c46927..954d55aa7 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/IsotonicRegPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/IsotonicRegPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.regression.IsotonicRegressionModelMapper; import com.alibaba.alink.params.regression.IsotonicRegPredictParams; @@ -12,6 +13,7 @@ * Implement parallelized pool adjacent violators algorithm. */ @NameCn("保序回归预测") +@NameEn("Isotonic Regression Prediction") public final class IsotonicRegPredictBatchOp extends ModelMapBatchOp implements IsotonicRegPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/IsotonicRegTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/IsotonicRegTrainBatchOp.java index 18871ee27..0315aedb7 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/IsotonicRegTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/IsotonicRegTrainBatchOp.java @@ -12,6 +12,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamCond; import com.alibaba.alink.common.annotation.ParamCond.CondType; @@ -29,6 +30,7 @@ import com.alibaba.alink.operator.common.regression.IsotonicRegressionModelData; import com.alibaba.alink.operator.common.regression.isotonicReg.LinkedData; import com.alibaba.alink.params.regression.IsotonicRegTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; import com.google.common.collect.Lists; import java.nio.ByteBuffer; @@ -66,6 +68,8 @@ ) @NameCn("保序回归训练") +@NameEn("Isotonic Regression Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.IsotonicRegression") public final class IsotonicRegTrainBatchOp extends BatchOperator implements IsotonicRegTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/KerasSequentialRegressorPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/KerasSequentialRegressorPredictBatchOp.java index 43fd01df7..d15642738 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/KerasSequentialRegressorPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/KerasSequentialRegressorPredictBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; /** * Prediction with a regressor using a Keras Sequential model. */ @NameCn("KerasSequential回归预测") +@NameEn("KerasSequential Regression Prediction") public class KerasSequentialRegressorPredictBatchOp extends TFTableModelRegressorPredictBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/KerasSequentialRegressorTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/KerasSequentialRegressorTrainBatchOp.java index eba58ba97..a22cbae15 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/KerasSequentialRegressorTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/KerasSequentialRegressorTrainBatchOp.java @@ -3,14 +3,18 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.dl.BaseKerasSequentialTrainBatchOp; import com.alibaba.alink.common.dl.TaskType; import com.alibaba.alink.params.dl.HasTaskType; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Train a regressor using a Keras Sequential model. */ @NameCn("KerasSequential回归训练") +@NameEn("KerasSequential Regression Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.KerasSequentialRegressor") public class KerasSequentialRegressorTrainBatchOp extends BaseKerasSequentialTrainBatchOp { public KerasSequentialRegressorTrainBatchOp() { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/LassoRegModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LassoRegModelInfoBatchOp.java index feeb8761e..2849396ef 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/LassoRegModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LassoRegModelInfoBatchOp.java @@ -3,8 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; -import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.linear.LinearRegressorModelInfo; import java.util.List; @@ -27,8 +26,4 @@ protected LinearRegressorModelInfo createModelInfo(List rows) { return new LinearRegressorModelInfo(rows); } - @Override - protected BatchOperator processModel() { - return this; - } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/LassoRegTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LassoRegTrainBatchOp.java index 0bd6f9924..a6a643fb8 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/LassoRegTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LassoRegTrainBatchOp.java @@ -4,17 +4,19 @@ import com.alibaba.alink.common.annotation.NameCn; import com.alibaba.alink.common.annotation.NameEn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp; import com.alibaba.alink.operator.common.linear.LinearModelType; import com.alibaba.alink.operator.common.linear.LinearRegressorModelInfo; import com.alibaba.alink.params.regression.LassoRegTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Train a regression model with L1-regularization. */ @NameCn("Lasso回归训练") @NameEn("Lasso Regression Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.LassoRegression") public final class LassoRegTrainBatchOp extends BaseLinearModelTrainBatchOp implements LassoRegTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegModelInfoBatchOp.java index 0a5086965..e8dd99e24 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegModelInfoBatchOp.java @@ -3,8 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; -import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.linear.LinearRegressorModelInfo; import java.util.List; @@ -27,8 +26,4 @@ protected LinearRegressorModelInfo createModelInfo(List rows) { return new LinearRegressorModelInfo(rows); } - @Override - protected BatchOperator processModel() { - return this; - } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegStepwisePredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegStepwisePredictBatchOp.java new file mode 100644 index 000000000..402fa3c72 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegStepwisePredictBatchOp.java @@ -0,0 +1,30 @@ +package com.alibaba.alink.operator.batch.regression; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.linear.LinearModelMapper; +import com.alibaba.alink.params.regression.LinearRegStepwisePredictParams; + +/** + * * + * + * @author weibo zhao + */ +@NameCn("线性回归Stepwise预测") +@NameEn("Stepwise Linear Regression Prediction") +public final class LinearRegStepwisePredictBatchOp extends ModelMapBatchOp + implements LinearRegStepwisePredictParams { + + private static final long serialVersionUID = -4433560932329911993L; + + public LinearRegStepwisePredictBatchOp() { + this(new Params()); + } + + public LinearRegStepwisePredictBatchOp(Params params) { + super(LinearModelMapper::new, params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegStepwiseTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegStepwiseTrainBatchOp.java new file mode 100644 index 000000000..8422d4874 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegStepwiseTrainBatchOp.java @@ -0,0 +1,49 @@ +package com.alibaba.alink.operator.batch.regression; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.regression.LinearRegWithSummaryResult.LinearRegType; +import com.alibaba.alink.params.regression.LinearRegStepwiseTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; + +@InputPorts(values = @PortSpec(PortType.DATA)) +@OutputPorts(values = { + @PortSpec(PortType.MODEL) +}) +@ParamSelectColumnSpec(name = "featureCols", + allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@ParamSelectColumnSpec(name = "labelCol") +@NameCn("线性回归Stepwise训练") +@NameEn("Stepwise Linear Regression Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.LinearRegStepwise") +public final class LinearRegStepwiseTrainBatchOp extends BatchOperator + implements LinearRegStepwiseTrainParams { + + private static final long serialVersionUID = -1316297704040780024L; + + public LinearRegStepwiseTrainBatchOp() { + super(); + } + + public LinearRegStepwiseTrainBatchOp(Params params) { + super(params); + } + + @Override + public LinearRegStepwiseTrainBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + this.setOutputTable(new LinearRegWithSummaryResult(this.getParams(), LinearRegType.stepwise).linkFrom(in) + .getOutputTable()); + return this; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegTrainBatchOp.java index c37a66a67..701df796c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegTrainBatchOp.java @@ -4,17 +4,19 @@ import com.alibaba.alink.common.annotation.NameCn; import com.alibaba.alink.common.annotation.NameEn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp; import com.alibaba.alink.operator.common.linear.LinearModelType; import com.alibaba.alink.operator.common.linear.LinearRegressorModelInfo; import com.alibaba.alink.params.regression.LinearRegTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Train a regression model. */ @NameCn("线性回归训练") @NameEn("Linear Regression Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.LinearRegression") public final class LinearRegTrainBatchOp extends BaseLinearModelTrainBatchOp implements LinearRegTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegWithSummaryResult.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegWithSummaryResult.java new file mode 100644 index 000000000..2bda99d1f --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearRegWithSummaryResult.java @@ -0,0 +1,121 @@ +package com.alibaba.alink.operator.batch.regression; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.ml.api.misc.param.ParamInfo; +import org.apache.flink.ml.api.misc.param.ParamInfoFactory; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; +import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.linear.LinearModelData; +import com.alibaba.alink.operator.common.linear.LinearModelDataConverter; +import com.alibaba.alink.operator.common.linear.LinearModelType; +import com.alibaba.alink.operator.common.regression.LinearReg; +import com.alibaba.alink.operator.common.regression.LinearRegressionModel; +import com.alibaba.alink.operator.common.regression.LinearRegressionStepwise; +import com.alibaba.alink.operator.common.regression.LinearRegressionStepwiseModel; +import com.alibaba.alink.operator.common.regression.RidgeRegressionProcess; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; +import com.alibaba.alink.operator.common.statistics.statistics.SummaryResultTable; +import com.alibaba.alink.params.regression.LinearRegStepwiseTrainParams; +import com.alibaba.alink.params.regression.LinearRegStepwiseTrainParams.Method; +import com.alibaba.alink.params.regression.LinearRegTrainParams; +import com.alibaba.alink.params.regression.RidgeRegTrainParams; +import com.alibaba.alink.params.statistics.HasStatLevel_L1; + +class LinearRegWithSummaryResult extends BatchOperator { + + private static final long serialVersionUID = 9007546963532152447L; + private static final ParamInfo REG_TYPE = ParamInfoFactory + .createParamInfo("regType", LinearRegWithSummaryResult.LinearRegType.class) + .setDescription("regType") + .setRequired() + .build(); + + public LinearRegWithSummaryResult(Params params, LinearRegType regType) { + super(params); + this.getParams().set(REG_TYPE, regType); + } + + private static LinearModelData getLinearModel(String modelName, String[] nameX, double[] beta) { + LinearModelData retData = new LinearModelData(); + retData.coefVector = new DenseVector(beta.clone()); + retData.modelName = modelName; + retData.linearModelType = LinearModelType.LinearReg; + retData.featureNames = nameX; + return retData; + } + + @Override + public LinearRegWithSummaryResult linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + String yVar = getParams().get(LinearRegTrainParams.LABEL_COL); + try { + int yColIndex = TableUtil.findColIndex(in.getColNames(), yVar); + if (yColIndex < 0) { + throw new AkIllegalArgumentException("There is no column(" + yVar + ") in the training dataset."); + } + setOutput(StatisticsHelper.getSRT(in, HasStatLevel_L1.StatLevel.L2).flatMap(new MyReg(this.getParams(), in.getColTypes()[yColIndex])), + new LinearModelDataConverter(in.getColTypes()[yColIndex]).getModelSchema()); + return this; + } catch (Exception ex) { + ex.printStackTrace(); + throw new AkUnsupportedOperationException(ex.getMessage()); + } + } + + public enum LinearRegType { + common, + ridge, + stepwise + } + + public static class MyReg implements FlatMapFunction { + private static final long serialVersionUID = 5647053026802533733L; + private final Params params; + private final TypeInformation labelType; + + public MyReg(Params params, TypeInformation labelType) { + this.params = params; + this.labelType = labelType; + } + + @Override + public void flatMap(SummaryResultTable srt, Collector collector) throws Exception { + String yVar = this.params.get(LinearRegTrainParams.LABEL_COL); + String[] xVars = this.params.get(LinearRegTrainParams.FEATURE_COLS); + LinearRegType regType = this.params.get(REG_TYPE); + LinearRegressionModel lrm; + switch (regType) { + case common: + lrm = LinearReg.train(srt, yVar, xVars); + new LinearModelDataConverter(labelType) + .save(getLinearModel("Linear Regression", lrm.nameX, lrm.beta), collector); + break; + case ridge: + double lambda = this.params.get(RidgeRegTrainParams.LAMBDA); + RidgeRegressionProcess rrp = new RidgeRegressionProcess(srt, yVar, xVars); + lrm = rrp.calc(new double[] {lambda}).lrModels[0]; + new LinearModelDataConverter(labelType) + .save(getLinearModel("Ridge Regression", lrm.nameX, lrm.beta), collector); + break; + case stepwise: + Method method = this.params.get(LinearRegStepwiseTrainParams.METHOD); + LinearRegressionStepwiseModel lrsm = LinearRegressionStepwise.step(srt, yVar, xVars, method); + lrm = lrsm.lrr; + new LinearModelDataConverter(labelType) + .save(getLinearModel("Linear Regression Stepwise", lrm.nameX, lrm.beta), collector); + break; + default: + throw new AkUnsupportedOperationException("Not support this regression type : " + regType); + } + } + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearSvrPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearSvrPredictBatchOp.java new file mode 100644 index 000000000..5b9cdc77f --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearSvrPredictBatchOp.java @@ -0,0 +1,34 @@ +package com.alibaba.alink.operator.batch.regression; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.operator.common.linear.LinearModelMapper; +import com.alibaba.alink.params.regression.LinearSvrPredictParams; + +/** + * * + * + * @author weibo zhao + */ +@ParamSelectColumnSpec(name = "vectorCol", + allowedTypeCollections = TypeCollections.VECTOR_TYPES) +@NameCn("线性SVR预测") +@NameEn("Linear SVR Prediction") +public final class LinearSvrPredictBatchOp extends ModelMapBatchOp + implements LinearSvrPredictParams { + + private static final long serialVersionUID = 4438160354556417595L; + + public LinearSvrPredictBatchOp() { + this(new Params()); + } + + public LinearSvrPredictBatchOp(Params params) { + super(LinearModelMapper::new, params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearSvrTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearSvrTrainBatchOp.java new file mode 100644 index 000000000..a0be43ba7 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/LinearSvrTrainBatchOp.java @@ -0,0 +1,33 @@ +package com.alibaba.alink.operator.batch.regression; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp; +import com.alibaba.alink.operator.common.linear.LinearModelType; +import com.alibaba.alink.params.regression.LinearSvrTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; + +/** + * * + * + * @author yangxu + */ +@NameCn("线性SVR训练") +@NameEn("Linear SVR Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.LinearSvr") +public final class LinearSvrTrainBatchOp extends BaseLinearModelTrainBatchOp + implements LinearSvrTrainParams { + + private static final long serialVersionUID = 713597255264745408L; + + public LinearSvrTrainBatchOp() { + this(new Params()); + } + + public LinearSvrTrainBatchOp(Params params) { + super(params.clone(), LinearModelType.LinearReg, "Linear SVR"); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/RandomForestRegPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/RandomForestRegPredictBatchOp.java index 407f62133..dee387cdd 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/RandomForestRegPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/RandomForestRegPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.tree.predictors.RandomForestModelMapper; import com.alibaba.alink.params.regression.RandomForestRegPredictParams; @@ -28,6 +29,7 @@ * @see Random_forest */ @NameCn("随机森林回归预测") +@NameEn("Random Forest Regression Prediction") public final class RandomForestRegPredictBatchOp extends ModelMapBatchOp implements RandomForestRegPredictParams { private static final long serialVersionUID = 1645429815373572620L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/RandomForestRegTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/RandomForestRegTrainBatchOp.java index bf3b15139..62ee59418 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/RandomForestRegTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/RandomForestRegTrainBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.common.tree.BaseRandomForestTrainBatchOp; import com.alibaba.alink.operator.common.tree.TreeModelInfo; import com.alibaba.alink.operator.common.tree.TreeUtil; import com.alibaba.alink.params.regression.RandomForestRegTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * The random forest use the bagging to prevent the overfitting. @@ -30,7 +32,9 @@ * @see Random_forest */ @NameCn("随机森林回归训练") -public final class RandomForestRegTrainBatchOp extends BaseRandomForestTrainBatchOp +@NameEn("Random Forest Regression Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.RandomForestRegressor") +public class RandomForestRegTrainBatchOp extends BaseRandomForestTrainBatchOp implements RandomForestRegTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/RidgeRegModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/RidgeRegModelInfoBatchOp.java index b533e964b..41761fa6a 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/RidgeRegModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/RidgeRegModelInfoBatchOp.java @@ -3,8 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; -import com.alibaba.alink.common.lazy.ExtractModelInfoBatchOp; -import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; import com.alibaba.alink.operator.common.linear.LinearRegressorModelInfo; import java.util.List; @@ -27,8 +26,4 @@ protected LinearRegressorModelInfo createModelInfo(List rows) { return new LinearRegressorModelInfo(rows); } - @Override - protected BatchOperator processModel() { - return this; - } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/RidgeRegTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/RidgeRegTrainBatchOp.java index 225ca6b44..3d191ca26 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/RidgeRegTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/RidgeRegTrainBatchOp.java @@ -4,17 +4,19 @@ import com.alibaba.alink.common.annotation.NameCn; import com.alibaba.alink.common.annotation.NameEn; -import com.alibaba.alink.common.lazy.WithModelInfoBatchOp; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp; import com.alibaba.alink.operator.common.linear.LinearModelType; import com.alibaba.alink.operator.common.linear.LinearRegressorModelInfo; import com.alibaba.alink.params.regression.RidgeRegTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Train a regression model with L2-regularization. */ @NameCn("岭回归训练") @NameEn("Linear Regression Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.RidgeRegression") public final class RidgeRegTrainBatchOp extends BaseLinearModelTrainBatchOp implements RidgeRegTrainParams , WithModelInfoBatchOp { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/TFTableModelRegressorPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/TFTableModelRegressorPredictBatchOp.java index c24707094..4e5b3e023 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/TFTableModelRegressorPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/TFTableModelRegressorPredictBatchOp.java @@ -4,12 +4,14 @@ import com.alibaba.alink.common.annotation.Internal; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.FlatModelMapBatchOp; import com.alibaba.alink.operator.common.regression.tensorflow.TFTableModelRegressionFlatModelMapper; import com.alibaba.alink.params.regression.TFTableModelRegressionPredictParams; @Internal @NameCn("TF表模型回归预测") +@NameEn("TF Table Model Regression Prediction") public class TFTableModelRegressorPredictBatchOp> extends FlatModelMapBatchOp implements TFTableModelRegressionPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/XGBoostRegTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/XGBoostRegTrainBatchOp.java index 4132b382e..82173e3d8 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/XGBoostRegTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/XGBoostRegTrainBatchOp.java @@ -6,9 +6,11 @@ import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.tree.BaseXGBoostTrainBatchOp; import com.alibaba.alink.params.xgboost.XGBoostRegTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; @NameCn("XGBoost 回归训练") @NameEn("XGBoost Regression Train") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.regression.XGBoostRegressor") public final class XGBoostRegTrainBatchOp extends BaseXGBoostTrainBatchOp implements XGBoostRegTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringApproxNearestNeighborPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringApproxNearestNeighborPredictBatchOp.java index ef9b6a0ee..2b05eaab1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringApproxNearestNeighborPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringApproxNearestNeighborPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -14,6 +15,7 @@ */ @ParamSelectColumnSpec(name = "selectCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("字符串近似最近邻预测") +@NameEn("String Approx Nearest Neighbor Prediction") public class StringApproxNearestNeighborPredictBatchOp extends ModelMapBatchOp implements NearestNeighborPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringApproxNearestNeighborTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringApproxNearestNeighborTrainBatchOp.java index d35d2e33a..1cfb517bd 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringApproxNearestNeighborTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringApproxNearestNeighborTrainBatchOp.java @@ -3,17 +3,21 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.similarity.TrainType; import com.alibaba.alink.operator.common.similarity.dataConverter.StringModelDataConverter; import com.alibaba.alink.params.similarity.StringTextApproxNearestNeighborTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Find the approximate nearest neighbor of query string. */ @ParamSelectColumnSpec(name = "selectCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("字符串近似最近邻训练") +@NameEn("String Approx Nearest Neighbor Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.similarity.StringApproxNearestNeighbor") public class StringApproxNearestNeighborTrainBatchOp extends BaseNearestNeighborTrainBatchOp implements StringTextApproxNearestNeighborTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringNearestNeighborPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringNearestNeighborPredictBatchOp.java index e2c9f2ba2..e406f5505 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringNearestNeighborPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringNearestNeighborPredictBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; import com.alibaba.alink.common.annotation.PortSpec; @@ -22,6 +23,7 @@ }) @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("字符串最近邻预测") +@NameEn("String Nearest Neighbor Prediction") public class StringNearestNeighborPredictBatchOp extends ModelMapBatchOp implements NearestNeighborPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringNearestNeighborTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringNearestNeighborTrainBatchOp.java index 9a376a934..24fe1c635 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringNearestNeighborTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringNearestNeighborTrainBatchOp.java @@ -3,17 +3,21 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.similarity.TrainType; import com.alibaba.alink.operator.common.similarity.dataConverter.StringModelDataConverter; import com.alibaba.alink.params.similarity.StringTextNearestNeighborTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Find the nearest neighbor of query string. */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("字符串最近邻训练") +@NameEn("String Nearest Neighbor Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.similarity.StringNearestNeighbor") public class StringNearestNeighborTrainBatchOp extends BaseNearestNeighborTrainBatchOp implements StringTextNearestNeighborTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringSimilarityPairwiseBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringSimilarityPairwiseBatchOp.java index b8c576cdb..0cb210222 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringSimilarityPairwiseBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/StringSimilarityPairwiseBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -29,6 +30,7 @@ */ @ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("字符串两两相似度计算") +@NameEn("String Similarity Pairwise") public class StringSimilarityPairwiseBatchOp extends MapBatchOp implements StringTextPairwiseParams { private static final long serialVersionUID = 6952374807123805800L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextApproxNearestNeighborPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextApproxNearestNeighborPredictBatchOp.java index 39e500c50..eb5e24e01 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextApproxNearestNeighborPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextApproxNearestNeighborPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.similarity.NearestNeighborsMapper; import com.alibaba.alink.params.similarity.NearestNeighborPredictParams; @@ -11,6 +12,7 @@ * Find the approximate nearest neighbor of query texts. */ @NameCn("文本近似最近邻预测") +@NameEn("Text Approx Nearest Neighbor Prediction") public class TextApproxNearestNeighborPredictBatchOp extends ModelMapBatchOp implements NearestNeighborPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextApproxNearestNeighborTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextApproxNearestNeighborTrainBatchOp.java index 815c35527..f96ec6eed 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextApproxNearestNeighborTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextApproxNearestNeighborTrainBatchOp.java @@ -3,14 +3,18 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.common.similarity.TrainType; import com.alibaba.alink.operator.common.similarity.dataConverter.StringModelDataConverter; import com.alibaba.alink.params.similarity.StringTextApproxNearestNeighborTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Find the approximate nearest neighbor of query texts. */ @NameCn("文本近似最近邻训练") +@NameEn("Text Approx Nearest Neighbor Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.similarity.TextApproxNearestNeighbor") public class TextApproxNearestNeighborTrainBatchOp extends BaseNearestNeighborTrainBatchOp implements StringTextApproxNearestNeighborTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextNearestNeighborPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextNearestNeighborPredictBatchOp.java index c432a0d19..580713d26 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextNearestNeighborPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextNearestNeighborPredictBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; @@ -15,6 +16,7 @@ */ @InputPorts(values = {@PortSpec(value = PortType.MODEL, suggestions = TextNearestNeighborTrainBatchOp.class), @PortSpec(PortType.DATA)}) @NameCn("文本最近邻预测") +@NameEn("Text Nearest Neighbor Prediction") public class TextNearestNeighborPredictBatchOp extends ModelMapBatchOp implements NearestNeighborPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextNearestNeighborTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextNearestNeighborTrainBatchOp.java index 5652e3c67..468d4a338 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextNearestNeighborTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextNearestNeighborTrainBatchOp.java @@ -3,17 +3,21 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.similarity.TrainType; import com.alibaba.alink.operator.common.similarity.dataConverter.StringModelDataConverter; import com.alibaba.alink.params.similarity.StringTextNearestNeighborTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Find the approximate nearest neighbor of query texts. */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("文本最近邻训练") +@NameEn("Text Nearest Neighbor Prediction Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.similarity.TextNearestNeighbor") public class TextNearestNeighborTrainBatchOp extends BaseNearestNeighborTrainBatchOp implements StringTextNearestNeighborTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextSimilarityPairwiseBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextSimilarityPairwiseBatchOp.java index 1c856dac1..a33702bdf 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextSimilarityPairwiseBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/TextSimilarityPairwiseBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -29,6 +30,7 @@ */ @ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.STRING_TYPES) @NameCn("文本两两相似度计算") +@NameEn("Text Similarity Pairwise") public final class TextSimilarityPairwiseBatchOp extends MapBatchOp implements StringTextPairwiseParams { private static final long serialVersionUID = 627852765048346223L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/VectorApproxNearestNeighborPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/VectorApproxNearestNeighborPredictBatchOp.java index 95611e6dd..b2d2309ef 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/VectorApproxNearestNeighborPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/VectorApproxNearestNeighborPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.similarity.NearestNeighborsMapper; import com.alibaba.alink.params.similarity.NearestNeighborPredictParams; @@ -11,6 +12,7 @@ * Find the approximate nearest neighbor of query vectors. */ @NameCn("向量近似最近邻预测") +@NameEn("Vector Approx Nearest Neighbor Prediction") public class VectorApproxNearestNeighborPredictBatchOp extends ModelMapBatchOp implements NearestNeighborPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/VectorApproxNearestNeighborTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/VectorApproxNearestNeighborTrainBatchOp.java index 38091d8fe..1942ed195 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/VectorApproxNearestNeighborTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/VectorApproxNearestNeighborTrainBatchOp.java @@ -3,16 +3,20 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.similarity.TrainType; import com.alibaba.alink.params.similarity.VectorApproxNearestNeighborTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Find the approximate nearest neighbor of query vectors. */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量近似最近邻训练") +@NameEn("Vector Approx Nearest Neighbor Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.similarity.VectorApproxNearestNeighbor") public class VectorApproxNearestNeighborTrainBatchOp extends BaseNearestNeighborTrainBatchOp implements VectorApproxNearestNeighborTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/VectorNearestNeighborPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/VectorNearestNeighborPredictBatchOp.java index ec1783097..4c759262a 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/VectorNearestNeighborPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/VectorNearestNeighborPredictBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -18,6 +19,7 @@ @InputPorts(values = {@PortSpec(value = PortType.MODEL, suggestions = VectorNearestNeighborTrainBatchOp.class), @PortSpec(PortType.DATA)}) @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量最近邻预测") +@NameEn("Vector Nearest Neighbor Prediction") public class VectorNearestNeighborPredictBatchOp extends ModelMapBatchOp implements NearestNeighborPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/VectorNearestNeighborTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/VectorNearestNeighborTrainBatchOp.java index a310707f4..93a7dec73 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/similarity/VectorNearestNeighborTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/similarity/VectorNearestNeighborTrainBatchOp.java @@ -3,16 +3,20 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.common.similarity.TrainType; import com.alibaba.alink.params.similarity.VectorNearestNeighborTrainParams; +import com.alibaba.alink.pipeline.EstimatorTrainerAnnotation; /** * Find the nearest neighbor of query vectors. */ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量最近邻训练") +@NameEn("Vector Nearest Neighbor Training") +@EstimatorTrainerAnnotation(estimatorName = "com.alibaba.alink.pipeline.similarity.VectorNearestNeighbor") public class VectorNearestNeighborTrainBatchOp extends BaseNearestNeighborTrainBatchOp implements VectorNearestNeighborTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sink/AkSinkBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sink/AkSinkBatchOp.java index 36d563cb0..90ed45d0c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sink/AkSinkBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sink/AkSinkBatchOp.java @@ -5,6 +5,7 @@ import org.apache.flink.table.api.TableSchema; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; @@ -13,8 +14,8 @@ import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.AkSourceBatchOp; +import com.alibaba.alink.operator.batch.utils.DataSetUtil; import com.alibaba.alink.operator.common.io.partition.AkSinkCollectorCreator; -import com.alibaba.alink.operator.common.io.partition.Utils; import com.alibaba.alink.params.io.AkSinkBatchParams; import org.apache.commons.lang3.ArrayUtils; @@ -23,6 +24,7 @@ */ @IoOpAnnotation(name = "ak", ioType = IOType.SinkBatch) @NameCn("AK文件导出") +@NameEn("Ak Sink") public final class AkSinkBatchOp extends BaseSinkBatchOp implements AkSinkBatchParams { @@ -45,7 +47,7 @@ public AkSinkBatchOp linkFrom(BatchOperator ... inputs) { public AkSinkBatchOp sinkFrom(BatchOperator in) { if (getPartitionCols() != null) { - Utils.partitionAndWriteFile( + DataSetUtil.partitionAndWriteFile( in, new AkSinkCollectorCreator( new AkMeta( diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sink/AppendModelStreamFileSinkBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sink/AppendModelStreamFileSinkBatchOp.java index 26a8b710b..6f5223d1b 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sink/AppendModelStreamFileSinkBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sink/AppendModelStreamFileSinkBatchOp.java @@ -12,6 +12,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; @@ -21,8 +22,8 @@ import com.alibaba.alink.common.io.filesystem.FilePath; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.stream.model.FileModelStreamSink; -import com.alibaba.alink.operator.common.stream.model.ModelStreamUtils; +import com.alibaba.alink.operator.common.modelstream.FileModelStreamSink; +import com.alibaba.alink.operator.common.modelstream.ModelStreamUtils; import com.alibaba.alink.params.io.AppendModelStreamFileSinkParams; import java.io.IOException; @@ -31,6 +32,7 @@ @IoOpAnnotation(name = "append_model_stream", ioType = IOType.SinkBatch) @InputPorts(values = {@PortSpec(PortType.MODEL)}) @NameCn("模型流导出") +@NameEn("Append Model Stream File Sink") public class AppendModelStreamFileSinkBatchOp extends BaseSinkBatchOp implements AppendModelStreamFileSinkParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sink/CatalogSinkBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sink/CatalogSinkBatchOp.java index ae06a8799..da85aeb96 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sink/CatalogSinkBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sink/CatalogSinkBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; @@ -15,6 +16,7 @@ */ @IoOpAnnotation(name = "catalog", ioType = IOType.SinkBatch) @NameCn("Catalog数据表导出") +@NameEn("Catalog Sink") public class CatalogSinkBatchOp extends BaseSinkBatchOp implements HasCatalogObject { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sink/CsvSinkBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sink/CsvSinkBatchOp.java index b86d2c0b1..50c6963b3 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sink/CsvSinkBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sink/CsvSinkBatchOp.java @@ -7,15 +7,16 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; import com.alibaba.alink.common.io.filesystem.copy.csv.TextOutputFormat; import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.DataSetUtil; import com.alibaba.alink.operator.common.io.csv.CsvUtil.FlattenCsvFromRow; import com.alibaba.alink.operator.common.io.csv.CsvUtil.FormatCsvFunc; import com.alibaba.alink.operator.common.io.partition.CsvSinkCollectorCreator; -import com.alibaba.alink.operator.common.io.partition.Utils; import com.alibaba.alink.params.io.CsvSinkBatchParams; /** @@ -24,6 +25,7 @@ @IoOpAnnotation(name = "csv", ioType = IOType.SinkBatch) @NameCn("CSV文件导出") +@NameEn("Csv Sink") public final class CsvSinkBatchOp extends BaseSinkBatchOp implements CsvSinkBatchParams { @@ -41,10 +43,10 @@ public CsvSinkBatchOp(Params params) { public CsvSinkBatchOp sinkFrom(BatchOperator in) { if (getPartitionCols() != null) { - Utils.partitionAndWriteFile( + DataSetUtil.partitionAndWriteFile( in, new CsvSinkCollectorCreator( - new FormatCsvFunc(in.getColTypes(), getFieldDelimiter(), getQuoteChar()), + new FormatCsvFunc(in.getColTypes(), getFieldDelimiter(), getRowDelimiter(), getQuoteChar()), new FlattenCsvFromRow(getRowDelimiter()), getRowDelimiter() ), @@ -55,6 +57,7 @@ public CsvSinkBatchOp sinkFrom(BatchOperator in) { final String filePath = getFilePath().getPathStr(); final String fieldDelim = getFieldDelimiter(); + final String rowDelim = getRowDelimiter(); final int numFiles = getNumFiles(); final TypeInformation [] types = in.getColTypes(); final Character quoteChar = getQuoteChar(); @@ -65,7 +68,7 @@ public CsvSinkBatchOp sinkFrom(BatchOperator in) { } DataSet textLines = in.getDataSet() - .map(new FormatCsvFunc(types, fieldDelim, quoteChar)) + .map(new FormatCsvFunc(types, fieldDelim, rowDelim, quoteChar)) .map(new FlattenCsvFromRow(getRowDelimiter())); TextOutputFormat tof = new TextOutputFormat <>( diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sink/HBaseSinkBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sink/HBaseSinkBatchOp.java index a7ad05fa9..508d0e7d6 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sink/HBaseSinkBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sink/HBaseSinkBatchOp.java @@ -4,6 +4,7 @@ import org.apache.flink.table.api.TableSchema; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.io.annotations.AnnotationUtils; @@ -18,7 +19,8 @@ */ @IoOpAnnotation(name = "hbase_batch_sink", ioType = IOType.SinkBatch) @ParamSelectColumnSpec(name = "rowKeyCol", allowedTypeCollections = TypeCollections.STRING_TYPE) -@NameCn("导出到HBase") +@NameCn("HBase导出") +@NameEn("HBase Sink") public final class HBaseSinkBatchOp extends BaseSinkBatchOp implements HBaseSinkParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sink/LibSvmSinkBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sink/LibSvmSinkBatchOp.java index 7f834eaa2..ce4731293 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sink/LibSvmSinkBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sink/LibSvmSinkBatchOp.java @@ -8,12 +8,13 @@ import org.apache.flink.types.Row; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.params.io.LibSvmSinkBatchParams; @@ -26,6 +27,7 @@ */ @IoOpAnnotation(name = "libsvm", ioType = IOType.SinkBatch) @NameCn("LibSvm文件导出") +@NameEn("LibSvm Sink") public final class LibSvmSinkBatchOp extends BaseSinkBatchOp implements LibSvmSinkBatchParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sink/TFRecordDatasetSinkBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sink/TFRecordDatasetSinkBatchOp.java index 109e8c4b2..2756576b2 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sink/TFRecordDatasetSinkBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sink/TFRecordDatasetSinkBatchOp.java @@ -4,6 +4,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; @@ -16,7 +17,8 @@ * Sink batch op data to a file system with TFRecordDataset format. */ @IoOpAnnotation(name = "tfrecord", ioType = IOType.SinkBatch) -@NameCn("TFRecordDataset文件导出") +@NameCn("TFRecord Dataset文件导出") +@NameEn("TFRecord Dataset Sink") public final class TFRecordDatasetSinkBatchOp extends BaseSinkBatchOp implements TFRecordDatasetSinkParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sink/TextSinkBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sink/TextSinkBatchOp.java index d4411a841..9ce57c65f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sink/TextSinkBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sink/TextSinkBatchOp.java @@ -5,6 +5,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; @@ -17,6 +18,7 @@ */ @IoOpAnnotation(name = "text", ioType = IOType.SinkBatch) @NameCn("Text文件导出") +@NameEn("Text Sink") public final class TextSinkBatchOp extends BaseSinkBatchOp implements TextSinkBatchParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sink/TsvSinkBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sink/TsvSinkBatchOp.java index bf91160a1..bc69b5973 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sink/TsvSinkBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sink/TsvSinkBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; @@ -14,6 +15,7 @@ */ @IoOpAnnotation(name = "tsv", ioType = IOType.SinkBatch) @NameCn("TSV文件导出") +@NameEn("TSV Sink") public final class TsvSinkBatchOp extends BaseSinkBatchOp implements TsvSinkBatchParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sink/XlsSinkBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sink/XlsSinkBatchOp.java new file mode 100644 index 000000000..ebc7ed5f3 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sink/XlsSinkBatchOp.java @@ -0,0 +1,48 @@ +package com.alibaba.alink.operator.batch.sink; + +import org.apache.flink.api.common.io.FileOutputFormat; +import org.apache.flink.core.fs.FileSystem.WriteMode; +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.io.annotations.AnnotationUtils; +import com.alibaba.alink.common.io.annotations.IOType; +import com.alibaba.alink.common.io.annotations.IoOpAnnotation; +import com.alibaba.alink.common.io.xls.XlsReaderClassLoader; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.params.io.XlsSinkParams; + +@IoOpAnnotation(name = "xls_sink", ioType = IOType.SinkBatch) +@NameCn("Xlsx表格写出") +@NameEn("Xls Sink") +public class XlsSinkBatchOp extends BaseSinkBatchOp implements XlsSinkParams { + public XlsSinkBatchOp() { + this(new Params()); + } + + private final XlsReaderClassLoader factory; + + public XlsSinkBatchOp(Params params) { + super(AnnotationUtils.annotatedName(XlsSinkBatchOp.class), params); + factory = new XlsReaderClassLoader("0.11"); + } + + @Override + protected XlsSinkBatchOp sinkFrom(BatchOperator in) { + FileOutputFormat outputFormat = XlsReaderClassLoader + .create(factory).createOutputFormat(getParams(), in.getSchema()); + + if (getOverwriteSink()) { + outputFormat.setWriteMode(WriteMode.OVERWRITE); + } else { + outputFormat.setWriteMode(WriteMode.NO_OVERWRITE); + } + + in.getDataSet().output(outputFormat) + .name("xls-file-sink") + .setParallelism(getNumFiles()); + + return this; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/source/AkSourceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/source/AkSourceBatchOp.java index 2833e764a..3b4bff561 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/source/AkSourceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/source/AkSourceBatchOp.java @@ -9,15 +9,16 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; import com.alibaba.alink.common.io.filesystem.AkUtils; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.utils.DataSetUtil; import com.alibaba.alink.operator.common.io.partition.AkSourceCollectorCreator; -import com.alibaba.alink.operator.common.io.partition.Utils; import com.alibaba.alink.params.io.AkSourceParams; import java.io.IOException; @@ -27,6 +28,7 @@ */ @IoOpAnnotation(name = "ak", ioType = IOType.SourceBatch) @NameCn("AK文件读入") +@NameEn("AK Source") public final class AkSourceBatchOp extends BaseSourceBatchOp implements AkSourceParams { @@ -67,7 +69,7 @@ public Table initializeDataSource() { } else { try { Tuple2 , TableSchema> schemaAndData = - Utils.readFromPartitionBatch( + DataSetUtil.readFromPartitionBatch( getParams(), getMLEnvironmentId(), new AkSourceCollectorCreator(meta) ); diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/source/CatalogSourceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/source/CatalogSourceBatchOp.java index 0ee80d3c9..91a893470 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/source/CatalogSourceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/source/CatalogSourceBatchOp.java @@ -4,6 +4,7 @@ import org.apache.flink.table.api.Table; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; @@ -15,6 +16,7 @@ */ @IoOpAnnotation(name = "catalog", ioType = IOType.SourceBatch) @NameCn("Catalog读入") +@NameEn("Catalog Source") public class CatalogSourceBatchOp extends BaseSourceBatchOp implements HasCatalogObject { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/source/CsvSourceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/source/CsvSourceBatchOp.java index 3a678900d..a95e0d624 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/source/CsvSourceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/source/CsvSourceBatchOp.java @@ -6,6 +6,7 @@ import org.apache.flink.table.api.TableSchema; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; @@ -13,7 +14,7 @@ import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.io.csv.CsvTypeConverter; -import com.alibaba.alink.operator.common.io.csv.InternalCsvSourceBatchOp; +import com.alibaba.alink.operator.common.io.csv.InternalCsvSourceBetaBatchOp; import com.alibaba.alink.params.io.CsvSourceParams; /** @@ -28,6 +29,7 @@ */ @IoOpAnnotation(name = "csv", ioType = IOType.SourceBatch) @NameCn("CSV文件读入") +@NameEn("CSV Source") public class CsvSourceBatchOp extends BaseSourceBatchOp implements CsvSourceParams { @@ -75,7 +77,7 @@ protected Table initializeDataSource() { TableUtil.schema2SchemaStr(new TableSchema(colNames, CsvTypeConverter.rewriteColTypes(colTypes))) ); - BatchOperator source = new InternalCsvSourceBatchOp(rawCsvParams); + BatchOperator source = new InternalCsvSourceBetaBatchOp(rawCsvParams); source = CsvTypeConverter.toTensorPipelineModel(getParams(), colNames, colTypes).transform(source); source = CsvTypeConverter.toVectorPipelineModel(getParams(), colNames, colTypes).transform(source); diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/source/LibSvmSourceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/source/LibSvmSourceBatchOp.java index 83a4c2906..13335a944 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/source/LibSvmSourceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/source/LibSvmSourceBatchOp.java @@ -8,11 +8,12 @@ import org.apache.flink.types.Row; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; import com.alibaba.alink.common.linalg.Vector; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.local.source.LibSvmSourceLocalOp; import com.alibaba.alink.params.io.LibSvmSourceParams; @@ -22,6 +23,7 @@ */ @IoOpAnnotation(name = "libsvm", ioType = IOType.SourceBatch) @NameCn("LibSvm文件读入") +@NameEn("LibSvm Source") public final class LibSvmSourceBatchOp extends BaseSourceBatchOp implements LibSvmSourceParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/source/MemSourceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/source/MemSourceBatchOp.java index 8fcfd6202..3ba8d35f9 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/source/MemSourceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/source/MemSourceBatchOp.java @@ -10,9 +10,10 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.MTable; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import java.util.Arrays; @@ -23,6 +24,7 @@ */ @IoOpAnnotation(name = MemSourceBatchOp.NAME, ioType = IOType.SourceBatch) @NameCn("内存数据读入") +@NameEn("Memory Source") public final class MemSourceBatchOp extends BaseSourceBatchOp { static final String NAME = "memory"; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/source/NumSeqSourceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/source/NumSeqSourceBatchOp.java index a8309bac4..880461efe 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/source/NumSeqSourceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/source/NumSeqSourceBatchOp.java @@ -14,10 +14,11 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.params.io.NumSeqSourceParams; /** @@ -25,6 +26,7 @@ */ @IoOpAnnotation(name = "num_seq", ioType = IOType.SourceBatch) @NameCn("数值队列数据源") +@NameEn("Number Sequence Source") public final class NumSeqSourceBatchOp extends BaseSourceBatchOp implements NumSeqSourceParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/source/ParquetSourceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/source/ParquetSourceBatchOp.java index 182b4f4d5..1988b32f6 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/source/ParquetSourceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/source/ParquetSourceBatchOp.java @@ -16,6 +16,7 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.filesystem.AkUtils; @@ -25,13 +26,14 @@ import com.alibaba.alink.common.io.parquet.ParquetClassLoaderFactory; import com.alibaba.alink.common.io.parquet.ParquetReaderFactory; import com.alibaba.alink.common.io.plugin.wrapper.RichInputFormatGenericWithClassLoader; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.params.io.ParquetSourceParams; import java.io.IOException; -@NameCn("parquet文件读入") +@NameCn("Parquet文件读入") +@NameEn("Parquet Source") public class ParquetSourceBatchOp extends BaseSourceBatchOp implements ParquetSourceParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/source/RandomTableSourceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/source/RandomTableSourceBatchOp.java index 49ecfa105..36bc8c976 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/source/RandomTableSourceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/source/RandomTableSourceBatchOp.java @@ -5,10 +5,11 @@ import org.apache.flink.table.api.Table; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; -import com.alibaba.alink.common.source.RandomTableSourceUtils; +import com.alibaba.alink.operator.common.utils.RandomTableSourceUtils; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.dataproc.RandomTable; import com.alibaba.alink.params.io.RandomTableSourceBatchParams; @@ -26,6 +27,7 @@ */ @IoOpAnnotation(name = "random_table", ioType = IOType.SourceBatch) @NameCn("随机生成结构数据源") +@NameEn("Random Table Source") public final class RandomTableSourceBatchOp extends BaseSourceBatchOp implements RandomTableSourceBatchParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/source/RandomVectorSourceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/source/RandomVectorSourceBatchOp.java index 2c4a32c4c..a05cc5e03 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/source/RandomVectorSourceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/source/RandomVectorSourceBatchOp.java @@ -4,6 +4,7 @@ import org.apache.flink.table.api.Table; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; @@ -16,6 +17,7 @@ */ @IoOpAnnotation(name = "random_vector", ioType = IOType.SourceBatch) @NameCn("随机生成向量数据源") +@NameEn("Random Vector Source") public final class RandomVectorSourceBatchOp extends BaseSourceBatchOp implements RandomVectorSourceBatchParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/source/TFRecordDatasetSourceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/source/TFRecordDatasetSourceBatchOp.java index 27d17b2ff..17c777df4 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/source/TFRecordDatasetSourceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/source/TFRecordDatasetSourceBatchOp.java @@ -6,11 +6,12 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; import com.alibaba.alink.common.io.filesystem.TFRecordDatasetUtils.TFRecordDatasetInputFormat; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.params.io.TFRecordDatasetSourceParams; @@ -19,6 +20,7 @@ */ @IoOpAnnotation(name = "tfrecord", ioType = IOType.SourceBatch) @NameCn("TFRecordDataset文件读入") +@NameEn("TF Record Dataset Source") public final class TFRecordDatasetSourceBatchOp extends BaseSourceBatchOp implements TFRecordDatasetSourceParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/source/TableSourceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/source/TableSourceBatchOp.java index f0e0c9f6b..2bab141ed 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/source/TableSourceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/source/TableSourceBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -17,6 +18,7 @@ @InputPorts() @OutputPorts(values = @PortSpec(PortType.DATA)) @NameCn("Table数据读入") +@NameEn("Table Source") public final class TableSourceBatchOp extends BatchOperator { private static final long serialVersionUID = -5220231513565199001L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/source/TextSourceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/source/TextSourceBatchOp.java index c40283f0d..58eeb923d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/source/TextSourceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/source/TextSourceBatchOp.java @@ -4,6 +4,7 @@ import org.apache.flink.table.api.Table; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; @@ -14,6 +15,7 @@ */ @IoOpAnnotation(name = "text", ioType = IOType.SourceBatch) @NameCn("Text文件读入") +@NameEn("Text Source") public final class TextSourceBatchOp extends BaseSourceBatchOp implements TextSourceParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/source/TsvSourceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/source/TsvSourceBatchOp.java index b227b5ac8..321bfce06 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/source/TsvSourceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/source/TsvSourceBatchOp.java @@ -4,6 +4,7 @@ import org.apache.flink.table.api.Table; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; @@ -14,6 +15,7 @@ */ @IoOpAnnotation(name = "tsv", ioType = IOType.SourceBatch) @NameCn("TSV文件读入") +@NameEn("Tsv Source") public final class TsvSourceBatchOp extends BaseSourceBatchOp implements TsvSourceParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/source/XlsSourceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/source/XlsSourceBatchOp.java index 7c5ab9683..c1f608f87 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/source/XlsSourceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/source/XlsSourceBatchOp.java @@ -13,16 +13,18 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; import com.alibaba.alink.common.io.plugin.wrapper.RichInputFormatGenericWithClassLoader; import com.alibaba.alink.common.io.xls.XlsReaderClassLoader; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.params.io.XlsSourceParams; -@IoOpAnnotation(name = "xls", ioType = IOType.SourceBatch) +@IoOpAnnotation(name = "xls_source", ioType = IOType.SourceBatch) @NameCn("Xls和Xlsx表格读入") +@NameEn("Xls and Xlsx File Source") public class XlsSourceBatchOp extends BaseSourceBatchOp implements XlsSourceParams { public XlsSourceBatchOp() { @@ -41,7 +43,7 @@ protected Table initializeDataSource() { Tuple2 , TableSchema> sourceFunction = XlsReaderClassLoader .create(factory) - .create(getParams()); + .createInputFormat(getParams()); RichInputFormat inputFormat = new RichInputFormatGenericWithClassLoader <>(factory, sourceFunction.f0); @@ -49,7 +51,9 @@ protected Table initializeDataSource() { DataSet data = MLEnvironmentFactory .get(getMLEnvironmentId()) .getExecutionEnvironment() - .createInput(inputFormat, new RowTypeInfo(sourceFunction.f1.getFieldTypes())); + .createInput(inputFormat, new RowTypeInfo(sourceFunction.f1.getFieldTypes())) + .name("xls-file-source") + .rebalance();; return DataSetConversionUtil.toTable(getMLEnvironmentId(), data, sourceFunction.f1); } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/AsBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/AsBatchOp.java index e44280891..b411895c7 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/AsBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/AsBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.params.sql.AsParams; @@ -10,6 +11,7 @@ * Rename the fields of a batch operator. */ @NameCn("SQL操作:As") +@NameEn("SQL As Operation") public final class AsBatchOp extends BaseSqlApiBatchOp implements AsParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/common/sql/BatchSqlOperators.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/BatchSqlOperators.java similarity index 99% rename from core/src/main/java/com/alibaba/alink/operator/common/sql/BatchSqlOperators.java rename to core/src/main/java/com/alibaba/alink/operator/batch/sql/BatchSqlOperators.java index 9cc9a40a7..fd1b69102 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/sql/BatchSqlOperators.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/BatchSqlOperators.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.operator.common.sql; +package com.alibaba.alink.operator.batch.sql; import org.apache.flink.table.api.bridge.java.BatchTableEnvironment; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/DistinctBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/DistinctBatchOp.java index b82184ba9..07d68a5a8 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/DistinctBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/DistinctBatchOp.java @@ -3,12 +3,14 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.BatchOperator; /** * Remove duplicated records. */ @NameCn("SQL操作:Distinct") +@NameEn("SQL Distinct Operation") public final class DistinctBatchOp extends BaseSqlApiBatchOp { private static final long serialVersionUID = 2774293287356122519L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/FilterBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/FilterBatchOp.java index 8795961ad..a802286a1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/FilterBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/FilterBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.params.sql.FilterParams; @@ -10,6 +11,7 @@ * Filter records in the batch operator. */ @NameCn("SQL操作:Filter") +@NameEn("SQL Filter Operation") public final class FilterBatchOp extends BaseSqlApiBatchOp implements FilterParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/FullOuterJoinBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/FullOuterJoinBatchOp.java index 9f35a847a..748d004ac 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/FullOuterJoinBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/FullOuterJoinBatchOp.java @@ -4,11 +4,11 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.sql.BatchSqlOperators; import com.alibaba.alink.params.sql.JoinParams; /** @@ -18,6 +18,7 @@ @InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(PortType.DATA)}) @OutputPorts(values = @PortSpec(PortType.DATA)) @NameCn("SQL操作:FullOuterJoin") +@NameEn("SQL FullOuterJoin Operation") public final class FullOuterJoinBatchOp extends BaseSqlApiBatchOp implements JoinParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/GroupByBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/GroupByBatchOp.java index 9274b2d44..61fd98f66 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/GroupByBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/GroupByBatchOp.java @@ -6,10 +6,10 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.sql.builtin.agg.MTableAgg; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.sql.BatchSqlOperators; import com.alibaba.alink.params.sql.GroupByParams; import org.apache.commons.lang3.StringUtils; @@ -23,6 +23,7 @@ * Apply the "group by" operation on the input batch operator. */ @NameCn("SQL操作:GroupBy") +@NameEn("SQL GroupBy Operation") public final class GroupByBatchOp extends BaseSqlApiBatchOp implements GroupByParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/IntersectAllBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/IntersectAllBatchOp.java index ca08f7d04..5f03c78c4 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/IntersectAllBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/IntersectAllBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -16,6 +17,7 @@ @InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(PortType.DATA)}) @OutputPorts(values = @PortSpec(PortType.DATA)) @NameCn("SQL操作:IntersectAll") +@NameEn("SQL IntersectAll Operation") public final class IntersectAllBatchOp extends BatchOperator { private static final long serialVersionUID = -8644196260740789294L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/IntersectBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/IntersectBatchOp.java index a801422df..4c2b5f400 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/IntersectBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/IntersectBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -16,6 +17,7 @@ @InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(PortType.DATA)}) @OutputPorts(values = @PortSpec(PortType.DATA)) @NameCn("SQL操作:Intersect") +@NameEn("SQL Intersect Operation") public final class IntersectBatchOp extends BatchOperator { private static final long serialVersionUID = -2981473236917210647L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/JoinBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/JoinBatchOp.java index 24ac297fa..a095d05c2 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/JoinBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/JoinBatchOp.java @@ -4,12 +4,12 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.sql.BatchSqlOperators; import com.alibaba.alink.params.sql.JoinParams; /** @@ -18,6 +18,7 @@ @InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(PortType.DATA)}) @OutputPorts(values = @PortSpec(PortType.DATA)) @NameCn("SQL操作:Join") +@NameEn("SQL Join Operation") public final class JoinBatchOp extends BaseSqlApiBatchOp implements JoinParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/LeftOuterJoinBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/LeftOuterJoinBatchOp.java index 2ef1a3f26..a1e229803 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/LeftOuterJoinBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/LeftOuterJoinBatchOp.java @@ -4,11 +4,11 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.sql.BatchSqlOperators; import com.alibaba.alink.params.sql.JoinParams; /** @@ -18,6 +18,7 @@ @InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(PortType.DATA)}) @OutputPorts(values = @PortSpec(PortType.DATA)) @NameCn("SQL操作:LeftOuterJoin") +@NameEn("SQL LeftOuterJoin Operation") public final class LeftOuterJoinBatchOp extends BaseSqlApiBatchOp implements JoinParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/MinusAllBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/MinusAllBatchOp.java index ce22fe19f..0917d974e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/MinusAllBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/MinusAllBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -15,6 +16,7 @@ @InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(PortType.DATA)}) @OutputPorts(values = @PortSpec(PortType.DATA)) @NameCn("SQL操作:MinusAll") +@NameEn("SQL MinusAll Operation") public final class MinusAllBatchOp extends BatchOperator { private static final long serialVersionUID = -7582100858266866075L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/MinusBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/MinusBatchOp.java index 95cf009d2..0b69cc812 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/MinusBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/MinusBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -15,6 +16,7 @@ @InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(PortType.DATA)}) @OutputPorts(values = @PortSpec(PortType.DATA)) @NameCn("SQL操作:Minus") +@NameEn("SQL Minus Operation") public final class MinusBatchOp extends BatchOperator { private static final long serialVersionUID = 5643333177043157438L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/OrderByBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/OrderByBatchOp.java index b52e2b208..b1b0f50ec 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/OrderByBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/OrderByBatchOp.java @@ -4,6 +4,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.params.sql.OrderByParams; @@ -11,6 +12,7 @@ * Order the batch operator. */ @NameCn("SQL操作:OrderBy") +@NameEn("SQL OrderBy Operation") public final class OrderByBatchOp extends BaseSqlApiBatchOp implements OrderByParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/RightOuterJoinBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/RightOuterJoinBatchOp.java index fcb970dcd..43be0da6a 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/RightOuterJoinBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/RightOuterJoinBatchOp.java @@ -4,11 +4,11 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.sql.BatchSqlOperators; import com.alibaba.alink.params.sql.JoinParams; /** @@ -17,6 +17,7 @@ @InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(PortType.DATA)}) @OutputPorts(values = @PortSpec(PortType.DATA)) @NameCn("SQL操作:RightOuterJoin") +@NameEn("SQL RightOuterJoin Operation") public final class RightOuterJoinBatchOp extends BaseSqlApiBatchOp implements JoinParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/SelectBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/SelectBatchOp.java index 92f9c637b..14e2ba212 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/SelectBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/SelectBatchOp.java @@ -2,10 +2,11 @@ import org.apache.flink.ml.api.misc.param.Params; +import com.alibaba.alink.common.annotation.Internal; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.utils.MapBatchOp; -import com.alibaba.alink.operator.common.sql.BatchSqlOperators; import com.alibaba.alink.operator.common.sql.SelectUtils; import com.alibaba.alink.operator.common.sql.SimpleSelectMapper; import com.alibaba.alink.params.sql.SelectParams; @@ -14,6 +15,7 @@ * Select the fields of a batch operator. */ @NameCn("SQL操作:Select") +@NameEn("SQL Select Operation") public final class SelectBatchOp extends BaseSqlApiBatchOp implements SelectParams { @@ -51,6 +53,7 @@ public SelectBatchOp linkFrom(BatchOperator ... inputs) { return this; } + @Internal private class SimpleSelectBatchOp extends MapBatchOp implements SelectParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/UnionAllBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/UnionAllBatchOp.java index 5c1b2b061..fcc959ad4 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/UnionAllBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/UnionAllBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -15,6 +16,7 @@ @InputPorts(values = @PortSpec(value = PortType.DATA, isRepeated = true)) @OutputPorts(values = @PortSpec(value = PortType.DATA)) @NameCn("SQL操作:UnionAll") +@NameEn("SQL UnionAll Operation") public final class UnionAllBatchOp extends BatchOperator { private static final long serialVersionUID = 2468662188701775196L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/UnionBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/UnionBatchOp.java index 0928643c2..b07320dea 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/UnionBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/UnionBatchOp.java @@ -4,6 +4,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -16,6 +17,7 @@ @InputPorts(values = @PortSpec(value = PortType.DATA, isRepeated = true)) @OutputPorts(values = @PortSpec(value = PortType.DATA)) @NameCn("SQL操作:Union") +@NameEn("SQL Union Operation") public final class UnionBatchOp extends BatchOperator { private static final long serialVersionUID = 6141413513148024360L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/WhereBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/WhereBatchOp.java index c4768fb50..f6ae9f172 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/WhereBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/WhereBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.params.dataproc.HasClause; import com.alibaba.alink.params.sql.WhereParams; @@ -11,6 +12,7 @@ * Filter records in the batch operator. */ @NameCn("SQL操作:Where") +@NameEn("SQL Where Operation") public final class WhereBatchOp extends BaseSqlApiBatchOp implements WhereParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/ChiSquareTestBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/ChiSquareTestBatchOp.java index f5488db27..7fcb9e957 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/ChiSquareTestBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/ChiSquareTestBatchOp.java @@ -5,6 +5,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -35,6 +36,7 @@ @ParamSelectColumnSpec(name = "selectedCols") @ParamSelectColumnSpec(name = "labelCol") @NameCn("卡方检验") +@NameEn("ChiSquare Test") public final class ChiSquareTestBatchOp extends BatchOperator implements ChiSquareTestParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/CorrelationBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/CorrelationBatchOp.java index 004a39b77..b8a6044a8 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/CorrelationBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/CorrelationBatchOp.java @@ -11,6 +11,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -18,11 +19,11 @@ import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.exceptions.AkIllegalOperationException; import com.alibaba.alink.common.exceptions.AkPreconditions; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.TableSourceBatchOp; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationDataConverter; import com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationResult; import com.alibaba.alink.operator.common.statistics.basicstatistic.SpearmanCorrelation; @@ -40,6 +41,7 @@ @OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) @ParamSelectColumnSpec(name = "selectedCols") @NameCn("相关系数") +@NameEn("Correlation") public final class CorrelationBatchOp extends BatchOperator implements CorrelationParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/InternalFullStatsBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/InternalFullStatsBatchOp.java index ca8e14a75..8e9b89fbf 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/InternalFullStatsBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/InternalFullStatsBatchOp.java @@ -16,11 +16,12 @@ import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.utils.StatsVisualizer; import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter; -import com.alibaba.alink.operator.common.statistics.StatisticUtil; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.statistics.FullStats; import com.alibaba.alink.operator.common.statistics.statistics.FullStatsConverter; import com.alibaba.alink.operator.common.statistics.statistics.SummaryResultTable; import com.alibaba.alink.params.statistics.HasStatLevel_L1.StatLevel; +import com.alibaba.alink.params.statistics.HasTableNames; import java.util.Arrays; import java.util.List; @@ -28,7 +29,7 @@ @SuppressWarnings({"UnusedReturnValue", "unused"}) @Internal -public class InternalFullStatsBatchOp extends BatchOperator { +public class InternalFullStatsBatchOp extends BatchOperator implements HasTableNames { public InternalFullStatsBatchOp() { this(new Params()); @@ -46,7 +47,7 @@ public InternalFullStatsBatchOp linkFrom(BatchOperator ... inputs) { DataSet > unionSrtDataSet = null; for (int i = 0; i < n; i += 1) { final int index = i; - DataSet srtDataSet = StatisticUtil.getSRT(inputs[i], StatLevel.L3); + DataSet srtDataSet = StatisticsHelper.getSRT(inputs[i], StatLevel.L3); //noinspection Convert2Lambda DataSet > indexedSrtDataSet = srtDataSet .map(new MapFunction >() { @@ -59,9 +60,18 @@ public Tuple2 map(SummaryResultTable value) { ? indexedSrtDataSet : unionSrtDataSet.union(indexedSrtDataSet); } - String[] tableNames = Arrays.stream(inputs) - .map(d -> d.getOutputTable().toString()) - .toArray(String[]::new); + + String[] tableNames = new String[n]; + for (int i = 0; i < n; i++) { + tableNames[i] = "table" + String.valueOf(i + 1); + } + if (getParams().contains(HasTableNames.TABLE_NAMES)) { + String[] inputNames = getTableNames(); + for (int i = 0; i < Math.min(n, inputNames.length); i++) { + tableNames[i] = inputNames[i]; + } + } + // assume all datasets have same schemas final TypeInformation [] colTypes = inputs[0].getColTypes(); //noinspection Convert2Lambda diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/MdsBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/MdsBatchOp.java new file mode 100644 index 000000000..25fd53dde --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/MdsBatchOp.java @@ -0,0 +1,237 @@ +package com.alibaba.alink.operator.batch.statistics; + +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.EigenSolver; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.common.utils.TableUtil; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.params.statistics.MdsParams; + +import java.util.ArrayList; + +/** + * @author Fan Hong + */ + +/** + * Multi-Dimensional Scaling, MDS, is a dimension reduction techniques for high-dimensional data. + * MDS reduces (projects or embeds) data into a lower-dimensional, usually 2D, space. + * The object of MDS is to keep the distances between data items in the original space as much as possible. + * Therefore, MDS can be used to perceive clusters or outliers. + */ +@NameCn("Multi-Dimensional Scaling") +@NameEn("Multi-Dimensional Scaling") +public class MdsBatchOp extends BatchOperator implements MdsParams { + + private static final long serialVersionUID = 7353869732042122439L; + + /** + * Default constructor + */ + public MdsBatchOp() { + super(null); + } + + /** + * Constructor + * + * @param params: parameters + */ + public MdsBatchOp(Params params) { + super(params); + } + + @Override + public MdsBatchOp linkFrom(BatchOperator... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + String[] selectedColNames = getSelectedCols(); + if (selectedColNames == null) { + selectedColNames = TableUtil.getNumericCols(in.getSchema()); + } + + String[] keepColNames = getReservedCols(); + if (keepColNames == null) { + keepColNames = in.getSchema().getFieldNames(); + } + + final String coordsColNamePrefix = getOutputColPrefix(); + final Integer numDimensions = getDim(); + + // Map column names to column indices + String[] allColNames = in.getColNames(); + TypeInformation [] allColTypes = in.getColTypes(); + + int numSelectedColNames = selectedColNames.length; + final int[] selectedColIndices = new int[numSelectedColNames]; + for (int i = 0; i < numSelectedColNames; i += 1) { + selectedColIndices[i] = TableUtil.findColIndexWithAssertAndHint(allColNames, selectedColNames[i]); + } + + int numKeepColNames = keepColNames.length; + final int[] keepColIndices = new int[numKeepColNames]; + for (int i = 0; i < numKeepColNames; i += 1) { + keepColIndices[i] = TableUtil.findColIndexWithAssertAndHint(allColNames, keepColNames[i]); + } + + DataSet out = in.getDataSet() + .mapPartition(new MdsComputationMapPartitionFunction(numDimensions, selectedColIndices, keepColIndices)) + .setParallelism(1); + + ArrayList colNames = new ArrayList <>(); + ArrayList > colTypes = new ArrayList <>(); + for (int i = 0; i < numDimensions; i += 1) { + colNames.add(new StringBuilder().append(coordsColNamePrefix).append(i).toString()); + colTypes.add(Types.DOUBLE); + } + + for (int i = 0; i < numKeepColNames; i += 1) { + colNames.add(allColNames[keepColIndices[i]]); + colTypes.add(allColTypes[keepColIndices[i]]); + } + + setOutput(out, colNames.toArray(new String[0]), colTypes.toArray(new TypeInformation [0])); + + return this; + } + + public static class MdsComputationMapPartitionFunction extends RichMapPartitionFunction { + + private static final long serialVersionUID = 5257680310195705244L; + private int numDimensions; + private int[] selectedColIndices; + private int[] keepColIndices; + + public MdsComputationMapPartitionFunction(int numDimensions, int[] selectedColIndices, int[] keepColIndices) { + this.numDimensions = numDimensions; + this.selectedColIndices = selectedColIndices; + this.keepColIndices = keepColIndices; + } + + @Override + public void mapPartition(Iterable iterable, Collector collector) throws Exception { + ArrayList rowList = new ArrayList <>(); + for (Row row : iterable) { + rowList.add(row); + } + + // extract data for computation + int numCols = this.selectedColIndices.length; + ArrayList dataList = new ArrayList <>(); + for (Row row : rowList) { + double[] item = new double[numCols]; + for (int i = 0; i < numCols; i += 1) { + item[i] = (double) row.getField(this.selectedColIndices[i]); + } + dataList.add(item); + } + + int n = dataList.size(); + System.out.println(n); + + double[][] data = new double[n][numCols]; + data = dataList.toArray(data); + + // Perform MDS computation + MdsComputation mdsComputation = new MdsComputation(n, numCols, data, numDimensions); + DenseVector[] coordinates = mdsComputation.compute(); + + // Collect results + int k = 0; + for (Row row : rowList) { + Row out = new Row(numDimensions + keepColIndices.length); + for (int i = 0; i < numDimensions; i += 1) { + out.setField(i, coordinates[i].get(k)); + } + for (int i = 0; i < keepColIndices.length; i += 1) { + out.setField(numDimensions + i, row.getField(keepColIndices[i])); + } + // System.out.println(row); + collector.collect(out); + k += 1; + } + } + + class MdsComputation { + private int n; + private int m; + private double[][] data; + private int k; + + MdsComputation(int n, int m, double[][] data, int k) { + this.n = n; + this.m = m; + this.data = data; + this.k = k; + } + + double computeDistance(int n, double[] d0, double[] d1) { + double sum = 0; + for (int i = 0; i < n; i += 1) { + sum += Math.pow(d0[i] - d1[i], 2.); + } + return Math.sqrt(sum); + } + + double[][] computeDistanceMatrix(int n, int m, double[][] data) { + double[][] dist = new double[n][n]; + for (int i = 0; i < n; i += 1) { + dist[i] = new double[n]; + for (int j = 0; j < i; j += 1) { + dist[i][j] = dist[j][i] = computeDistance(m, data[i], data[j]); + } + dist[i][i] = 0; + } + return dist; + } + + DenseVector[] compute() { + // STEP 0: compute the distances matrix between data items + double[][] dist = computeDistanceMatrix(n, m, data); + + // STEP 1: double-center the matrix + double rowSum[] = new double[n]; + double colSum[] = new double[n]; + double totalSum = 0; + for (int i = 0; i < n; i += 1) { + for (int j = 0; j < n; j += 1) { + rowSum[i] += dist[i][j]; + colSum[j] += dist[i][j]; + totalSum += dist[i][j]; + } + } + for (int i = 0; i < n; i += 1) { + for (int j = 0; j < n; j += 1) { + dist[i][j] += rowSum[i] / n + colSum[j] / n - totalSum / n / n; + } + } + + // STEP 2: get the k largest eigenvalues and eigenvectors + //get eig values and eig vectors + DenseMatrix dm = new DenseMatrix(dist); + double epsilon = 1e-6; + scala.Tuple2 eigens = EigenSolver.solve(dm, k, epsilon, 300); + + // STEP 3: obtain the embedding coordinates + DenseVector[] coords = new DenseVector[eigens._2.numCols()]; + for (int i = 0; i < coords.length; i++) { + coords[i] = new DenseVector(eigens._2.getColumn(i).clone()); + } + for (int i = 0; i < k; i += 1) { + coords[i].scaleEqual(eigens._1.get(i)); + } + return coords; + } + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/MultiCollinearityBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/MultiCollinearityBatchOp.java new file mode 100644 index 000000000..a8e7d6dfa --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/MultiCollinearityBatchOp.java @@ -0,0 +1,477 @@ +package com.alibaba.alink.operator.batch.statistics; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.jama.JMatrixFunc; +import com.alibaba.alink.common.utils.AlinkSerializable; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.common.viz.VizData; +import com.alibaba.alink.common.viz.VizDataWriterInterface; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; +import com.alibaba.alink.operator.common.statistics.statistics.SummaryResultTable; +import com.alibaba.alink.params.statistics.HasStatLevel_L1.StatLevel; +import com.alibaba.alink.params.statistics.MultiCollinearityBatchParams; + +import java.util.ArrayList; +import java.util.Arrays; + +import static com.alibaba.alink.common.utils.JsonConverter.gson; + +@InputPorts(values = @PortSpec(PortType.DATA)) +@OutputPorts(values = { + @PortSpec(PortType.DATA) +}) + +@ParamSelectColumnSpec(name = "selectedCols", + allowedTypeCollections = TypeCollections.NUMERIC_TYPES) + +@NameCn("多重共线性") +@NameEn("MultiCollinearity") +public class MultiCollinearityBatchOp extends BatchOperator + implements MultiCollinearityBatchParams { + + private static final long serialVersionUID = -3276749170439192468L; + + public MultiCollinearityBatchOp() { + super(null); + } + + public MultiCollinearityBatchOp(Params params) { + super(params); + } + + @Override + public MultiCollinearityBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + final String[] selectedColNames = getSelectedCols(); + + TableUtil.assertNumericalCols(in.getSchema(), selectedColNames); + + DataSet multiDataSet = StatisticsHelper.getSRT(in.select(selectedColNames), StatLevel.L3) + .map(new MapFunction () { + @Override + public Multicollinearity map(SummaryResultTable srt) throws Exception { + return Multicollinearity.calc(srt, selectedColNames); + } + }); + + this.setOutput( + multiDataSet.flatMap(new FlatMapFunction () { + @Override + public void flatMap(Multicollinearity multi, Collector out) throws Exception { + for (int i = 0; i < multi.nameX.length; i++) { + Row row = new Row(5 + selectedColNames.length); + row.setField(0, multi.nameX[i]); + row.setField(1, multi.VIF[i]); + row.setField(2, multi.TOL[i]); + row.setField(3, multi.eigenValues[i]); + row.setField(4, multi.CI[i]); + for (int j = 0; j < selectedColNames.length; j++) { + row.setField(5 + j, multi.VarProp[i][j]); + } + out.collect(row); + } + } + }), + mergeCols(new String[] {"feature_name", "vif", "tof", "eigenvalue", "condition_index"}, + selectedColNames), + mergeColTypes(new TypeInformation[] {Types.STRING, Types.DOUBLE, Types.DOUBLE, Types.DOUBLE, + Types.DOUBLE}, Types.DOUBLE, selectedColNames.length) + ); + + return this; + } + + private String[] mergeCols(String[] left, String[] right) { + String[] result = new String[left.length + right.length]; + System.arraycopy(left, 0, result, 0, left.length); + System.arraycopy(right, 0, result, left.length, right.length); + return result; + } + + private TypeInformation [] mergeColTypes(TypeInformation [] left, + TypeInformation right, + int rightLength) { + TypeInformation [] result = new TypeInformation [left.length + rightLength]; + System.arraycopy(left, 0, result, 0, left.length); + for (int i = 0; i < rightLength; i++) { + result[left.length + i] = right; + } + return result; + } + + public static class MulticollinearityFlatMap implements FlatMapFunction { + + private static final long serialVersionUID = 7050574386992532014L; + private String functionName; + private VizDataWriterInterface node; + + public MulticollinearityFlatMap(VizDataWriterInterface node, String functionName) { + this.functionName = functionName; + this.node = node; + } + + @Override + public void flatMap(SummaryResultTable srt, Collector collector) throws Exception { + try { + long timestamp = System.currentTimeMillis(); + Multicollinearity mcl = Multicollinearity.calc(srt, null); + + MulticollinearityResult mlr = new MulticollinearityResult(); + int rowLen = mcl.nameX.length; + int colLen = rowLen + 4; + mlr.rowNames = mcl.nameX; + mlr.colNames = new String[colLen]; + mlr.colNames[0] = "vif"; + mlr.colNames[1] = "tol"; + mlr.colNames[2] = "eigenvalue"; + mlr.colNames[3] = "condition_indx"; + System.arraycopy(mcl.nameX, 0, mlr.colNames, 4, rowLen); + mlr.data = new double[rowLen][colLen]; + for (int i = 0; i < rowLen; i++) { + mlr.data[i][0] = mcl.VIF[i]; + mlr.data[i][1] = mcl.TOL[i]; + mlr.data[i][2] = mcl.eigenValues[i]; + mlr.data[i][3] = mcl.CI[i]; + System.arraycopy(mcl.VarProp[i], 0, mlr.data[i], 4, rowLen); + } + + //simple summary + Row row = new Row(7); + String json = gson.toJson(mlr); + String resultType = ""; + row.setField(0, functionName); + row.setField(1, resultType); + row.setField(2, json); + row.setField(3, timestamp); + + collector.collect(row); + + { + int dataId = 0; + if (functionName.equals("AllStat")) { + dataId = 1; + } + + ArrayList vizDataList = new ArrayList <>(); + vizDataList.add(new VizData(dataId, json, timestamp)); + + node.writeStreamData(vizDataList); + } + + } catch (Exception ex) { + ex.printStackTrace(); + } + } + } + + public static class Multicollinearity { + /** + * * + * 变量名称 + */ + public String[] nameX; + /** + * * + * 方差膨胀因子(Variance Inflation Factor,VIF) + */ + public double[] VIF; + /** + * * + * 容忍度 + */ + public double[] TOL; + /** + * * + * 相关系数矩阵的特征值 + */ + public double[] eigenValues; + /** + * * + * 条件指数 + */ + public double[] CI; + /** + * * + * 方差比例(Variance Proportions) + */ + public double[][] VarProp; + /** + * * + * 相关系数矩阵 + */ + double[][] correlation; + /** + * * + * 相关系数矩阵的条件数 + */ + double kappa; + /** + * * + * 相关系数矩阵的最小特征值 + */ + double lambdaMin; + /** + * * + * 相关系数矩阵的最大特征值 + */ + double lambdaMax; + /** + * * + * 相关系数矩阵的最小特征向量 + */ + double[] vectorMin; + + public static Multicollinearity calc(SummaryResultTable srt, String[] nameX) throws Exception { + if (srt == null) { + throw new Exception("srt is null!"); + } + if (nameX == null) { + nameX = srt.colNames; + } + int nx = nameX.length; + int[] indexX = new int[nx]; + for (int i = 0; i < nx; i++) { + indexX[i] = TableUtil.findColIndexWithAssert(srt.colNames, nameX[i]); + Class type = srt.col(indexX[i]).dataType; + if (type != Double.class && type != Long.class && type != Boolean.class) { + throw new Exception("col type must be double, bigint , boolean!"); + } + if (srt.col(indexX[i]).count == 0) { + throw new Exception(nameX[i] + " count is zero, please choose cols again!"); + } + if (srt.col(indexX[i]).countMissValue > 0 || srt.col(indexX[i]).countNanValue > 0) { + throw new Exception("col " + nameX[i] + " has null value or nan value!"); + } + } + + double[][] matCorr = srt.getCorr(); + + double[][] correlation = new double[nx][nx]; + for (int i = 0; i < nx; i++) { + for (int j = 0; j < nx; j++) { + correlation[i][j] = matCorr[indexX[i]][indexX[j]]; + } + } + + DenseMatrix[] ed = JMatrixFunc.eig(new DenseMatrix(correlation)); + + Multicollinearity mcr = new Multicollinearity(); + mcr.correlation = correlation; + mcr.nameX = new String[nx]; + mcr.vectorMin = new double[nx]; + for (int i = 0; i < nx; i++) { + mcr.nameX[i] = nameX[i]; + mcr.vectorMin[i] = ed[0].get(i, 0); + } + + mcr.eigenValues = new double[nx]; + for (int i = 0; i < nx; i++) { + double d = ed[1].get(i, i); + if (d < 1E-12) { + d = 1E-12; + } + mcr.eigenValues[nx - 1 - i] = d; + } + + mcr.lambdaMax = mcr.eigenValues[0]; + mcr.lambdaMin = mcr.eigenValues[nx - 1]; + mcr.kappa = mcr.lambdaMax / mcr.lambdaMin; + + mcr.CI = new double[nx]; + for (int i = 0; i < nx; i++) { + mcr.CI[i] = Math.sqrt(mcr.lambdaMax / mcr.eigenValues[i]); + } + + double[][] q = new double[nx][nx]; + double[] sq = new double[nx]; + for (int j = 0; j < nx; j++) { + sq[j] = 0; + for (int i = 0; i < nx; i++) { + q[i][j] = ed[0].get(j, nx - 1 - i); + q[i][j] = q[i][j] * q[i][j] / mcr.eigenValues[i]; + sq[j] += q[i][j]; + } + } + + mcr.VarProp = new double[nx][nx]; + for (int i = 0; i < nx; i++) { + for (int j = 0; j < nx; j++) { + mcr.VarProp[i][j] = q[i][j] / sq[j]; + } + } + + mcr.VIF = new double[nx]; + mcr.TOL = new double[nx]; + + double thresholdVIF = 100000; + ArrayList restcols = new ArrayList (); + restcols.addAll(Arrays.asList(nameX)); + double[][] cov = srt.getCov(); + for (int i = nameX.length - 1; i >= 0; i--) { + ArrayList cols = new ArrayList (); + cols.addAll(Arrays.asList(nameX)); + cols.remove(i); + double r2 = getR2(srt, nameX[i], cols.toArray(new String[0]), cov); + mcr.VIF[i] = 1.0 / (1.0 - r2); + mcr.TOL[i] = 1.0 / mcr.VIF[i]; + if (mcr.VIF[i] > thresholdVIF) { + restcols.remove(nameX[i]); + } + } + for (int i = nameX.length - 1; i >= 0; i--) { + if (mcr.VIF[i] <= thresholdVIF) { + ArrayList cols = new ArrayList (); + cols.addAll(restcols); + cols.remove(nameX[i]); + double r2 = getR2(srt, nameX[i], cols.toArray(new String[0]), cov); + mcr.VIF[i] = 1.0 / (1.0 - r2); + mcr.TOL[i] = 1.0 / mcr.VIF[i]; + } + } + + return mcr; + } + + static double getR2(SummaryResultTable srt, int indexY, int[] indexX, String nameY, String[] nameX, double[][] cov) + throws Exception { + if (srt.col(0).countTotal == 0) { + throw new Exception("table is empty!"); + } + if (srt.col(0).countTotal < nameX.length) { + throw new Exception("record size Less than features size!"); + } + + int nx = indexX.length; + long N = srt.col(indexY).count; + if (N == 0) { + throw new Exception("Y valid value num is zero!"); + } + + //将count == 0 and cov == 0的去掉 + ArrayList nameXList = new ArrayList (); + for (int anIndexX : indexX) { + if (srt.col(anIndexX).count != 0 && cov[anIndexX][indexY] != 0) { + // if (srt.col(indexX[i]).count != 0) { + nameXList.add(anIndexX); + } + } + indexX = new int[nameXList.size()]; + for (int i = 0; i < indexX.length; i++) { + indexX[i] = nameXList.get(i); + } + nx = indexX.length; + + double[] XBar = new double[nx]; + for (int i = 0; i < nx; i++) { + XBar[i] = srt.col(indexX[i]).mean(); + } + double yBar = srt.col(indexY).mean(); + + DenseMatrix A = new DenseMatrix(nx, nx); + for (int i = 0; i < nx; i++) { + for (int j = 0; j < nx; j++) { + A.set(i, j, cov[indexX[i]][indexX[j]]); + } + } + DenseMatrix C = new DenseMatrix(nx, 1); + for (int i = 0; i < nx; i++) { + C.set(i, 0, cov[indexX[i]][indexY]); + } + + // JMatrix BetaMatrix = A.solveLS(C); + DenseMatrix BetaMatrix = null; + try { + BetaMatrix = A.solve(C); + } catch (Exception ex) { + BetaMatrix = A.solveLS(C); + } + + double[] beta = new double[nx + 1]; + double d = yBar; + for (int i = 0; i < nx; i++) { + beta[i + 1] = BetaMatrix.get(i, 0); + d -= XBar[i] * beta[i + 1]; + } + beta[0] = d; + + double S = srt.col(nameY).variance() * (srt.col(nameY).count - 1); + double alpha = beta[0] - yBar; + double U = 0.0; + U += alpha * alpha * N; + for (int i = 0; i < nx; i++) { + U += 2 * alpha * srt.col(indexX[i]).sum * beta[i + 1]; + } + for (int i = 0; i < nx; i++) { + for (int j = 0; j < nx; j++) { + U += beta[i + 1] * beta[j + 1] * (cov[indexX[i]][indexX[j]] * (N - 1) + + srt.col(indexX[i]).mean() * srt.col(indexX[j]).mean() * N); + } + } + double s = U / S; + if (s < 0) { + s = 0; + } else if (s > 1) { + s = 1; + } + return s; + } + + static double getR2(SummaryResultTable srt, String nameY, String[] nameX, double[][] cov) + throws Exception { + if (srt == null) { + throw new Exception("srt must not null!"); + } + String[] colNames = srt.colNames; + Class[] types = new Class[colNames.length]; + for (int i = 0; i < colNames.length; i++) { + types[i] = srt.col(i).dataType; + } + int indexY = TableUtil.findColIndexWithAssertAndHint(colNames, nameY); + Class typeY = types[indexY]; + if (typeY != Double.class && typeY != Long.class) { + throw new Exception("col type must be double or bigint!"); + } + if (nameX.length == 0) { + throw new Exception("nameX must input!"); + } + for (String aNameX : nameX) { + int indexX = TableUtil.findColIndexWithAssertAndHint(colNames, aNameX); + Class typeX = types[indexX]; + if (typeX != Double.class && typeX != Long.class) { + throw new Exception("col type must be double or bigint!"); + } + } + int nx = nameX.length; + int[] indexX = new int[nx]; + for (int i = 0; i < nx; i++) { + indexX[i] = TableUtil.findColIndexWithAssert(srt.colNames, nameX[i]); + } + + return getR2(srt, indexY, indexX, nameY, nameX, cov); + } + } + + static class MulticollinearityResult implements AlinkSerializable { + String[] colNames; //vif, tol,eigenvalue, condition_index,col0, col1, .... + String[] rowNames; //col0, col1, col2 + double[][] data; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/QuantileBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/QuantileBatchOp.java new file mode 100644 index 000000000..58845559f --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/QuantileBatchOp.java @@ -0,0 +1,215 @@ +package com.alibaba.alink.operator.batch.statistics; + +import org.apache.flink.api.common.functions.BroadcastVariableInitializer; +import org.apache.flink.api.common.functions.RichGroupReduceFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.common.utils.RowUtil; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.feature.QuantileDiscretizerTrainBatchOp; +import com.alibaba.alink.operator.common.dataproc.SortUtils; +import com.alibaba.alink.params.statistics.QuantileBatchParams; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; + +/** + * In statistics and probability quantiles are cut points dividing + * the range of a probability distribution into contiguous intervals + * with equal probabilities, or dividing the observations in a sample + * in the same way. + * (https://en.wikipedia.org/wiki/Quantile) + *

+ * reference: Yang, X. (2014). Chong gou da shu ju tong ji (1st ed., pp. 25-29). + *

+ * Note: This algorithm is improved on the base of the parallel + * sorting by regular sampling(PSRS). The following step is added + * to the PSRS + *

    + *
  • replace (val) with (val, task id) to distinguishing the + * same value on different machines
  • + *
  • + * the index of q-quantiles: index = roundMode((n - 1) * k / q), + * n is the count of sample, k satisfying 0 < k < q + *
  • + *
+ */ +@InputPorts(values = @PortSpec(PortType.DATA)) +@OutputPorts(values = @PortSpec(PortType.DATA)) +@ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@NameCn("分位数") +@NameEn("Quantile") +public final class QuantileBatchOp extends BatchOperator + implements QuantileBatchParams { + + private static final long serialVersionUID = -86119177892147044L; + + public QuantileBatchOp() { + super(null); + } + + public QuantileBatchOp(Params params) { + super(params); + } + + @Override + public QuantileBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + + TableSchema tableSchema = in.getSchema(); + + String quantileColName = getSelectedCol(); + + int index = TableUtil.findColIndexWithAssertAndHint(tableSchema.getFieldNames(), quantileColName); + + /* filter the selected column from input */ + DataSet input = in.select(quantileColName).getDataSet(); + + /* sort data */ + Tuple2 >, DataSet >> sortedData + = SortUtils.pSort(input, 0); + + /* calculate quantile */ + DataSet quantile = sortedData.f0. + groupBy(0) + .reduceGroup(new Quantile( + 0, getQuantileNum(), + getRoundMode())) + .withBroadcastSet(sortedData.f1, "counts"); + + /* set output */ + setOutput(quantile, + new String[] {tableSchema.getFieldNames()[index], "quantile"}, + new TypeInformation [] {tableSchema.getFieldTypes()[index], BasicTypeInfo.LONG_TYPE_INFO}); + + return this; + } + + /** + * + */ + public static class Quantile extends RichGroupReduceFunction , Row> { + private static final long serialVersionUID = -6101513604891658021L; + private int index; + private List > counts; + private long countSum = 0; + private int quantileNum; + private RoundMode roundType; + + public Quantile(int index, int quantileNum, RoundMode roundType) { + this.index = index; + this.quantileNum = quantileNum; + this.roundType = roundType; + } + + @Override + public void open(Configuration parameters) throws Exception { + this.counts = getRuntimeContext().getBroadcastVariableWithInitializer( + "counts", + new BroadcastVariableInitializer , List >>() { + @Override + public List > initializeBroadcastVariable( + Iterable > data) { + // sort the list by task id to calculate the correct offset + List > sortedData = new ArrayList <>(); + for (Tuple2 datum : data) { + sortedData.add(datum); + } + Collections.sort(sortedData, new Comparator >() { + @Override + public int compare(Tuple2 o1, Tuple2 o2) { + return o1.f0.compareTo(o2.f0); + } + }); + + return sortedData; + } + }); + + for (int i = 0; i < this.counts.size(); ++i) { + countSum += this.counts.get(i).f1; + } + } + + @Override + public void reduce(Iterable > values, Collector out) throws Exception { + ArrayList allRows = new ArrayList <>(); + int id = -1; + long start = 0; + long end = 0; + + for (Tuple2 value : values) { + id = value.f0; + allRows.add(Row.copy(value.f1)); + } + + if (id < 0) { + throw new Exception("Error key. key: " + id); + } + + int curListIndex = -1; + int size = counts.size(); + + for (int i = 0; i < size; ++i) { + int curId = counts.get(i).f0; + + if (curId == id) { + curListIndex = i; + break; + } + + if (curId > id) { + throw new Exception("Error curId: " + curId + + ". id: " + id); + } + + start += counts.get(i).f1; + } + + end = start + counts.get(curListIndex).f1; + + if (allRows.size() != end - start) { + throw new Exception("Error start end." + + " start: " + start + + ". end: " + end + + ". size: " + allRows.size()); + } + + SortUtils.RowComparator rowComparator = new SortUtils.RowComparator(this.index); + Collections.sort(allRows, rowComparator); + + QuantileDiscretizerTrainBatchOp.QIndex qIndex = new QuantileDiscretizerTrainBatchOp.QIndex( + countSum, quantileNum, roundType); + + for (int i = 0; i <= quantileNum; ++i) { + long index = qIndex.genIndex(i); + + if (index >= start && index < end) { + out.collect( + RowUtil.merge(allRows.get((int) (index - start)), Long.valueOf(i))); + } + } + } + + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/RankingListBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/RankingListBatchOp.java new file mode 100644 index 000000000..264bc6d36 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/RankingListBatchOp.java @@ -0,0 +1,414 @@ +package com.alibaba.alink.operator.batch.statistics; + +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.AlinkGlobalConfiguration; +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortDesc; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.common.utils.RowUtil; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.params.ParamUtil; +import com.alibaba.alink.params.statistics.RankingListParams; +import org.apache.commons.lang3.ArrayUtils; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * @author yangxu + */ +@InputPorts(values = {@PortSpec(PortType.DATA)}) +@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) +@ParamSelectColumnSpec(name = "objectCol") +@ParamSelectColumnSpec(name = "statCol", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@ParamSelectColumnSpec(name = "addedCols") +@NameCn("排行榜") +@NameEn("Ranking List") +public final class RankingListBatchOp extends BatchOperator + implements RankingListParams { + + private static final Double PRECISION = 1e-16; + private static final long serialVersionUID = 2673682618359897321L; + + public RankingListBatchOp() { + super(null); + } + + public RankingListBatchOp(Params params) { + super(params); + } + + @Override + public RankingListBatchOp linkFrom(BatchOperator ... inputs) { + + BatchOperator in = checkAndGetFirst(inputs); + + String groupCol = getGroupCol(); + String[] values = getGroupValues(); + String objectCol = getObjectCol(); + String statCol = getStatCol(); + StatType statFunc = getStatType(); + String[] addedCols = getAddedCols(); + String[] addedStatFuncs = getAddedStatTypes(); + Boolean isDescending = getIsDescending(); + + int topN = super.getParams().get(TOP_N); + + if (in.getColNames().length == 0) { + throw new RuntimeException("table col num must be larger than 0."); + } + + //check param + if (objectCol == null || objectCol.isEmpty()) { + throw new RuntimeException("objectCol must be set."); + } + + if ((statCol == null || statCol.isEmpty()) && + statFunc != StatType.count) { + throw new RuntimeException("if stat col is null, then statFunc must be count."); + } + if (statCol == null || statCol.isEmpty()) { + statCol = in.getColNames()[0]; + } + + if ((addedCols == null && addedStatFuncs != null) || + (addedCols != null && addedStatFuncs == null)) { + throw new RuntimeException("addedCols and addedStatFuncs length must be same."); + } + + if (addedCols != null && addedStatFuncs != null && + addedCols.length != addedStatFuncs.length) { + throw new RuntimeException("addedCols and addedStatFuncs length must be same."); + } + + if (groupCol != null && !groupCol.isEmpty()) { + if (values == null || values.length == 0) { + throw new RuntimeException("values must be set."); + } + } + + //check type + TableSchema tableSchema = in.getSchema(); + TypeInformation statColType = null; + if (statCol != null && !statCol.isEmpty()) { + statColType = tableSchema.getFieldType(statCol).get(); + if (statColType != AlinkTypes.INT && statColType != AlinkTypes.LONG && + statColType != AlinkTypes.DOUBLE && + statFunc != StatType.count) { + throw new RuntimeException("only support count when type not double and long."); + } + } + + TypeInformation [] addedColTypes = null; + if (addedCols != null && addedCols.length != 0) { + addedColTypes = new TypeInformation [addedCols.length]; + int len = addedCols.length; + for (int i = 0; i < len; ++i) { + TableUtil.assertSelectedColExist(in.getColNames(), addedCols[i]); + TypeInformation type = tableSchema.getFieldType(addedCols[i]).get(); + if (type != AlinkTypes.INT && type != AlinkTypes.LONG && type != AlinkTypes.DOUBLE && + !addedStatFuncs[i].equals("count")) { + throw new RuntimeException("only support count when type not double and long."); + } + addedColTypes[i] = type; + } + } + + TypeInformation groupColType = null; + if (groupCol != null && !groupCol.isEmpty()) { + groupColType = tableSchema.getFieldType(groupCol).get(); + if (groupColType != AlinkTypes.STRING) { + throw new RuntimeException("group col must be string."); + } + } + + if (objectCol == null || objectCol.isEmpty()) { + throw new RuntimeException("object col must exist."); + } + + TypeInformation objColType = null; + if (objectCol != null && !objectCol.isEmpty()) { + objColType = tableSchema.getFieldType(objectCol).get(); + if (objColType != AlinkTypes.STRING && objColType != AlinkTypes.LONG && objColType != AlinkTypes.INT) { + throw new RuntimeException("objectCol must be string or bigint."); + } + } + + int groupColIndex = -1; + if (groupCol != null) { + StringBuilder tmp = new StringBuilder(groupCol); + tmp.append("=").append("'").append(values[0]).append("'"); + for (int i = 1; i < values.length; ++i) { + tmp.append(" or ").append(groupCol).append("="). + append("'").append(values[i]).append("'"); + } + String filter = tmp.toString(); + if (AlinkGlobalConfiguration.isPrintProcessInfo()) { + System.out.println("filter: " + filter); + } + in = in.filter(filter); + groupColIndex = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), groupCol); + } + + int addedLen = addedCols == null ? 0 : addedCols.length; + int[] colIdx = new int[addedLen + 2]; + colIdx[0] = TableUtil.findColIndex(in.getColNames(), objectCol); + colIdx[1] = TableUtil.findColIndex(in.getColNames(), statCol); + for (int i = 0; i < addedLen; ++i) { + colIdx[i + 2] = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), addedCols[i]); + } + StatType[] funcs = {statFunc}; + + if (addedLen > 0) { + funcs = new StatType[1 + addedLen]; + funcs[0] = statFunc; + for (int i = 0; i < addedStatFuncs.length; i++) { + funcs[1 + i] = ParamUtil.searchEnum(RankingListParams.STAT_TYPE, addedStatFuncs[i]); + } + } + + TypeInformation [] statTypes = new TypeInformation[addedLen]; + for (int i = 0; i < addedLen; ++i) { + statTypes[i] = AlinkTypes.DOUBLE; + } + if (groupCol == null) { + String[] outputColNames = {objectCol, statCol}; + TypeInformation [] outputColTypes = {objColType, AlinkTypes.DOUBLE}; + if (addedCols != null && addedLen > 0) { + outputColNames = ArrayUtils.addAll(outputColNames, addedCols); + outputColTypes = ArrayUtils.addAll(outputColTypes, statTypes); + } + outputColNames = ArrayUtils.add(outputColNames, "rank"); + outputColTypes = ArrayUtils.add(outputColTypes, AlinkTypes.LONG); + DataSet sorted = in.getDataSet().reduceGroup(new SortByStatCol( + colIdx, groupColIndex, funcs, in.getColTypes(), + isDescending, topN)); + setOutput(sorted, outputColNames, outputColTypes); + } else { + String[] outputColNames = {groupCol, objectCol, statCol}; + TypeInformation [] outputColTypes = {groupColType, objColType, AlinkTypes.DOUBLE}; + if (addedCols != null && addedLen > 0) { + outputColNames = ArrayUtils.addAll(outputColNames, addedCols); + outputColTypes = ArrayUtils.addAll(outputColTypes, statTypes); + } + outputColNames = ArrayUtils.add(outputColNames, "rank"); + outputColTypes = ArrayUtils.add(outputColTypes, AlinkTypes.LONG); + DataSet sorted = in.getDataSet().groupBy(groupColIndex). + reduceGroup(new SortByStatCol(colIdx, groupColIndex, funcs, in.getColTypes(), + isDescending, topN)); + this.setOutput(sorted, outputColNames, outputColTypes); + } + return this; + } + + public static class SortByStatCol implements GroupReduceFunction { + private static final long serialVersionUID = -8278621481729657224L; + private int[] colIdx; + private int groupColIndex; + private StatType[] funcs; + private TypeInformation [] types; + private boolean hasAdded; + private boolean isDesending; + private int topN; + + public SortByStatCol(int[] colIdx, int groupColIndex, + StatType[] funcs, TypeInformation [] types, + boolean isDesending, int topN) { + this.colIdx = colIdx; + this.groupColIndex = groupColIndex; + this.funcs = funcs; + this.types = types; + this.isDesending = isDesending; + this.topN = topN; + this.hasAdded = funcs.length != 1; + } + + @Override + public void reduce(Iterable rows, Collector out) throws Exception { + int statColNum = this.funcs.length; + int objIndex = colIdx[0]; + Map >> basicStats = new HashMap <>(); + + for (Row row : rows) { + Object objCol = row.getField(objIndex); + Object groupCol = null; + if (groupColIndex != -1) { + groupCol = row.getField(groupColIndex); + } + if (basicStats.keySet().contains(objCol)) { + for (int i = 0; i < statColNum; ++i) { + basicStats.get(objCol).f1.get(i).add(row); + } + } else { + List stats = new ArrayList <>(); + for (int i = 0; i < statColNum; ++i) { + StatCal one = new StatCal(this.types[colIdx[i + 1]], colIdx[i + 1]); + one.add(row); + stats.add(one); + } + basicStats.put(objCol, new Tuple2 <>(groupCol, stats)); + } + } + + Map result = new HashMap <>(); + Object[] pre = new Object[2]; + for (Object objCol : basicStats.keySet()) { + pre[0] = basicStats.get(objCol).f0; + pre[1] = objCol; + Row row = toRow(basicStats.get(objCol).f1, pre, this.funcs); + result.put(objCol, row); + } + List > entryList = new ArrayList <>(result.entrySet()); + int statIndex = this.groupColIndex == -1 ? 1 : 2; + entryList.sort(new StatComparator(statIndex, this.isDesending)); + + Iterator > iter = entryList.iterator(); + Map.Entry tmp = null; + long rank = 1L; + while (iter.hasNext() && rank <= this.topN) { + tmp = iter.next(); + out.collect(RowUtil.merge(tmp.getValue(), rank)); + ++rank; + } + } + + private Row toRow(List stats, Object[] pre, StatType[] funcs) { + if (pre == null || pre.length < 1) { + throw new RuntimeException("No Object col info."); + } + //pre[0] = groupCol, pre[1] = objCol + int preLen = pre[0] == null ? pre.length - 1 : pre.length; + int len = stats.size() + preLen; + Row row = new Row(len); + if (pre[0] == null) { + row.setField(0, pre[1]); + } else { + row.setField(0, pre[0]); + row.setField(1, pre[1]); + } + for (int i = preLen; i < len; ++i) { + double value = stats.get(i - preLen).calc(funcs[i - preLen]); + row.setField(i, value); + } + return row; + } + + static class StatComparator implements Comparator > { + private int index; + private boolean isDesending; + + private StatComparator(int index, boolean isDesending) { + this.index = index; + this.isDesending = isDesending; + } + + @Override + public int compare(Map.Entry o1, Map.Entry o2) { + double left = (double) o1.getValue().getField(this.index); + double right = (double) o2.getValue().getField(this.index); + if (left < right) { + if (this.isDesending) { + return 1; + } else { + return -1; + } + } else if (Math.abs(left - right) < PRECISION) { + return 0; + } else { + if (this.isDesending) { + return -1; + } else { + return 1; + } + } + } + } + + public class StatCal { + private TypeInformation type; + private int colIndex; + private long countTotal; + private long count; + private double sum; + private double sum2; + private double min; + private double max; + + public StatCal(TypeInformation type, int colIndex) { + this.type = type; + this.colIndex = colIndex; + this.count = 0; + this.countTotal = 0; + this.sum = 0; + this.sum2 = 0; + this.min = Double.POSITIVE_INFINITY; + this.max = Double.NEGATIVE_INFINITY; + } + + public void add(Row row) { + countTotal++; + if (type.equals(AlinkTypes.DOUBLE) | type.equals(AlinkTypes.LONG) || type.equals(AlinkTypes.INT)) { + if (row.getField(this.colIndex) != null) { + double val = Double.parseDouble(row.getField(this.colIndex).toString()); + count++; + sum += val; + sum2 += val * val; + max = val > max ? val : max; + min = val < min ? val : min; + } + } else { + if (row.getField(this.colIndex) != null) { + count++; + } + } + } + + public double calc(StatType statFunc) { + switch (statFunc) { + case count: + return (double) count; + case countTotal: + return (double) countTotal; + case min: + return count == 0 ? 0 : min; + case max: + return count == 0 ? 0 : max; + case sum: + return sum; + case mean: + return count == 0 ? 0 : sum / count; + case variance: + if (0 == count || 1 == count || max == min) { + return 0.0; + } else { + return Math.max(0.0, (sum2 - sum / count * sum) / (count - 1)); + } + default: + throw new RuntimeException("statFunc " + statFunc + " not support."); + } + } + } + } +} + + diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/SomBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/SomBatchOp.java new file mode 100644 index 000000000..4c45d6a3d --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/SomBatchOp.java @@ -0,0 +1,634 @@ +package com.alibaba.alink.operator.batch.statistics; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.operators.IterativeDataSet; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Table; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.AlinkGlobalConfiguration; +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.SparseVector; +import com.alibaba.alink.common.linalg.Tensor; +import com.alibaba.alink.common.linalg.VectorUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.statistics.SomJni; +import com.alibaba.alink.params.statistics.SomParams; +import org.apache.commons.lang3.ArrayUtils; + +import java.io.FileOutputStream; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Random; + +/** + * Self-Organized Map algorithm. + *

+ * reference: + * 1. http://davis.wpi.edu/~matt/courses/soms/ + * 2. https://github.com/JustGlowing/minisom + * 3. https://clarkdatalabs.github.io/soms/SOM_NBA + *

+ * A rule of thumb to set the size of the grid for a dimensionality + * reduction task is that it should contain 5*Sqrt(N) neurons + * where N is the number of samples in the dataset to analyze. + * E.g. if your dataset has 150 samples, 5*Sqrt(150) = 61.23 + * hence a map 8-by-8 should perform well. + */ + +@InputPorts(values = @PortSpec(PortType.DATA)) +@OutputPorts(values = { + @PortSpec(PortType.DATA), + @PortSpec(PortType.DATA) +}) + +@ParamSelectColumnSpec(name = "vectorCol", + allowedTypeCollections = TypeCollections.VECTOR_TYPES) +@NameCn("Som") +@NameEn("Som") +public final class SomBatchOp extends BatchOperator + implements SomParams { + + public final static String[] COL_NAMES = new String[] {"meta", "xidx", "yidx", "weights", "cnt"}; + public final static TypeInformation[] COL_TYPES = new TypeInformation[] {AlinkTypes.STRING, AlinkTypes.LONG, AlinkTypes.LONG, + AlinkTypes.STRING, AlinkTypes.LONG}; + public final static boolean DO_PREDICTION = true; + private static final long serialVersionUID = -6014481798410706652L; + + public SomBatchOp() { + this(new Params()); + } + + public SomBatchOp(Params params) { + super(params); + } + + @Override + public SomBatchOp linkFrom(BatchOperator ... inputs) { + BatchOperator in = checkAndGetFirst(inputs); + + String tensorColName = getVectorCol(); + final int numIters = getNumIters(); + final int xdim = getXdim(); + final int ydim = getYdim(); + final int vdim = getVdim(); + final String meta = String.format("%d,%d,%d,r", xdim, ydim, vdim); + final boolean eval = getEvaluation(); + + // count number of traning samples + DataSet numSamples = in.getDataSet() + .mapPartition(new MapPartitionFunction () { + private static final long serialVersionUID = -4852925590649190739L; + + @Override + public void mapPartition(Iterable iterable, Collector collector) throws Exception { + long cnt = 0L; + for (Row r : iterable) { + cnt++; + } + collector.collect(cnt); + } + }) + .reduce(new ReduceFunction () { + private static final long serialVersionUID = -6343518193952236485L; + + @Override + public Long reduce(Long aLong, Long t1) throws Exception { + return aLong + t1; + } + }); + + // initialize the model by randomly select xdim * ydim samples from the training data + DataSet > initModel = in.select(tensorColName).getDataSet() + .mapPartition(new RichMapPartitionFunction >() { + private static final long serialVersionUID = -1154161394939821199L; + List selectedRows; + + @Override + public void open(Configuration parameters) throws Exception { + selectedRows = new ArrayList <>(xdim * ydim); + + long n = (long) getRuntimeContext().getBroadcastVariable("numSamples").get(0); + if (n < xdim * ydim) { + throw new RuntimeException("xdim * ydim > num training samples"); + } + + if (AlinkGlobalConfiguration.isPrintProcessInfo()) { + System.out.println("Initializing model, num training samples: " + n); + } + } + + @Override + public void mapPartition(Iterable iterable, Collector > collector) + throws Exception { + Random random = new Random(); + int cnt = 0; + for (Row r : iterable) { + if (cnt < xdim * ydim) { + selectedRows.add(r); + } else { + boolean keep = random.nextDouble() < (double) (xdim * ydim) / (cnt + 1); + if (keep) { + int pick = random.nextInt(xdim * ydim); + selectedRows.set(pick, r); + } + } + cnt++; + } + + int pos = 0; + for (int i = 0; i < xdim; i++) { + for (int j = 0; j < ydim; j++) { + Object object = selectedRows.get(pos).getField(0); + if (object instanceof DenseVector) { + object = VectorUtil.serialize(object); + } else if (object instanceof SparseVector) { + object = VectorUtil.serialize(object); + } + collector.collect(Tuple3.of((long) i, (long) j, (String) (object))); + pos++; + } + } + } + }) + .withBroadcastSet(numSamples, "numSamples") + .setParallelism(1) + .name("init_model"); + + IterativeDataSet > loop = initModel.iterate(numIters).setParallelism(1); + + DataSet > updatedModel; + updatedModel = in.select(tensorColName).getDataSet() + .mapPartition(new SomTask(getParams())) + .withBroadcastSet(loop, "initModel") + .withBroadcastSet(numSamples, "numSamples") + .setParallelism(1) + .name("som_train"); + + DataSet > finalModel = loop.closeWith(updatedModel); + + // output the model + DataSet model = in.select(tensorColName).getDataSet() + .mapPartition(new RichMapPartitionFunction () { + private static final long serialVersionUID = -7426117483145285343L; + + @Override + public void open(Configuration parameters) throws Exception { + if (getRuntimeContext().getNumberOfParallelSubtasks() != 1) { + throw new RuntimeException("parallelism should be 1"); + } + } + + @Override + public void mapPartition(Iterable iterable, Collector collector) throws Exception { + List > bcModel = getRuntimeContext().getBroadcastVariable("somModel"); + if (bcModel.size() != xdim * ydim) { + throw new RuntimeException("unexpected"); + } + SomModel model = new SomModel(xdim, ydim, vdim); + model.init(bcModel); + model.initCount(); + + float[] v = new float[vdim]; + int[] bmu = new int[2]; + + for (Row r : iterable) { + DenseVector tensor = VectorUtil.getDenseVector(r.getField(0)); + double[] data = tensor.getData(); + for (int i = 0; i < vdim; i++) { + v[i] = (float) data[i]; + } + model.findBMU(v, bmu); + model.increaseCount(bmu, 1L); + } + + long[][] counts = model.getCounts(); + List > weights = model.getWeights(); + + for (Tuple3 w : weights) { + collector.collect(Row.of(meta, w.f0, w.f1, w.f2, counts[w.f0.intValue()][w.f1.intValue()])); + } + } + }) + .withBroadcastSet(finalModel, "somModel") + .setParallelism(1) + .name("count"); + + setOutput(model, COL_NAMES, COL_TYPES); + + // do prediction + if (DO_PREDICTION) { + DataSet pred = in.getDataSet() + .map(new RichMapFunction () { + private static final long serialVersionUID = 5561628297750641436L; + transient SomModel model; + + @Override + public void open(Configuration parameters) throws Exception { + List > bcModel = getRuntimeContext().getBroadcastVariable( + "somModel"); + if (bcModel.size() != xdim * ydim) { + throw new RuntimeException("unexpected"); + } + model = new SomModel(xdim, ydim, vdim); + model.init(bcModel); + } + + @Override + public Row map(Row r) throws Exception { + float[] v = new float[vdim]; + int[] bmu = new int[2]; + DenseVector tensor = VectorUtil.getDenseVector(r.getField(0)); + double[] data = tensor.getData(); + for (int i = 0; i < vdim; i++) { + v[i] = (float) data[i]; + } + model.findBMU(v, bmu); + + Row o = new Row(r.getArity() + 2); + for (int i = 0; i < r.getArity(); i++) { + o.setField(i, r.getField(i)); + } + for (int i = 0; i < 2; i++) { + o.setField(i + r.getArity(), (long) bmu[i]); + } + return o; + } + }) + .withBroadcastSet(finalModel, "somModel"); + + Table table = DataSetConversionUtil.toTable(getMLEnvironmentId(), pred, + ArrayUtils.addAll(in.getColNames(), new String[] {"xidx", "yidx"}), + ArrayUtils.addAll(in.getColTypes(), new TypeInformation [] {AlinkTypes.LONG, AlinkTypes.LONG})); + this.setSideOutputTables(new Table[] {table}); + } + + return this; + } + + private static class SomTask extends RichMapPartitionFunction > { + private static final long serialVersionUID = 6117856294526477050L; + Params params; + + transient SomSolver solver = null; + + public SomTask(Params params) { + this.params = params; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + + if (AlinkGlobalConfiguration.isPrintProcessInfo()) { + System.out.println("\n ** step " + getIterationRuntimeContext().getSuperstepNumber()); + System.out.println(new Date().toString() + ": " + "start step ..."); + } + + if (solver == null) { + final int numIters = params.get(SomParams.NUM_ITERS); + final int xdim = params.get(SomParams.XDIM); + final int ydim = params.get(SomParams.YDIM); + final int vdim = params.get(SomParams.VDIM); + final double learnRate = params.get(SomParams.LEARN_RATE); + final double sigma = params.get(SomParams.SIGMA); + List > bcModel = getRuntimeContext().getBroadcastVariable("initModel"); + if (bcModel.size() != xdim * ydim) { + throw new RuntimeException("unexpected"); + } + List bcNumSamples = getRuntimeContext().getBroadcastVariable("numSamples"); + long numSamples = bcNumSamples.get(0); + solver = new SomSolver(xdim, ydim, vdim, learnRate, sigma, (long) numIters * numSamples); + solver.init(bcModel); + } + } + + @Override + public void close() throws Exception { + if (AlinkGlobalConfiguration.isPrintProcessInfo()) { + System.out.println(new Date().toString() + ": " + "close step ..."); + } + } + + @Override + public void mapPartition(Iterable iterable, Collector > collector) + throws Exception { + final int vdim = params.get(SomParams.VDIM); + final int BATCH_SIZE = 64 * 1024; + float[] batch = new float[BATCH_SIZE * vdim]; + int batchCnt = 0; + + for (Row row : iterable) { + DenseVector tensor = VectorUtil.getDenseVector(row.getField(0)); + double[] data = tensor.getData(); + int pos = batchCnt * vdim; + for (int i = 0; i < vdim; i++) { + batch[pos + i] = (float) data[i]; + } + batchCnt++; + if (batchCnt >= BATCH_SIZE) { + solver.updateBatch(batch, batchCnt); + batchCnt = 0; + } + } + + if (batchCnt > 0) { + solver.updateBatch(batch, batchCnt); + batchCnt = 0; + } + + List > weights = solver.getWeights(); + for (Tuple3 w : weights) { + collector.collect(w); + } + } + } + + public static class SomSolver { + private int xdim; + private int ydim; + private int vdim; + private double learnRate; + private double sigma; + private long currStepNo = 0L; + private long maxStepNo; + private float[] weights; + private SomJni somJni; + + public SomSolver(int xdim, int ydim, int vdim, double learnRate, double sigma, long maxStepNo) { + this.xdim = xdim; + this.ydim = ydim; + this.vdim = vdim; + this.learnRate = learnRate; + this.sigma = sigma; + this.maxStepNo = maxStepNo; + this.currStepNo = 0L; + weights = new float[xdim * ydim * vdim]; + + somJni = new SomJni(); + + if (maxStepNo > 0) { + if (AlinkGlobalConfiguration.isPrintProcessInfo()) { + System.out.println(String.format("xdim=%d,ydim=%d,vdim=%d,learnRate=%f,sigma=%f,maxStepNo=%d", + xdim, ydim, vdim, learnRate, sigma, maxStepNo)); + } + } + } + + private static double decayFunction(double v, double currStepNo, double maxStepNo) { + return v / (1.0 + 2.0 * currStepNo / maxStepNo); + } + + public static int getNeuronPos(int x, int y, int xdim, int ydim, int vdim) { + return (y * xdim + x) * vdim; + } + + public void init(List > model) { + for (Tuple3 neron : model) { + int x = neron.f0.intValue(); + int y = neron.f1.intValue(); + String tensorStr = neron.f2; + DenseVector tensor = VectorUtil.getDenseVector(tensorStr); + double[] data = tensor.getData(); + if (data.length != vdim) { + throw new RuntimeException("Invalid data length: " + data.length); + } + int pos = getNeuronPos(x, y, xdim, ydim, vdim); + for (int i = 0; i < vdim; i++) { + this.weights[pos + i] = (float) data[i]; + } + } + } + + public void updateBatch(float[] batch, int cnt) { + float lr = (float) decayFunction(this.learnRate, this.currStepNo, this.maxStepNo); + float sig = (float) decayFunction(this.sigma, this.currStepNo, this.maxStepNo); + + somJni.updateBatchJava(weights, batch, cnt, lr, sig, xdim, ydim, vdim); + + this.currStepNo += cnt; + } + + public List > getWeights() { + List > ret = new ArrayList <>(xdim * ydim); + + for (int i = 0; i < xdim; i++) { + for (int j = 0; j < ydim; j++) { + int pos = getNeuronPos(i, j, xdim, ydim, vdim); + StringBuilder sbd = new StringBuilder(); + for (int k = 0; k < vdim; k++) { + if (k > 0) { + sbd.append(","); + } + sbd.append(weights[pos + k]); + } + ret.add(Tuple3.of((long) i, (long) j, sbd.toString())); + } + } + + return ret; + } + } + + public static class SomModel { + private int xdim; + private int ydim; + private int vdim; + + private float[] weights; + private int[] bmu0; + + private long[][] counts; + private float[][] umatrix; + + private SomJni somJni; + + public SomModel(int xdim, int ydim, int vdim) { + this.xdim = xdim; + this.ydim = ydim; + this.vdim = vdim; + weights = new float[xdim * ydim * vdim]; + bmu0 = new int[2]; + + somJni = new SomJni(); + } + + private static float squaredDistance(float[] v1, int s1, float[] v2, int s2, int n) { + float s = 0.F; + for (int i = 0; i < n; i++) { + s += (v1[s1 + i] - v2[s2 + i]) * (v1[s1 + i] - v2[s2 + i]); + } + return s; + } + + public int getNeuronPos(int x, int y) { + return (y * xdim + x) * vdim; + } + + public void setNeuron(int x, int y, String tensorStr) { + Tensor tensor = Tensor.parse(tensorStr); + double[] data = tensor.getData(); + if (data.length != vdim) { + throw new RuntimeException("invalid data length: " + data.length); + } + int pos = getNeuronPos(x, y); + for (int i = 0; i < vdim; i++) { + this.weights[pos + i] = (float) data[i]; + } + } + + public void init(List > model) { + for (Tuple3 neron : model) { + int x = neron.f0.intValue(); + int y = neron.f1.intValue(); + String tensorStr = neron.f2; + DenseVector tensor = VectorUtil.parseDense(tensorStr); + double[] data = tensor.getData(); + if (data.length != vdim) { + throw new RuntimeException("Invalid data length: " + data.length); + } + int pos = getNeuronPos(x, y); + for (int i = 0; i < vdim; i++) { + this.weights[pos + i] = (float) data[i]; + } + } + } + + public void initCount() { + this.counts = new long[xdim][ydim]; + } + + public void increaseCount(int[] who, long c) { + this.counts[who[0]][who[1]] += c; + } + + public long[][] getCounts() { + return this.counts; + } + + private float findBMU(final float[] v, int[] bmu) { + float minD2 = somJni.findBmuJava(weights, new float[xdim * ydim], v, bmu0, xdim, ydim, vdim); + + if (bmu != null) { + bmu[0] = bmu0[0]; + bmu[1] = bmu0[1]; + } + return minD2; + } + + public List > getWeights() { + List > ret = new ArrayList <>(xdim * ydim); + + for (int i = 0; i < xdim; i++) { + for (int j = 0; j < ydim; j++) { + int pos = getNeuronPos(i, j); + StringBuilder sbd = new StringBuilder(); + for (int k = 0; k < vdim; k++) { + if (k > 0) { + sbd.append(","); + } + sbd.append(weights[pos + k]); + } + ret.add(Tuple3.of((long) i, (long) j, sbd.toString())); + } + } + + return ret; + } + + public void createUMatrix() { + this.umatrix = new float[xdim][ydim]; + for (int i = 0; i < xdim; i++) { + for (int j = 0; j < ydim; j++) { + float sum = 0.F; + int cnt = 0; + + for (int k = -1; k <= 1; k += 2) { + for (int l = -1; l <= 1; l += 2) { + int xtarget = i + k; + int ytarget = j + l; + if (xtarget < 0 || xtarget >= xdim) { + continue; + } + if (ytarget < 0 || ytarget >= ydim) { + continue; + } + sum += squaredDistance(this.weights, getNeuronPos(i, j), + this.weights, getNeuronPos(xtarget, ytarget), vdim); + cnt++; + } + } + + this.umatrix[i][j] = (float) Math.sqrt(sum / cnt); + } + } + } + + public float getUMatrixValue(int x, int y) { + return this.umatrix[x][y]; + } + + public void writeToFile(String xFn, String yFn) throws Exception { + StringBuilder sbdx = new StringBuilder(); + StringBuilder sbdy = new StringBuilder(); + for (int i = 0; i < ydim; i++) { + for (int j = 0; j < xdim; j++) { + if (j > 0) { + sbdx.append(" "); + sbdy.append(" "); + } + int pos = getNeuronPos(j, i); + sbdx.append(weights[pos + 0]); + sbdy.append(weights[pos + 1]); + } + sbdx.append("\n"); + sbdy.append("\n"); + } + + { + FileOutputStream out = new FileOutputStream(xFn); + out.write(sbdx.toString().getBytes()); + out.close(); + } + + { + FileOutputStream out = new FileOutputStream(yFn); + out.write(sbdy.toString().getBytes()); + out.close(); + } + } + + public void writePointsToFile(String fn, List x, List y) throws Exception { + int n = x.size(); + StringBuilder sbd = new StringBuilder(); + + for (int i = 0; i < n; i++) { + sbd.append(x.get(i)).append(" ").append(y.get(i)).append("\n"); + } + + FileOutputStream out = new FileOutputStream(fn); + out.write(sbd.toString().getBytes()); + out.close(); + } + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/SummarizerBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/SummarizerBatchOp.java index 06d006ed2..58d7969ff 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/SummarizerBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/SummarizerBatchOp.java @@ -8,6 +8,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -15,7 +16,7 @@ import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.exceptions.AkPreconditions; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.SummaryDataConverter; import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummary; import com.alibaba.alink.params.statistics.SummarizerParams; @@ -31,6 +32,7 @@ @OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) @ParamSelectColumnSpec(name = "selectedCols") @NameCn("全表统计") +@NameEn("Summarizer") public class SummarizerBatchOp extends BatchOperator implements SummarizerParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorChiSquareTestBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorChiSquareTestBatchOp.java index 252e02c18..8fee76ef2 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorChiSquareTestBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorChiSquareTestBatchOp.java @@ -5,6 +5,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -35,6 +36,7 @@ @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @ParamSelectColumnSpec(name = "labelCol") @NameCn("向量卡方检验") +@NameEn("Vector ChiSquare Test") public final class VectorChiSquareTestBatchOp extends BatchOperator implements VectorChiSquareTestParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorCorrelationBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorCorrelationBatchOp.java index 0c05ae5db..28848ae6e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorCorrelationBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorCorrelationBatchOp.java @@ -11,6 +11,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -18,10 +19,10 @@ import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.exceptions.AkPreconditions; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.TableSourceBatchOp; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; import com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationDataConverter; import com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationResult; @@ -39,6 +40,7 @@ @OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量相关系数") +@NameEn("Vector Correlation") public final class VectorCorrelationBatchOp extends BatchOperator implements VectorCorrelationParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorSummarizerBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorSummarizerBatchOp.java index 385560a10..5a417acfc 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorSummarizerBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorSummarizerBatchOp.java @@ -8,6 +8,7 @@ import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.PortDesc; @@ -16,7 +17,7 @@ import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.common.exceptions.AkPreconditions; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; import com.alibaba.alink.operator.common.statistics.basicstatistic.VectorSummaryDataConverter; import com.alibaba.alink.params.statistics.VectorSummarizerParams; @@ -32,6 +33,7 @@ @OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) @ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.VECTOR_TYPES) @NameCn("向量全表统计") +@NameEn("Vector Summarizer") public class VectorSummarizerBatchOp extends BatchOperator implements VectorSummarizerParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/StatisticUtil.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/utils/StatisticUtil.java similarity index 62% rename from core/src/main/java/com/alibaba/alink/operator/common/statistics/StatisticUtil.java rename to core/src/main/java/com/alibaba/alink/operator/batch/statistics/utils/StatisticUtil.java index a4e93ea28..528f2b985 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/StatisticUtil.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/utils/StatisticUtil.java @@ -1,15 +1,9 @@ -package com.alibaba.alink.operator.common.statistics; +package com.alibaba.alink.operator.batch.statistics.utils; -import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.java.DataSet; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.DenseVector; -import com.alibaba.alink.common.AlinkTypes; -import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.statistics.basicstat.SetPartitionBasicStat; -import com.alibaba.alink.operator.common.statistics.statistics.SummaryResultTable; -import com.alibaba.alink.params.statistics.HasStatLevel_L1; public class StatisticUtil { /* @@ -28,25 +22,6 @@ public class StatisticUtil { AlinkTypes.STRING, AlinkTypes.LONG, AlinkTypes.STRING, AlinkTypes.STRING, AlinkTypes.STRING}; final static public String COLUMN_NAME_OF_TIMESTAMP = "timestamp"; - public static DataSet getSRT(BatchOperator in, HasStatLevel_L1.StatLevel statLevel) { - return in.getDataSet() - .mapPartition(new SetPartitionBasicStat(in.getSchema(), statLevel)) - .reduce(new ReduceFunction () { - private static final long serialVersionUID = 6050967884386340459L; - - @Override - public SummaryResultTable reduce(SummaryResultTable a, SummaryResultTable b) { - if (null == a) { - return b; - } else if (null == b) { - return a; - } else { - return SummaryResultTable.combine(a, b); - } - } - }); - } - /** * determine whether it is a string type. */ @@ -75,11 +50,7 @@ public static boolean isDatetime(String dataType) { * @return the transformed {@link DenseVector} */ public static double getDoubleValue(Object obj, Class type) { - if (Double.class == type || - Integer.class == type || - Long.class == type || - Float.class == type || - Short.class == type) { + if (Number.class.isAssignableFrom(type)) { return ((Number) obj).doubleValue(); } else if (Boolean.class == type) { return (Boolean) obj ? 1.0 : 0.0; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/StatisticsHelper.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/utils/StatisticsHelper.java similarity index 74% rename from core/src/main/java/com/alibaba/alink/operator/common/statistics/StatisticsHelper.java rename to core/src/main/java/com/alibaba/alink/operator/batch/statistics/utils/StatisticsHelper.java index 684dce42d..6f839c2c4 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/StatisticsHelper.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/utils/StatisticsHelper.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.operator.common.statistics; +package com.alibaba.alink.operator.batch.statistics.utils; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.MapPartitionFunction; @@ -6,6 +6,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; @@ -24,11 +25,11 @@ import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummarizer; import com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummary; import com.alibaba.alink.operator.common.statistics.basicstatistic.VectorSummarizerUtil; +import com.alibaba.alink.operator.common.statistics.statistics.SrtUtil; +import com.alibaba.alink.operator.common.statistics.statistics.SummaryResultTable; +import com.alibaba.alink.operator.common.statistics.statistics.WindowTable; import com.alibaba.alink.operator.common.tree.Preprocessing; - -import java.security.InvalidParameterException; -import java.util.ArrayList; -import java.util.List; +import com.alibaba.alink.params.statistics.HasStatLevel_L1; /** * Util for batch statistical calculation. @@ -36,16 +37,15 @@ public class StatisticsHelper { /** * Transform cols or vector col to vector and compute summary. - * it deal with table which only have number cols and not has missing value. - * It return tuple2, f0 is data, f1 is summary. data f0 is vector, data f1 is row which is reserved cols. + * it deals with table which only have number cols and not has missing value. + * It returns tuple2, f0 is data, f1 is summary. data f0 is vector, data f1 is row which is reserved cols. */ - public static Tuple2 >, DataSet > summaryHelper(BatchOperator in, - String[] - selectedColNames, - String - vectorColName, - String[] - reservedColNames) { + public static Tuple2 >, DataSet > summaryHelper( + BatchOperator in, + String[] selectedColNames, + String vectorColName, + String[] reservedColNames) { + DataSet > data = transformToVector(in, selectedColNames, vectorColName, reservedColNames); DataSet vectorDataSet = data @@ -63,9 +63,9 @@ public Vector map(Tuple2 tuple2) { /** * Transform to vector without reservedColNames and compute summary. - * it deal with table which only have number cols and has no missing value. + * it deals with table which only have number cols and has no missing value. */ - public static Tuple2 , DataSet > summaryHelper(BatchOperator in, + public static Tuple2 , DataSet > summaryHelper(BatchOperator in, String[] selectedColNames, String vectorColName) { DataSet data = transformToVector(in, selectedColNames, vectorColName); @@ -75,13 +75,13 @@ public static Tuple2 , DataSet > summaryHelp /** * table summary, selectedColNames must be set. */ - public static DataSet summary(BatchOperator in, String[] selectedColNames) { + public static DataSet summary(BatchOperator in, String[] selectedColNames) { return summarizer(in, selectedColNames, false) .map(new MapFunction () { private static final long serialVersionUID = 8876418210242735806L; @Override - public TableSummary map(TableSummarizer summarizer) throws Exception { + public TableSummary map(TableSummarizer summarizer) { return summarizer.toSummary(); } }).name("toSummary"); @@ -90,7 +90,7 @@ public TableSummary map(TableSummarizer summarizer) throws Exception { /** * vector stat, selectedColName must be set. */ - public static DataSet vectorSummary(BatchOperator in, String selectedColName) { + public static DataSet vectorSummary(BatchOperator in, String selectedColName) { return vectorSummarizer(in, selectedColName, false) .map(new MapFunction () { private static final long serialVersionUID = -6426572658193278213L; @@ -111,7 +111,7 @@ public static DataSet summary(DataSet data) { private static final long serialVersionUID = -2082435777065038687L; @Override - public BaseVectorSummary map(BaseVectorSummarizer summarizer) throws Exception { + public BaseVectorSummary map(BaseVectorSummarizer summarizer) { return summarizer.toSummary(); } }).name("toSummary"); @@ -120,7 +120,7 @@ public BaseVectorSummary map(BaseVectorSummarizer summarizer) throws Exception { /** * calculate correlation. result is tuple2, f0 is summary, f1 is correlation. */ - public static DataSet > pearsonCorrelation(BatchOperator in, + public static DataSet > pearsonCorrelation(BatchOperator in, String[] selectedColNames) { return summarizer(in, selectedColNames, true) .map(new MapFunction >() { @@ -136,9 +136,10 @@ public Tuple2 map(TableSummarizer summarizer) /** * calculate correlation. result is tuple2, f0 is summary, f1 is correlation. */ - public static DataSet > vectorPearsonCorrelation(BatchOperator in, - String - selectedColName) { + public static DataSet > vectorPearsonCorrelation( + BatchOperator in, + String selectedColName) { + return vectorSummarizer(in, selectedColName, true) .map(new MapFunction >() { private static final long serialVersionUID = -1745468840082156193L; @@ -197,14 +198,14 @@ public static DataSet > transformToVector(BatchOperator * Transform vector or table columns to table columns. * selectedCols and vectorCol only be one non-empty, reserved cols can be empty. *

- * If selected cols is not null, it will combines selected cols and reserved cols, + * If selected cols is not null, it will combine selected cols and reserved cols, * and selected cols will transform to double type. *

- * If vector col is not null, it will splits vector to cols and combines with reserved cols. + * If vector col is not null, it will split vector to cols and combines with reserved cols. *

* If selected cols and vector col both set, it will use vector col. */ - public static DataSet transformToColumns(BatchOperator in, + public static DataSet transformToColumns(BatchOperator in, String[] selectedColNames, String vectorColName, String[] reservedColNames) { @@ -231,7 +232,7 @@ public static DataSet transformToColumns(BatchOperator in, /** * check parameters is invalid. */ - private static void checkSimpleStatParameter(BatchOperator in, + private static void checkSimpleStatParameter(BatchOperator in, String[] selectedColNames, String vectorColName, String[] reservedColNames) { @@ -251,7 +252,8 @@ private static void checkSimpleStatParameter(BatchOperator in, /** * table stat */ - private static DataSet summarizer(BatchOperator in, String[] selectedColNames, + private static DataSet summarizer(BatchOperator in, + String[] selectedColNames, boolean calculateOuterProduct) { if (selectedColNames == null || selectedColNames.length == 0) { throw new AkIllegalOperatorParameterException("selectedColNames must be set."); @@ -259,14 +261,14 @@ private static DataSet summarizer(BatchOperator in, String[] s in = Preprocessing.select(in, selectedColNames); - return summarizer(in.getDataSet(), calculateOuterProduct, getNumericalColIndices(in.getColTypes()), - selectedColNames); + return summarizer(in.getDataSet(), in.getSchema(), calculateOuterProduct); } /** * vector stat */ - private static DataSet vectorSummarizer(BatchOperator in, String selectedColName, + private static DataSet vectorSummarizer(BatchOperator in, + String selectedColName, boolean calculateOuterProduct) { TableUtil.assertSelectedColExist(in.getColNames(), selectedColName); @@ -284,8 +286,7 @@ public static DataSet summarizer(DataSet data, b private static final long serialVersionUID = 5993118429985684366L; @Override - public BaseVectorSummarizer reduce(BaseVectorSummarizer value1, BaseVectorSummarizer value2) - throws Exception { + public BaseVectorSummarizer reduce(BaseVectorSummarizer value1, BaseVectorSummarizer value2) { return VectorSummarizerUtil.merge(value1, value2); } }) @@ -295,10 +296,9 @@ public BaseVectorSummarizer reduce(BaseVectorSummarizer value1, BaseVectorSummar /** * given data, return summary. numberIndices is the indices of cols which are number type in selected cols. */ - private static DataSet summarizer(DataSet data, boolean bCov, int[] numberIndices, - String[] selectedColNames) { + private static DataSet summarizer(DataSet data, TableSchema tableSchema, boolean bCov) { return data - .mapPartition(new TableSummarizerPartition(bCov, numberIndices, selectedColNames)) + .mapPartition(new TableSummarizerPartition(tableSchema, bCov)) .reduce(new ReduceFunction () { private static final long serialVersionUID = 964700189305139868L; @@ -309,22 +309,23 @@ public TableSummarizer reduce(TableSummarizer left, TableSummarizer right) { }); } - /** - * get indices which type is number - */ - private static int[] getNumericalColIndices(TypeInformation[] colTypes) { - List numberColIndicesList = new ArrayList <>(); - for (int i = 0; i < colTypes.length; i++) { - if (TableUtil.isSupportedNumericType(colTypes[i])) { - numberColIndicesList.add(i); - } - } + public static DataSet getSRT(BatchOperator in, HasStatLevel_L1.StatLevel statLevel) { + return in.getDataSet() + .mapPartition(new SetPartitionBasicStat(in.getSchema(), statLevel)) + .reduce(new ReduceFunction () { + private static final long serialVersionUID = 6050967884386340459L; - int[] numberColIndices = new int[numberColIndicesList.size()]; - for (int i = 0; i < numberColIndices.length; i++) { - numberColIndices[i] = numberColIndicesList.get(i); - } - return numberColIndices; + @Override + public SummaryResultTable reduce(SummaryResultTable a, SummaryResultTable b) { + if (null == a) { + return b; + } else if (null == b) { + return a; + } else { + return SummaryResultTable.combine(a, b); + } + } + }); } /** @@ -332,8 +333,8 @@ private static int[] getNumericalColIndices(TypeInformation[] colTypes) { */ private static class ColsToVectorWithReservedColsMap implements MapFunction > { private static final long serialVersionUID = -7292044433828115396L; - private int[] selectedColIndices; - private int[] reservedColIndices; + private final int[] selectedColIndices; + private final int[] reservedColIndices; ColsToVectorWithReservedColsMap(int[] selectedColIndices, int[] reservedColIndices) { this.selectedColIndices = selectedColIndices; @@ -360,7 +361,7 @@ public Tuple2 map(Row in) throws Exception { */ private static class ColsToVectorWithoutReservedColsMap implements MapFunction { private static final long serialVersionUID = -8479361651447801687L; - private int[] selectedColIndices; + private final int[] selectedColIndices; ColsToVectorWithoutReservedColsMap(int[] selectedColIndices) { this.selectedColIndices = selectedColIndices; @@ -386,8 +387,8 @@ public Vector map(Row in) throws Exception { */ private static class VectorColToVectorWithReservedColsMap implements MapFunction > { private static final long serialVersionUID = -3222351920500305742L; - private int vectorColIndex; - private int[] reservedColIndices; + private final int vectorColIndex; + private final int[] reservedColIndices; VectorColToVectorWithReservedColsMap(int vectorColIndex, int[] reservedColIndices) { this.vectorColIndex = vectorColIndex; @@ -413,7 +414,7 @@ public Tuple2 map(Row in) throws Exception { */ private static class VectorCoToVectorWithoutReservedColsMap implements MapFunction { private static final long serialVersionUID = -6220416346174572528L; - private int vectorColIndex; + private final int vectorColIndex; VectorCoToVectorWithoutReservedColsMap(int vectorColIndex) { this.vectorColIndex = vectorColIndex; @@ -434,18 +435,15 @@ private static class VectorColToTableMap implements MapFunction { /** * vector col index. */ - private int vectorColIndex; + private final int vectorColIndex; /** * reserved col indices. */ - private int[] reservedColIndices; + private final int[] reservedColIndices; VectorColToTableMap(int vectorColIndex, int[] reservedColIndices) { this.vectorColIndex = vectorColIndex; - this.reservedColIndices = reservedColIndices; - if (reservedColIndices == null) { - this.reservedColIndices = new int[0]; - } + this.reservedColIndices = null == reservedColIndices ? new int[0] : reservedColIndices; } @Override @@ -486,20 +484,17 @@ public Row map(Row in) throws Exception { */ private static class ColsToDoubleColsMap implements MapFunction { private static final long serialVersionUID = 2021889928304298454L; - private int[] selectedColIndices; - private int[] reservedColIndices; + private final int[] selectedColIndices; + private final int[] reservedColIndices; ColsToDoubleColsMap(int[] selectedColIndices, int[] reservedColIndices) { this.selectedColIndices = selectedColIndices; - this.reservedColIndices = reservedColIndices; - if (reservedColIndices == null) { - this.reservedColIndices = new int[0]; - } + this.reservedColIndices = null == reservedColIndices ? new int[0] : reservedColIndices; } @Override public Row map(Row in) throws Exception { - //table cols and reserved cols, table cols will be transform to double type. + //table cols and reserved cols, table cols will be transformed to double type. Row out = new Row(selectedColIndices.length + reservedColIndices.length); for (int i = 0; i < this.selectedColIndices.length; ++i) { out.setField(i, ((Number) in.getField(this.selectedColIndices[i])).doubleValue()); @@ -514,22 +509,21 @@ public Row map(Row in) throws Exception { /** * It is table summary partition of one worker, will merge result later. */ - public static class TableSummarizerPartition implements MapPartitionFunction { + private static class TableSummarizerPartition implements MapPartitionFunction { private static final long serialVersionUID = -1625614901816383530L; - private boolean outerProduct; - private int[] numericalIndices; - private String[] selectedColNames; + private final boolean outerProduct; + private final String[] colNames; + private final TypeInformation [] colTypes; - TableSummarizerPartition(boolean outerProduct, int[] numericalIndices, String[] selectedColNames) { + TableSummarizerPartition(TableSchema tableSchema, boolean outerProduct) { this.outerProduct = outerProduct; - this.numericalIndices = numericalIndices; - this.selectedColNames = selectedColNames; + this.colNames = tableSchema.getFieldNames(); + this.colTypes = tableSchema.getFieldTypes(); } @Override public void mapPartition(Iterable iterable, Collector collector) throws Exception { - TableSummarizer srt = new TableSummarizer(selectedColNames, numericalIndices, outerProduct); - srt.colNames = selectedColNames; + TableSummarizer srt = new TableSummarizer(new TableSchema(colNames, colTypes), outerProduct); for (Row sv : iterable) { srt = (TableSummarizer) srt.visit(sv); } @@ -543,7 +537,7 @@ public void mapPartition(Iterable iterable, Collector co */ public static class VectorSummarizerPartition implements MapPartitionFunction { private static final long serialVersionUID = 1065284716432882945L; - private boolean outerProduct; + private final boolean outerProduct; public VectorSummarizerPartition(boolean outerProduct) { this.outerProduct = outerProduct; @@ -561,4 +555,66 @@ public void mapPartition(Iterable iterable, Collector { + + private static final long serialVersionUID = -5607403479996476267L; + private String[] colNames; + private Class[] colTypes; + private HasStatLevel_L1.StatLevel statLevel; + private String[] selectedColNames = null; + + public SetPartitionBasicStat(TableSchema schema) { + this(schema, HasStatLevel_L1.StatLevel.L1); + } + + /** + * @param schema + * @param statLevel: L1,L2,L3: 默认是L1 + * L1 has basic statistic; + * L2 has simple statistic and cov/corr; + * L3 has simple statistic, cov/corr, histogram, freq, topk, bottomk; + */ + public SetPartitionBasicStat(TableSchema schema, HasStatLevel_L1.StatLevel statLevel) { + this.colNames = schema.getFieldNames(); + int n = this.colNames.length; + this.colTypes = new Class[n]; + for (int i = 0; i < n; i++) { + colTypes[i] = schema.getFieldTypes()[i].getTypeClass(); + } + this.statLevel = statLevel; + this.selectedColNames = this.colNames; + } + + /** + * @param schema + * @param statLevel: L1,L2,L3: 默认是L1 + * L1 has basic statistic; + * L2 has simple statistic and cov/corr; + * L3 has simple statistic, cov/corr, histogram, freq, topk, bottomk; + * @param selectedColNames + */ + public SetPartitionBasicStat(TableSchema schema, String[] selectedColNames, HasStatLevel_L1.StatLevel statLevel) { + this.colNames = schema.getFieldNames(); + int n = this.colNames.length; + this.colTypes = new Class[n]; + for (int i = 0; i < n; i++) { + colTypes[i] = schema.getFieldTypes()[i].getTypeClass(); + } + this.statLevel = statLevel; + this.selectedColNames = selectedColNames; + } + + @Override + public void mapPartition(Iterable itrbl, Collector clctr) throws Exception { + WindowTable wt = new WindowTable(this.colNames, this.colTypes, itrbl); + SummaryResultTable srt = SrtUtil.batchSummary(wt, this.selectedColNames, 10, 10, 1000, 10, this.statLevel); + if (srt != null) { + clctr.collect(srt); + } + } + + } } \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TF2TableModelTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TF2TableModelTrainBatchOp.java index 907ede02b..c6b90292f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TF2TableModelTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TF2TableModelTrainBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.dl.BaseDLTableModelTrainBatchOp; import com.alibaba.alink.common.dl.DLEnvConfig.Version; import com.alibaba.alink.params.tensorflow.TFTableModelTrainParams; @@ -16,6 +17,7 @@ * timestamps) is zipped and returned to Alink side as a two-column Alink Model. */ @NameCn("TF2表模型训练") +@NameEn("TF2 TableModel Training") public class TF2TableModelTrainBatchOp extends BaseDLTableModelTrainBatchOp implements TFTableModelTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TFSavedModelPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TFSavedModelPredictBatchOp.java index 8cd0eb505..8db47b242 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TFSavedModelPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TFSavedModelPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.tensorflow.TFSavedModelPredictMapper; @@ -13,6 +14,7 @@ */ @ParamSelectColumnSpec(name = "selectedCols") @NameCn("TF SavedModel模型预测") +@NameEn("TF SaveModel Prediction") public final class TFSavedModelPredictBatchOp extends MapBatchOp implements TFSavedModelPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TFTableModelPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TFTableModelPredictBatchOp.java index db601f9ff..2a174f817 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TFTableModelPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TFTableModelPredictBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.operator.batch.utils.FlatModelMapBatchOp; import com.alibaba.alink.operator.common.tensorflow.TFTableModelPredictFlatModelMapper; @@ -14,6 +15,7 @@ */ @ParamSelectColumnSpec(name = "selectedCols") @NameCn("TF表模型预测") +@NameEn("TF TableModel Prediction") public final class TFTableModelPredictBatchOp extends FlatModelMapBatchOp implements TFTableModelPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TFTableModelTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TFTableModelTrainBatchOp.java index 7770940c0..3d025bc2c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TFTableModelTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TFTableModelTrainBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.dl.BaseDLTableModelTrainBatchOp; import com.alibaba.alink.common.dl.DLEnvConfig.Version; import com.alibaba.alink.params.tensorflow.TFTableModelTrainParams; @@ -16,6 +17,7 @@ * timestamps) is zipped and returned back to Alink side as a two-column Alink Model. */ @NameCn("TF表模型训练") +@NameEn("TF TableModel Training") public class TFTableModelTrainBatchOp extends BaseDLTableModelTrainBatchOp implements TFTableModelTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TensorFlow2BatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TensorFlow2BatchOp.java index a48553cec..9a3760cfa 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TensorFlow2BatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TensorFlow2BatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.dl.BaseDLBatchOp; import com.alibaba.alink.common.dl.DLEnvConfig.Version; import com.alibaba.alink.params.tensorflow.TensorFlowParams; @@ -10,10 +11,11 @@ import java.util.Collections; /** - * A general stream op to run custom TensorFlow (version 2.3.1) scripts for stream datasets. + * A general batch op to run custom TensorFlow (version 2.3.1) scripts for batch datasets. * Any number of outputs are allowed from TF scripts, even no outputs. */ @NameCn("TensorFlow2自定义脚本") +@NameEn("TensorFlow2 Script") public class TensorFlow2BatchOp extends BaseDLBatchOp implements TensorFlowParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TensorFlowBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TensorFlowBatchOp.java index 432cba3ad..0c8016856 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TensorFlowBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/tensorflow/TensorFlowBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.dl.BaseDLBatchOp; import com.alibaba.alink.common.dl.DLEnvConfig.Version; import com.alibaba.alink.params.tensorflow.TensorFlowParams; @@ -10,10 +11,11 @@ import java.util.Collections; /** - * A general stream op to run custom TensorFlow (version 1.15) scripts for stream datasets. + * A general batch op to run custom TensorFlow (version 1.15) scripts for batch datasets. * Any number of outputs are allowed from TF scripts, even no outputs. */ @NameCn("TensorFlow自定义脚本") +@NameEn("TensorFlow Script") public class TensorFlowBatchOp extends BaseDLBatchOp implements TensorFlowParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ArimaBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ArimaBatchOp.java index 05921e6b6..2cb7ca9e1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ArimaBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ArimaBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.timeseries.ArimaMapper; import com.alibaba.alink.params.timeseries.ArimaParams; @NameCn("Arima") +@NameEn("Arima") public class ArimaBatchOp extends MapBatchOp implements ArimaParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/AutoArimaBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/AutoArimaBatchOp.java index 808f07c93..f3c3b7c92 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/AutoArimaBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/AutoArimaBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.timeseries.AutoArimaMapper; import com.alibaba.alink.params.timeseries.AutoArimaParams; @NameCn("AutoArima") +@NameEn("Auto Arima") public final class AutoArimaBatchOp extends MapBatchOp implements AutoArimaParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/AutoGarchBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/AutoGarchBatchOp.java index d26c24399..d96245318 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/AutoGarchBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/AutoGarchBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.timeseries.AutoGarchMapper; import com.alibaba.alink.params.timeseries.AutoGarchParams; @NameCn("AutoGarch") +@NameEn("Auto Graph") public final class AutoGarchBatchOp extends MapBatchOp implements AutoGarchParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/DeepARPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/DeepARPredictBatchOp.java index d3df90e00..17de0f4dd 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/DeepARPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/DeepARPredictBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.timeseries.DeepARModelMapper; import com.alibaba.alink.params.timeseries.DeepARPredictParams; @NameCn("DeepAR预测") +@NameEn("Deep AR Prediction") public class DeepARPredictBatchOp extends ModelMapBatchOp implements DeepARPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/DeepARTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/DeepARTrainBatchOp.java index 7e824a365..e748980e8 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/DeepARTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/DeepARTrainBatchOp.java @@ -18,9 +18,11 @@ import org.apache.flink.types.Row; import org.apache.flink.util.Collector; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.Internal; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamCond; import com.alibaba.alink.common.annotation.ParamCond.CondType; @@ -37,7 +39,7 @@ import com.alibaba.alink.common.linalg.tensor.Shape; import com.alibaba.alink.common.linalg.tensor.Tensor; import com.alibaba.alink.common.linalg.tensor.TensorUtil; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; @@ -81,6 +83,7 @@ ) ) @NameCn("DeepAR训练") +@NameEn("Deep AR Training") public class DeepARTrainBatchOp extends BatchOperator implements DeepARTrainParams { @@ -175,6 +178,7 @@ public TimeFrequency map(Row value) throws Exception { return this; } + @Internal private static class DeepARPreProcessBatchOp extends BatchOperator implements DeepARPreProcessParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/HoltWintersBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/HoltWintersBatchOp.java index 3c1a70745..7aac27e46 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/HoltWintersBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/HoltWintersBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.timeseries.HoltWintersMapper; import com.alibaba.alink.params.timeseries.HoltWintersParams; @NameCn("HoltWinters") +@NameEn("Holt Winters") public class HoltWintersBatchOp extends MapBatchOp implements HoltWintersParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/LSTNetPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/LSTNetPredictBatchOp.java index 7a54256b1..fdf00321f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/LSTNetPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/LSTNetPredictBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.timeseries.LSTNetModelMapper; import com.alibaba.alink.params.timeseries.LSTNetPredictParams; @NameCn("LSTNet预测") +@NameEn("LSTNet Prediction") public class LSTNetPredictBatchOp extends ModelMapBatchOp implements LSTNetPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/LSTNetTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/LSTNetTrainBatchOp.java index 5aced4cb3..aafa3e430 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/LSTNetTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/LSTNetTrainBatchOp.java @@ -8,9 +8,11 @@ import org.apache.flink.types.Row; import org.apache.flink.util.Collector; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.Internal; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.ParamCond; import com.alibaba.alink.common.annotation.ParamCond.CondType; @@ -58,6 +60,7 @@ ) ) @NameCn("LSTNet训练") +@NameEn("LSTNet Training") public class LSTNetTrainBatchOp extends BatchOperator implements LSTNetTrainParams { @@ -102,6 +105,7 @@ public LSTNetTrainBatchOp linkFrom(BatchOperator ... inputs) { return this; } + @Internal private static class LSTNetPreProcessBatchOp extends BatchOperator implements LSTNetPreProcessParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/LookupValueInTimeSeriesBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/LookupValueInTimeSeriesBatchOp.java index aa9af3dd9..f90bad713 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/LookupValueInTimeSeriesBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/LookupValueInTimeSeriesBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -12,6 +13,7 @@ @ParamSelectColumnSpec(name="timeCol",allowedTypeCollections = TypeCollections.TIMESTAMP_TYPES) @ParamSelectColumnSpec(name="timeSeriesCol",allowedTypeCollections = TypeCollections.MTABLE_TYPES) @NameCn("时间序列插值") +@NameEn("Lookup Value In Time Series") public class LookupValueInTimeSeriesBatchOp extends MapBatchOp implements LookupValueInTimeSeriesParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/LookupVectorInTimeSeriesBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/LookupVectorInTimeSeriesBatchOp.java index 2cc2e89b9..9f616144e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/LookupVectorInTimeSeriesBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/LookupVectorInTimeSeriesBatchOp.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; import com.alibaba.alink.common.annotation.TypeCollections; import com.alibaba.alink.operator.batch.utils.MapBatchOp; @@ -12,6 +13,7 @@ @ParamSelectColumnSpec(name="timeCol",allowedTypeCollections = TypeCollections.TIMESTAMP_TYPES) @ParamSelectColumnSpec(name="timeSeriesCol",allowedTypeCollections = TypeCollections.MTABLE_TYPES) @NameCn("时间序列向量插值") +@NameEn("Lookup Vector In Time Series") public class LookupVectorInTimeSeriesBatchOp extends MapBatchOp implements LookupVectorInTimeSeriesParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ProphetBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ProphetBatchOp.java index 0f5e2f1b3..2fe125334 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ProphetBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ProphetBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.timeseries.ProphetMapper; import com.alibaba.alink.params.timeseries.ProphetParams; @NameCn("Prophet") +@NameEn("Prophet") public class ProphetBatchOp extends MapBatchOp implements ProphetParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ProphetPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ProphetPredictBatchOp.java index 3508d1a72..091b19d11 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ProphetPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ProphetPredictBatchOp.java @@ -3,12 +3,14 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.operator.common.timeseries.ProphetModelMapper; import com.alibaba.alink.params.timeseries.ProphetPredictParams; @NameCn("Prophet预测") +@NameEn("Prophet Prediction") public class ProphetPredictBatchOp extends ModelMapBatchOp implements ProphetPredictParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ProphetTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ProphetTrainBatchOp.java index 862167f7d..e52191d79 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ProphetTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ProphetTrainBatchOp.java @@ -11,6 +11,7 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.pyrunner.PythonMIMOUdaf; @@ -30,6 +31,7 @@ @InputPorts(values = @PortSpec(value = DATA)) @OutputPorts(values = @PortSpec(value = MODEL)) @NameCn("Prophet训练") +@NameEn("Prophet Training") public class ProphetTrainBatchOp extends BatchOperator implements ProphetTrainParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ShiftBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ShiftBatchOp.java index 4a76a10da..62792341e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ShiftBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/timeseries/ShiftBatchOp.java @@ -3,11 +3,13 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.utils.MapBatchOp; import com.alibaba.alink.operator.common.timeseries.ShiftMapper; import com.alibaba.alink.params.timeseries.ShiftParams; @NameCn("Shift") +@NameEn("Shift") public class ShiftBatchOp extends MapBatchOp implements ShiftParams { diff --git a/core/src/main/java/com/alibaba/alink/common/utils/DataSetConversionUtil.java b/core/src/main/java/com/alibaba/alink/operator/batch/utils/DataSetConversionUtil.java similarity index 99% rename from core/src/main/java/com/alibaba/alink/common/utils/DataSetConversionUtil.java rename to core/src/main/java/com/alibaba/alink/operator/batch/utils/DataSetConversionUtil.java index 4cc36634e..fcb1fb306 100644 --- a/core/src/main/java/com/alibaba/alink/common/utils/DataSetConversionUtil.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/utils/DataSetConversionUtil.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.common.utils; +package com.alibaba.alink.operator.batch.utils; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; diff --git a/core/src/main/java/com/alibaba/alink/common/utils/DataSetUtil.java b/core/src/main/java/com/alibaba/alink/operator/batch/utils/DataSetUtil.java similarity index 57% rename from core/src/main/java/com/alibaba/alink/common/utils/DataSetUtil.java rename to core/src/main/java/com/alibaba/alink/operator/batch/utils/DataSetUtil.java index 489437892..a8687c34f 100644 --- a/core/src/main/java/com/alibaba/alink/common/utils/DataSetUtil.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/utils/DataSetUtil.java @@ -1,11 +1,10 @@ -package com.alibaba.alink.common.utils; +package com.alibaba.alink.operator.batch.utils; import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.GroupReduceFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.MapPartitionFunction; -import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.api.common.functions.ReduceFunction; -import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.functions.RichMapPartitionFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; @@ -13,13 +12,27 @@ import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; +import com.alibaba.alink.common.io.filesystem.AkUtils2; +import com.alibaba.alink.common.io.filesystem.BaseFileSystem; +import com.alibaba.alink.common.io.filesystem.FilePath; +import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.common.io.dummy.DummyOutputFormat; +import com.alibaba.alink.operator.common.io.partition.SinkCollectorCreator; +import com.alibaba.alink.operator.common.io.partition.SourceCollectorCreator; +import com.alibaba.alink.operator.stream.sink.Export2FileOutputFormat; +import com.alibaba.alink.params.io.HasFilePath; +import com.alibaba.alink.params.io.shared.HasPartitionColsDefaultAsNull; +import com.alibaba.alink.params.io.shared.HasPartitions; import org.apache.commons.lang.ArrayUtils; +import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -190,4 +203,108 @@ public void mapPartition(Iterable values, Collector > out) thro } }); } + + public static void linkDummySink(DataSet dataSet) { + dataSet.output(new DummyOutputFormat ()); + } + + public static Tuple2 , TableSchema> readFromPartitionBatch( + final Params params, final Long sessionId, + final SourceCollectorCreator sourceCollectorCreator) throws IOException { + + return readFromPartitionBatch(params, sessionId, sourceCollectorCreator, null); + } + + public static Tuple2 , TableSchema> readFromPartitionBatch( + final Params params, final Long sessionId, + final SourceCollectorCreator sourceCollectorCreator, String[] partitionCols) throws IOException { + + final FilePath filePath = FilePath.deserialize(params.get(HasFilePath.FILE_PATH)); + final String partitions = params.get(HasPartitions.PARTITIONS); + + BatchOperator selected = AkUtils2 + .selectPartitionBatchOp(sessionId, filePath, partitions, partitionCols); + + final String[] colNames = selected.getColNames(); + + return Tuple2.of( + selected + .getDataSet() + .rebalance() + .flatMap(new FlatMapFunction () { + @Override + public void flatMap(Row value, Collector out) throws Exception { + Path path = filePath.getPath(); + + for (int i = 0; i < value.getArity(); ++i) { + path = new Path(path, String.format("%s=%s", colNames[i], value.getField(i))); + } + + sourceCollectorCreator.collect(new FilePath(path, filePath.getFileSystem()), out); + } + }), + sourceCollectorCreator.schema() + ); + } + + public static void partitionAndWriteFile( + BatchOperator input, SinkCollectorCreator sinkCollectorCreator, Params params) { + + TableSchema schema = input.getSchema(); + + String[] partitionCols = params.get(HasPartitionColsDefaultAsNull.PARTITION_COLS); + final int[] partitionColIndices = TableUtil.findColIndicesWithAssertAndHint(schema, partitionCols); + final String[] reservedCols = org.apache.commons.lang3.ArrayUtils.removeElements(schema.getFieldNames(), partitionCols); + final int[] reservedColIndices = TableUtil.findColIndices(schema.getFieldNames(), reservedCols); + final FilePath localFilePath = FilePath.deserialize(params.get(HasFilePath.FILE_PATH)); + + input + .getDataSet() + .groupBy(partitionCols) + .reduceGroup(new GroupReduceFunction () { + @Override + public void reduce(Iterable values, Collector out) throws IOException { + Path root = localFilePath.getPath(); + BaseFileSystem fileSystem = localFilePath.getFileSystem(); + + Collector collector = null; + Path localPath = null; + + for (Row row : values) { + if (collector == null) { + localPath = new Path(root.getPath()); + + for (int partitionColIndex : partitionColIndices) { + localPath = new Path(localPath, row.getField(partitionColIndex).toString()); + } + + fileSystem.mkdirs(localPath); + + collector = sinkCollectorCreator.createCollector(new FilePath( + new Path( + localPath, "0" + Export2FileOutputFormat.IN_PROGRESS_FILE_SUFFIX + ), + fileSystem + )); + } + + collector.collect(Row.project(row, reservedColIndices)); + } + + if (collector != null) { + collector.close(); + + fileSystem.rename( + new Path( + localPath, "0" + Export2FileOutputFormat.IN_PROGRESS_FILE_SUFFIX + ), + new Path( + localPath, "0" + ) + ); + } + } + }) + .output(new DummyOutputFormat <>()); + } } diff --git a/core/src/main/java/com/alibaba/alink/common/lazy/ExtractModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/utils/ExtractModelInfoBatchOp.java similarity index 89% rename from core/src/main/java/com/alibaba/alink/common/lazy/ExtractModelInfoBatchOp.java rename to core/src/main/java/com/alibaba/alink/operator/batch/utils/ExtractModelInfoBatchOp.java index 145202b9a..f4e49baa6 100644 --- a/core/src/main/java/com/alibaba/alink/common/lazy/ExtractModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/utils/ExtractModelInfoBatchOp.java @@ -1,8 +1,12 @@ -package com.alibaba.alink.common.lazy; +package com.alibaba.alink.operator.batch.utils; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.common.lazy.LazyEvaluation; +import com.alibaba.alink.common.lazy.LazyObjectsManager; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.BaseSourceBatchOp; @@ -10,6 +14,8 @@ import java.util.List; import java.util.function.Consumer; +@NameCn("模型信息抽取") +@NameEn("Model Information Extraction") public abstract class ExtractModelInfoBatchOp> extends BatchOperator { private static final long serialVersionUID = 8426490988920758149L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/utils/ModelMapBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/utils/ModelMapBatchOp.java index 45732b4bd..66a089a32 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/utils/ModelMapBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/utils/ModelMapBatchOp.java @@ -41,7 +41,7 @@ import com.alibaba.alink.common.model.BroadcastVariableModelSource; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.AkSourceBatchOp; -import com.alibaba.alink.operator.common.stream.model.ModelStreamUtils; +import com.alibaba.alink.operator.common.modelstream.ModelStreamUtils; import com.alibaba.alink.params.mapper.ModelMapperParams; import com.alibaba.alink.params.shared.HasModelFilePath; import com.alibaba.alink.params.shared.HasPredictBatchSize; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/utils/PrintBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/utils/PrintBatchOp.java index f620a696b..e7fa51d10 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/utils/PrintBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/utils/PrintBatchOp.java @@ -4,6 +4,7 @@ import org.apache.flink.types.Row; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; import com.alibaba.alink.common.exceptions.ExceptionWithErrorCode; import com.alibaba.alink.common.io.annotations.AnnotationUtils; @@ -21,6 +22,7 @@ */ @IoOpAnnotation(name = "print", ioType = IOType.SinkBatch) @NameCn("批式数据打印") +@NameEn("Print Operation") public class PrintBatchOp extends BaseSinkBatchOp { private static final long serialVersionUID = -8361687806231696283L; diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/utils/StatsVisualizer.java b/core/src/main/java/com/alibaba/alink/operator/batch/utils/StatsVisualizer.java index 13e6a12d3..0c16ec8a3 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/utils/StatsVisualizer.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/utils/StatsVisualizer.java @@ -44,7 +44,7 @@ private String generateHtmlTitle(DatasetFeatureStatisticsList datasetFeatureStat AkPreconditions.checkArgument( datasetFeatureStatisticsList.getDatasetsCount() == newTableNames.length, new AkIllegalDataException("The number of new table names must be equal to the number of datasets.")); - title = String.join(", ", newTableNames) + "'s Stats"; + title = "Stats"; } else { title = datasetFeatureStatisticsList.getDatasetsList() .stream().map(DatasetFeatureStatistics::getName) diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/utils/UDFBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/utils/UDFBatchOp.java index 00cb3f3ed..bd59ed87d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/utils/UDFBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/utils/UDFBatchOp.java @@ -7,6 +7,7 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -32,6 +33,7 @@ @InputPorts(values = @PortSpec(PortType.DATA)) @OutputPorts(values = @PortSpec(PortType.DATA)) @NameCn("UDF") +@NameEn("UDF") public class UDFBatchOp extends BatchOperator implements UDFParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/utils/UDTFBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/utils/UDTFBatchOp.java index 6a2784b47..a9959af5a 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/utils/UDTFBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/utils/UDTFBatchOp.java @@ -7,6 +7,7 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.common.annotation.OutputPorts; import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; @@ -32,6 +33,7 @@ @InputPorts(values = @PortSpec(PortType.DATA)) @OutputPorts(values = @PortSpec(PortType.DATA)) @NameCn("UDTF") +@NameEn("UDTF") public class UDTFBatchOp extends BatchOperator implements UDTFParams { diff --git a/core/src/main/java/com/alibaba/alink/common/lazy/WithModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/utils/WithModelInfoBatchOp.java similarity index 95% rename from core/src/main/java/com/alibaba/alink/common/lazy/WithModelInfoBatchOp.java rename to core/src/main/java/com/alibaba/alink/operator/batch/utils/WithModelInfoBatchOp.java index 84773a45c..d98ffc47b 100644 --- a/core/src/main/java/com/alibaba/alink/common/lazy/WithModelInfoBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/utils/WithModelInfoBatchOp.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.common.lazy; +package com.alibaba.alink.operator.batch.utils; import com.alibaba.alink.params.shared.HasMLEnvironmentId; diff --git a/core/src/main/java/com/alibaba/alink/common/lazy/WithTrainInfo.java b/core/src/main/java/com/alibaba/alink/operator/batch/utils/WithTrainInfo.java similarity index 90% rename from core/src/main/java/com/alibaba/alink/common/lazy/WithTrainInfo.java rename to core/src/main/java/com/alibaba/alink/operator/batch/utils/WithTrainInfo.java index ca83a2a42..5637b8612 100644 --- a/core/src/main/java/com/alibaba/alink/common/lazy/WithTrainInfo.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/utils/WithTrainInfo.java @@ -1,7 +1,9 @@ -package com.alibaba.alink.common.lazy; +package com.alibaba.alink.operator.batch.utils; import org.apache.flink.types.Row; +import com.alibaba.alink.common.lazy.LazyEvaluation; +import com.alibaba.alink.common.lazy.LazyObjectsManager; import com.alibaba.alink.operator.batch.BatchOperator; import java.util.Arrays; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/aps/ApsContext.java b/core/src/main/java/com/alibaba/alink/operator/common/aps/ApsContext.java index dac49ac2d..2ae36436d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/aps/ApsContext.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/aps/ApsContext.java @@ -24,7 +24,7 @@ import com.alibaba.alink.common.exceptions.AkPreconditions; import com.alibaba.alink.common.io.directreader.DefaultDistributedInfo; import com.alibaba.alink.common.io.directreader.DistributedInfo; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp; import com.alibaba.alink.operator.batch.source.TableSourceBatchOp; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/aps/ApsEnv.java b/core/src/main/java/com/alibaba/alink/operator/common/aps/ApsEnv.java index 84f3ffb0a..2f3d5c99e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/aps/ApsEnv.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/aps/ApsEnv.java @@ -19,7 +19,7 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.TableSourceBatchOp; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/aps/ApsSerializeData.java b/core/src/main/java/com/alibaba/alink/operator/common/aps/ApsSerializeData.java index fafa14ee5..50981cff6 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/aps/ApsSerializeData.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/aps/ApsSerializeData.java @@ -10,7 +10,7 @@ import org.apache.flink.types.Row; import com.alibaba.alink.common.MLEnvironmentFactory; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.TableSourceBatchOp; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/aps/ApsSerializeModel.java b/core/src/main/java/com/alibaba/alink/operator/common/aps/ApsSerializeModel.java index bc23e84c3..04d8c3d1c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/aps/ApsSerializeModel.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/aps/ApsSerializeModel.java @@ -10,7 +10,7 @@ import org.apache.flink.types.Row; import com.alibaba.alink.common.MLEnvironmentFactory; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.TableSourceBatchOp; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/associationrule/ApplyAssociationRuleModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/associationrule/ApplyAssociationRuleModelMapper.java new file mode 100644 index 000000000..a3c4681ac --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/associationrule/ApplyAssociationRuleModelMapper.java @@ -0,0 +1,71 @@ +package com.alibaba.alink.operator.common.associationrule; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.mapper.SISOModelMapper; +import com.alibaba.alink.operator.batch.associationrule.FpGrowthBatchOp; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * The model mapper of applying the Association Rules. + */ +public class ApplyAssociationRuleModelMapper extends SISOModelMapper { + private static final long serialVersionUID = 3709131767975976366L; + private final String sep = FpGrowthBatchOp.ITEM_SEPARATOR; + private transient List > antecedents; + private transient List consequences; + + public ApplyAssociationRuleModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { + super(modelSchema, dataSchema, params); + } + + @Override + protected TypeInformation initPredResultColType() { + return Types.STRING; + } + + @Override + protected Object predictResult(Object input) throws Exception { + Set items = new HashSet <>(Arrays.asList(((String) input).split(sep))); + Set prediction = new HashSet <>(); + for (int i = 0; i < antecedents.size(); i++) { + if (items.containsAll(antecedents.get(i))) { + String consequent = consequences.get(i); + if (!items.contains(consequent)) { + prediction.add(consequent); + } + } + } + StringBuilder sbd = new StringBuilder(); + int cnt = 0; + for (String p : prediction) { + if (cnt > 0) { + sbd.append(sep); + } + sbd.append(p); + cnt++; + } + return sbd.toString(); + } + + @Override + public void loadModel(List modelRows) { + final int numRules = modelRows.size(); + this.antecedents = new ArrayList <>(numRules); + this.consequences = new ArrayList <>(numRules); + modelRows.forEach(row -> { + String[] rule = ((String) row.getField(0)).split("=>"); + this.consequences.add(rule[1]); + this.antecedents.add(new HashSet <>(Arrays.asList(rule[0].split(sep)))); + }); + } +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/associationrule/ApplySequenceRuleModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/associationrule/ApplySequenceRuleModelMapper.java new file mode 100644 index 000000000..bb811c741 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/associationrule/ApplySequenceRuleModelMapper.java @@ -0,0 +1,106 @@ +package com.alibaba.alink.operator.common.associationrule; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.mapper.SISOModelMapper; +import com.alibaba.alink.operator.batch.associationrule.PrefixSpanBatchOp; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * The model mapper of applying Sequence Rules. + */ +public class ApplySequenceRuleModelMapper extends SISOModelMapper { + private static final long serialVersionUID = 2458317592464336059L; + private transient List >> antecedents; + private transient List consequent; + + public ApplySequenceRuleModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { + super(modelSchema, dataSchema, params); + } + + private static List > parseSequenceString(String seqStr) { + String[] elements = seqStr.split(PrefixSpanBatchOp.ELEMENT_SEPARATOR); + List > sequence = new ArrayList <>(elements.length); + for (String element : elements) { + sequence.add(new HashSet <>(Arrays.asList(element.split(PrefixSpanBatchOp.ITEM_SEPARATOR)))); + } + return sequence; + } + + private static boolean isSubSequence(List > seq, List > sub) { + if (sub.size() > seq.size()) { + return false; + } + int subLen = sub.size(); + int seqPos = 0; + int numMatchedElement = 0; + for (Set aSub : sub) { + while (seqPos < seq.size()) { + if (isSubset(seq.get(seqPos++), aSub)) { + numMatchedElement++; + break; + } + } + } + return numMatchedElement == subLen; + } + + private static boolean isSubset(Set set, Set sub) { + return set.containsAll(sub); + } + + @Override + protected TypeInformation initPredResultColType() { + return Types.STRING; + } + + @Override + protected Object predictResult(Object input) throws Exception { + List > sequence = parseSequenceString((String) input); + Set prediction = new HashSet <>(); + for (int i = 0; i < antecedents.size(); i++) { + if (isSubSequence(sequence, antecedents.get(i))) { + String consequent = this.consequent.get(i); + List > pred = new ArrayList <>(); + pred.addAll(antecedents.get(i)); + pred.addAll(parseSequenceString(consequent)); + if (!isSubSequence(sequence, pred)) { + prediction.add(consequent); + } + } + } + StringBuilder sbd = new StringBuilder(); + int cnt = 0; + for (String p : prediction) { + if (cnt > 0) { + sbd.append(PrefixSpanBatchOp.ELEMENT_SEPARATOR); + } + sbd.append(p); + cnt++; + } + return sbd.toString(); + } + + @Override + public void loadModel(List modelRows) { + final int numRules = modelRows.size(); + this.antecedents = new ArrayList <>(numRules); + this.consequent = new ArrayList <>(numRules); + modelRows.forEach(row -> { + String[] rule = ((String) row.getField(0)).split(PrefixSpanBatchOp.RULE_SEPARATOR); + String consequentStr = rule[1]; + String antecedentStr = rule[0]; + this.antecedents.add(parseSequenceString(antecedentStr)); + this.consequent.add(consequentStr); + }); + } +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/audio/ExtractMfccFeatureMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/audio/ExtractMfccFeatureMapper.java index 2316becfd..5c1a199e3 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/audio/ExtractMfccFeatureMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/audio/ExtractMfccFeatureMapper.java @@ -9,7 +9,7 @@ import com.alibaba.alink.common.linalg.Vector; import com.alibaba.alink.common.linalg.tensor.FloatTensor; import com.alibaba.alink.common.linalg.tensor.Tensor; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.mapper.SISOMapper; import com.alibaba.alink.operator.common.dataproc.MFCC; import com.alibaba.alink.params.audio.ExtractMfccFeatureParams; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/audio/ReadAudioToTensorMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/audio/ReadAudioToTensorMapper.java index a1d250b7f..a091e5fd1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/audio/ReadAudioToTensorMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/audio/ReadAudioToTensorMapper.java @@ -7,7 +7,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.io.filesystem.BaseFileSystem; import com.alibaba.alink.common.io.filesystem.FilePath; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/classification/OneVsRestModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/classification/OneVsRestModelMapper.java index e7916558f..59880c2c4 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/classification/OneVsRestModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/classification/OneVsRestModelMapper.java @@ -10,7 +10,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.VectorUtil; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/AnnObjFunc.java b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/AnnObjFunc.java index 86cba8641..cfd5aa580 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/AnnObjFunc.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/classification/ann/AnnObjFunc.java @@ -29,7 +29,7 @@ public AnnObjFunc(Topology topology, } @Override - protected double calcLoss(Tuple3 labledVector, DenseVector coefVector) { + public double calcLoss(Tuple3 labledVector, DenseVector coefVector) { if (topologyModel == null) { topologyModel = topology.getModel(coefVector); } else { @@ -40,7 +40,7 @@ protected double calcLoss(Tuple3 labledVector, DenseVec } @Override - protected void updateGradient(Tuple3 labledVector, DenseVector coefVector, + public void updateGradient(Tuple3 labledVector, DenseVector coefVector, DenseVector updateGrad) { if (topologyModel == null) { topologyModel = topology.getModel(coefVector); @@ -52,7 +52,7 @@ protected void updateGradient(Tuple3 labledVector, Dens } @Override - protected void updateHessian(Tuple3 labledVector, DenseVector coefVector, + public void updateHessian(Tuple3 labledVector, DenseVector coefVector, DenseMatrix updateHessian) { throw new AkUnsupportedOperationException("not supported."); } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/classification/tensorflow/TFTableModelClassificationFlatModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/classification/tensorflow/TFTableModelClassificationFlatModelMapper.java index f5cc8568d..73a5f7389 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/classification/tensorflow/TFTableModelClassificationFlatModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/classification/tensorflow/TFTableModelClassificationFlatModelMapper.java @@ -7,7 +7,7 @@ import org.apache.flink.types.Row; import org.apache.flink.util.Collector; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.dl.plugin.TFPredictorClassLoaderFactory; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; import com.alibaba.alink.common.linalg.tensor.FloatTensor; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/classification/tensorflow/TFTableModelClassificationModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/classification/tensorflow/TFTableModelClassificationModelMapper.java index 75ae70117..e5c5a94c8 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/classification/tensorflow/TFTableModelClassificationModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/classification/tensorflow/TFTableModelClassificationModelMapper.java @@ -8,7 +8,7 @@ import com.alibaba.alink.common.dl.plugin.TFPredictorClassLoaderFactory; import com.alibaba.alink.common.linalg.tensor.FloatTensor; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.mapper.IterableModelLoader; import com.alibaba.alink.common.mapper.Mapper; import com.alibaba.alink.common.mapper.MapperChain; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/ClusterSummary.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/ClusterSummary.java new file mode 100644 index 000000000..8d9ac2b13 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/ClusterSummary.java @@ -0,0 +1,32 @@ +package com.alibaba.alink.operator.common.clustering; + +import java.io.Serializable; + +import static com.alibaba.alink.common.utils.JsonConverter.gson; + +public class ClusterSummary implements Serializable { + private static final long serialVersionUID = 8631961988528403010L; + public String center; + public Long clusterId; + public Double weight; + public String type; + + public ClusterSummary(String center, Long clusterId, Double weight, String type) { + this.center = center; + this.clusterId = clusterId; + this.weight = weight; + this.type = type; + } + + public ClusterSummary(String center, Long clusterId, Double weight) { + this(center, clusterId, weight, null); + } + + public static ClusterSummary deserialize(String json) { + return gson.fromJson(json, ClusterSummary.class); + } + + public String serialize() { + return gson.toJson(this, ClusterSummary.class); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/DistanceType.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/DistanceType.java new file mode 100644 index 000000000..c8d9c1437 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/DistanceType.java @@ -0,0 +1,48 @@ +package com.alibaba.alink.operator.common.clustering; + +import com.alibaba.alink.operator.common.distance.CosineDistance; +import com.alibaba.alink.operator.common.distance.EuclideanDistance; +import com.alibaba.alink.operator.common.distance.FastDistance; +import com.alibaba.alink.operator.common.distance.HaversineDistance; +import com.alibaba.alink.operator.common.distance.JaccardDistance; +import com.alibaba.alink.operator.common.distance.ManHattanDistance; + +import java.io.Serializable; + +/** + * Various distance types. + */ +public enum DistanceType implements Serializable { + /** + * EUCLIDEAN + */ + EUCLIDEAN(new EuclideanDistance()), + /** + * COSINE + */ + COSINE(new CosineDistance()), + /** + * CITYBLOCK + */ + CITYBLOCK(new ManHattanDistance()), + /** + * HAVERSINE + */ + HAVERSINE(new HaversineDistance()), + + /** + * JACCARD + */ + JACCARD(new JaccardDistance()); + + public FastDistance getFastDistance() { + return fastDistance; + } + + private FastDistance fastDistance; + + DistanceType(FastDistance fastDistance) { + this.fastDistance = fastDistance; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/FindResult.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/FindResult.java new file mode 100644 index 000000000..0558e373c --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/FindResult.java @@ -0,0 +1,24 @@ +package com.alibaba.alink.operator.common.clustering; + +/** + * the result of findCluster function + * + * @author guotao.gt + */ +public class FindResult { + private Long clusterId; + private Double distance; + + public FindResult(Long clusterId, Double distance) { + this.clusterId = clusterId; + this.distance = distance; + } + + public Long getClusterId() { + return clusterId; + } + + public Double getDistance() { + return distance; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/LocalKMeans.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/LocalKMeans.java new file mode 100644 index 000000000..8c77a0b43 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/LocalKMeans.java @@ -0,0 +1,131 @@ +package com.alibaba.alink.operator.common.clustering; + +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.operator.common.clustering.common.Center; +import com.alibaba.alink.operator.common.clustering.common.Cluster; +import com.alibaba.alink.operator.common.clustering.common.Sample; +import com.alibaba.alink.operator.common.distance.ContinuousDistance; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +/** + * @author guotao.gt + */ +public class LocalKMeans { + + public static void clustering(Iterable values, Collector out, int k, double epsilon, + int maxIter, ContinuousDistance distance) { + List samples = new ArrayList <>(); + for (Sample sample : values) { + samples.add(sample); + } + Sample[] outSamples = new LocalKMeans().clustering(samples.toArray(new Sample[samples.size()]), k, epsilon, + maxIter, distance); + for (Sample sample : outSamples) { + out.collect(sample); + } + } + + public static FindResult findCluster(Center[] centers, DenseVector sample, ContinuousDistance continuousDistance) { + long clusterId = -1; + double d = Double.POSITIVE_INFINITY; + + for (Center c : centers) { + if (null != c) { + double distance = continuousDistance.calc(sample, c.getVector()); + + if (distance < d) { + clusterId = c.getClusterId(); + d = distance; + } + } + } + return new FindResult(clusterId, d); + } + + public static Center[] getCentersFromClusters(Sample[] samples, int k) { + Cluster[] clusters = new Cluster[k]; + for (int i = 0; i < k; i++) { + clusters[i] = new Cluster(); + } + for (int i = 0; i < samples.length; i++) { + clusters[(int) samples[i].getClusterId()].addSample(samples[i]); + } + List

list = new ArrayList <>(); + for (int i = 0; i < clusters.length; i++) { + if (clusters[i].getCenter().getVector() != null) { + list.add(clusters[i].getCenter()); + } + } + for (int i = 0; i < list.size(); i++) { + list.get(i).setClusterId((long) i); + } + return list.toArray(new Center[0]); + } + + public Sample[] clustering(Sample[] samples, int k, double epsilon, int maxIter, ContinuousDistance distance) { + k = k > samples.length ? samples.length : k; + kMeansClustering(samples, epsilon, k, distance, maxIter); + return samples; + } + + /** + * get initial centers + * + * @param k + * @return + */ + private Center[] getInitialCenters(Sample[] samples, int k) { + Center[] centers = new Center[k]; + int size = samples.length; + boolean[] flags = new boolean[size]; + Arrays.fill(flags, false); + //random + if (k < size / 3) { + int clusterId = 0; + while (clusterId < k) { + int randomInt = new Random().nextInt(size); + if (!flags[randomInt]) { + centers[clusterId] = new Center(clusterId, 0, samples[randomInt].getVector()); + clusterId++; + } + flags[randomInt] = true; + } + //topK + } else { + for (int i = 0; i < k; i++) { + centers[i] = new Center(i, 0, samples[i].getVector()); + } + } + return centers; + } + + private Center[] kMeansClustering(Sample[] samples, double epsilon, int k, ContinuousDistance distance, + int maxIter) { + Center[] centers = getInitialCenters(samples, k); + int iter = 0; + double oldSsw = 0; + while (iter++ < maxIter) { + double newSsw = 0; + for (int i = 0; i < samples.length; i++) { + FindResult findResult = findCluster(centers, samples[i].getVector(), distance); + samples[i].setClusterId(findResult.getClusterId()); + newSsw += findResult.getDistance(); + } + centers = getCentersFromClusters(samples, k); + if (Math.abs(newSsw - oldSsw) / samples.length < epsilon) { + break; + } else { + oldSsw = newSsw; + } + } + + return centers; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/Agnes.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/Agnes.java new file mode 100644 index 000000000..c58763097 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/Agnes.java @@ -0,0 +1,170 @@ +package com.alibaba.alink.operator.common.clustering.agnes; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.utils.AlinkSerializable; +import com.alibaba.alink.operator.common.clustering.kmeans.KMeansInitCentroids; +import com.alibaba.alink.operator.common.distance.ContinuousDistance; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; + +/** + * @author guotao.gt + */ +public class Agnes implements AlinkSerializable { + private static final Logger LOG = LoggerFactory.getLogger(Agnes.class); + + /** + * @param agnesSamples + * @param k default 1 + * @param distanceThreshold default double.MAX + * @return + */ + public static List startAnalysis(List agnesSamples, int k, double distanceThreshold, + Linkage linkage, ContinuousDistance distance) { + List originalClusters = initialCluster(agnesSamples); + List finalClusters = originalClusters; + int iter = 1; + while (true) { + if (finalClusters.size() <= k) { + break; + } + double min = Double.MAX_VALUE; + int mergeIndexA = 0; + int mergeIndexB = 0; + for (int i = 0; i < finalClusters.size(); i++) { + for (int j = 0; j < finalClusters.size(); j++) { + if (i != j) { + AgnesCluster clusterA = finalClusters.get(i); + AgnesCluster clusterB = finalClusters.get(j); + List dataPointsA = clusterA.getAgnesSamples(); + List dataPointsB = clusterB.getAgnesSamples(); + + switch (linkage) { + case MIN: + double minDistance = Double.MAX_VALUE; + for (int m = 0; m < dataPointsA.size(); m++) { + for (int n = 0; n < dataPointsB.size(); n++) { + double tempDis = distance.calc( + dataPointsA.get(m).getVector(), dataPointsB.get(n).getVector()); + if (tempDis < minDistance) { + minDistance = tempDis; + } + } + } + if (minDistance < min) { + min = minDistance; + mergeIndexA = i; + mergeIndexB = j; + } + break; + case MAX: + double maxDistance = Double.MIN_VALUE; + for (int m = 0; m < dataPointsA.size(); m++) { + for (int n = 0; n < dataPointsB.size(); n++) { + double tempDis = distance.calc( + dataPointsA.get(m).getVector(), dataPointsB.get(n).getVector()); + if (tempDis > maxDistance) { + maxDistance = tempDis; + } + } + } + if (maxDistance < min) { + min = maxDistance; + mergeIndexA = i; + mergeIndexB = j; + } + break; + case AVERAGE: + double averageDistance = 0; + for (int m = 0; m < dataPointsA.size(); m++) { + for (int n = 0; n < dataPointsB.size(); n++) { + averageDistance += distance.calc( + dataPointsA.get(m).getVector(), dataPointsB.get(n).getVector()); + } + } + averageDistance /= dataPointsA.size(); + if (averageDistance < min) { + min = averageDistance; + mergeIndexA = i; + mergeIndexB = j; + } + break; + case MEAN: + DenseVector vectorA = mean(dataPointsA, distance); + DenseVector vectorB = mean(dataPointsB, distance); + double meanDistance = distance.calc(vectorA, vectorB); + if (meanDistance < min) { + min = meanDistance; + mergeIndexA = i; + mergeIndexB = j; + } + break; + default: + throw new RuntimeException("linkage not support:" + linkage); + + } + } + } + } + finalClusters = mergeCluster(finalClusters, mergeIndexA, mergeIndexB, iter); + LOG.info("Iteration:" + iter + "; distance:" + min); + iter++; + if (min > distanceThreshold) { + break; + } + } + + return finalClusters; + } + + private static List mergeCluster(List clusters, int mergeIndexA, int mergeIndexB, + int mergeIter) { + if (mergeIndexA != mergeIndexB) { + AgnesCluster clusterA = clusters.get(mergeIndexA); + AgnesCluster clusterB = clusters.get(mergeIndexB); + List dpB = clusterB.getAgnesSamples(); + clusterB.getFirstSample().setParentId(clusterA.getFirstSample().getSampleId()); + clusterB.getFirstSample().setMergeIter(mergeIter); + for (AgnesSample dp : dpB) { + dp.setClusterId(clusterA.getClusterId()); + clusterA.addDataPoints(dp); + } + + clusters.remove(mergeIndexB); + } + + return clusters; + } + + private static List initialCluster(List agnesSamples) { + List originalClusters = new ArrayList <>(); + for (int i = 0; i < agnesSamples.size(); i++) { + agnesSamples.get(i).setClusterId(i); + originalClusters.add(new AgnesCluster(i, agnesSamples.get(i))); + } + return originalClusters; + } + + /** + * calc the center of cluster + * + * @param agnesSamples + * @return + */ + public static DenseVector mean(List agnesSamples, ContinuousDistance distance) { + if (null != agnesSamples && agnesSamples.size() > 0) { + int dim = agnesSamples.get(0).getVector().size(); + DenseVector r = new DenseVector(dim); + for (AgnesSample dp : agnesSamples) { + r.plusEqual((DenseVector) dp.getVector()); + } + r.scaleEqual(1.0 / agnesSamples.size()); + return r; + } + return null; + } + +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/AgnesCluster.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/AgnesCluster.java new file mode 100644 index 000000000..c6722ce14 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/AgnesCluster.java @@ -0,0 +1,38 @@ +package com.alibaba.alink.operator.common.clustering.agnes; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +/** + * @author guotao.gt + */ +public class AgnesCluster implements Serializable { + private static final long serialVersionUID = -3460634266927981428L; + private int clusterId; + + private AgnesSample firstSample; + private List agnesSamples = new ArrayList <>(); + + public AgnesCluster(int clusterId, AgnesSample agnesSample) { + this.agnesSamples.add(agnesSample); + this.clusterId = clusterId; + this.firstSample = agnesSample; + } + + public AgnesSample getFirstSample() { + return firstSample; + } + + public List getAgnesSamples() { + return agnesSamples; + } + + public void addDataPoints(AgnesSample agnesSample) { + this.agnesSamples.add(agnesSample); + } + + public int getClusterId() { + return clusterId; + } +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/AgnesModelData.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/AgnesModelData.java new file mode 100644 index 000000000..2e526459f --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/AgnesModelData.java @@ -0,0 +1,16 @@ +package com.alibaba.alink.operator.common.clustering.agnes; + +import com.alibaba.alink.params.shared.clustering.HasClusteringDistanceType; + +import java.util.List; + +public class AgnesModelData { + public int k; + public double distanceThreshold; + public HasClusteringDistanceType.DistanceType distanceType; + public Linkage linkage; + public String[] featureColNames; + public String idCol; + + public List centroids; +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/AgnesModelDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/AgnesModelDataConverter.java new file mode 100644 index 000000000..b4d98715c --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/AgnesModelDataConverter.java @@ -0,0 +1,65 @@ +package com.alibaba.alink.operator.common.clustering.agnes; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.model.SimpleModelDataConverter; +import com.alibaba.alink.operator.common.clustering.ClusterSummary; +import com.alibaba.alink.operator.common.clustering.DistanceType; +import com.alibaba.alink.operator.common.distance.ContinuousDistance; +import com.alibaba.alink.params.clustering.AgnesParams; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.misc.param.Params; + +import java.util.ArrayList; +import java.util.List; + +import static com.alibaba.alink.common.utils.JsonConverter.gson; + +/** + * @author guotao.gt + */ +public class AgnesModelDataConverter extends SimpleModelDataConverter { + /** + * predict scope (after loading, before predicting) + */ + + public AgnesModelDataConverter() { + } + + @Override + public Tuple2> serializeModel(AgnesModelData modelData) { + List modelRows = new ArrayList <>(); + for (AgnesSample agnesSample : modelData.centroids) { + String str = gson.toJson(agnesSample.getVector()); + modelRows.add(new ClusterSummary(str, agnesSample.getClusterId(), agnesSample.getWeight()).serialize()); + } + Params meta = new Params() + .set(AgnesParams.K, modelData.k) + .set(AgnesParams.DISTANCE_THRESHOLD, modelData.distanceThreshold) + .set(AgnesParams.DISTANCE_TYPE, modelData.distanceType) + .set(AgnesParams.LINKAGE, modelData.linkage) + .set(AgnesParams.ID_COL, modelData.idCol); + return Tuple2.of(meta, modelRows); + } + + @Override + public AgnesModelData deserializeModel(Params meta, Iterable data) { + AgnesModelData modelData = new AgnesModelData(); + modelData.k = meta.get(AgnesParams.K); + modelData.distanceThreshold = meta.get(AgnesParams.DISTANCE_THRESHOLD); + modelData.linkage = meta.get(AgnesParams.LINKAGE); + modelData.idCol = meta.contains(AgnesParams.ID_COL) ? meta.get(AgnesParams.ID_COL) : ""; + modelData.centroids = new ArrayList (); + modelData.distanceType = meta.get(AgnesParams.DISTANCE_TYPE); + + // get the model data + for (String row : data) { + ClusterSummary c = ClusterSummary.deserialize(row); + long clusterId = c.clusterId; + double weight = c.weight; + AgnesSample agnesSample = new AgnesSample(null, clusterId, gson.fromJson(c.center, DenseVector.class), + weight); + modelData.centroids.add(agnesSample); + } + return modelData; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/AgnesModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/AgnesModelInfoBatchOp.java new file mode 100644 index 000000000..dbe7cad5e --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/AgnesModelInfoBatchOp.java @@ -0,0 +1,75 @@ +package com.alibaba.alink.operator.common.clustering.agnes; + +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; + +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class AgnesModelInfoBatchOp + extends ExtractModelInfoBatchOp { + private static final long serialVersionUID = 1735133462550836751L; + + public AgnesModelInfoBatchOp() { + this(null); + } + + public AgnesModelInfoBatchOp(Params params) { + super(params); + } + + @Override + public AgnesModelSummary createModelInfo(List rows) { + AgnesModelInfoBatchOp.AgnesModelSummary summary = new AgnesModelInfoBatchOp.AgnesModelSummary(); + summary.cluster = new HashMap <>(); + summary.totalSamples = 0; + for (Row row : rows) { + summary.totalSamples++; + Object id = row.getField(0); + int clusterId = ((Number) row.getField(1)).intValue(); + List list = summary.cluster.get(clusterId); + if (null == list) { + list = new ArrayList <>(); + } + list.add(id); + summary.cluster.put(clusterId, list); + } + return summary; + } + + public static class AgnesModelSummary { + private static final long serialVersionUID = 5349212648420863302L; + private Map > cluster; + + private int totalSamples = 0; + + public AgnesModelSummary() { + } + + public int getClusterNumber() { + return cluster.size(); + } + + public Object[] getPoints(int clusterId) { + return cluster.get(clusterId).toArray(new Object[0]); + } + + @Override + public String toString() { + StringBuilder sbd = new StringBuilder(PrettyDisplayUtils.displayHeadline("AgnesModelSummary", '-')); + sbd.append("Agnes clustering with ") + .append(getClusterNumber()) + .append(" clusters on ") + .append(totalSamples) + .append(" samples.\n"); + sbd.append(PrettyDisplayUtils.displayHeadline("Clusters", '=')); + sbd.append(PrettyDisplayUtils.displayMap(cluster, 10, true)).append("\n"); + return sbd.toString(); + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/AgnesSample.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/AgnesSample.java new file mode 100644 index 000000000..47c54f413 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/AgnesSample.java @@ -0,0 +1,66 @@ +package com.alibaba.alink.operator.common.clustering.agnes; + +import com.alibaba.alink.common.linalg.DenseVector; + +import java.io.Serializable; + +/** + * @author guotao.gt + */ +public class AgnesSample implements Serializable { + private static final long serialVersionUID = -7913806684961377173L; + + private String parentId; + + private Long mergeIter; + private String sampleId; + private long clusterId; + private double weight = 1.0; + private DenseVector vector; + + public AgnesSample() { + } + + public AgnesSample(String sampleId, long clusterId, DenseVector vector, double weight) { + this.sampleId = sampleId; + this.clusterId = clusterId; + this.vector = vector; + this.weight = weight; + } + + public DenseVector getVector() { + return vector; + } + + public long getClusterId() { + return clusterId; + } + + public void setClusterId(long clusterId) { + this.clusterId = clusterId; + } + + public void setParentId(String parentId) { + this.parentId = parentId; + } + + public String getParentId() { + return parentId; + } + + public Long getMergeIter() { + return mergeIter; + } + + public void setMergeIter(int mergeIter) { + this.mergeIter = (long) mergeIter; + } + + public String getSampleId() { + return sampleId; + } + + public double getWeight() { + return weight; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/Linkage.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/Linkage.java new file mode 100644 index 000000000..f4560f4b5 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/agnes/Linkage.java @@ -0,0 +1,20 @@ +package com.alibaba.alink.operator.common.clustering.agnes; + +public enum Linkage { + /** + * MIN + */ + MIN, + /** + * MAX + */ + MAX, + /** + * MEAN + */ + MEAN, + /** + * AVERAGE + */ + AVERAGE +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/common/Center.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/common/Center.java new file mode 100644 index 000000000..ea38ed1c7 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/common/Center.java @@ -0,0 +1,47 @@ +package com.alibaba.alink.operator.common.clustering.common; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.utils.AlinkSerializable; + +/** + * @author guotao.gt + */ +public class Center implements AlinkSerializable { + + public void setClusterId(long clusterId) { + this.clusterId = clusterId; + } + + /** + * which cluster the cluster belong to + */ + protected long clusterId = -1; + + /** + * how many sample belong to the cluster + */ + protected long count; + + /** + * the vector value of the sample + */ + protected DenseVector vector; + + public Center(long clusterId, long count, DenseVector vector) { + this.clusterId = clusterId; + this.count = count; + this.vector = vector; + } + + public long getClusterId() { + return clusterId; + } + + public long getCount() { + return count; + } + + public DenseVector getVector() { + return vector; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/common/Cluster.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/common/Cluster.java new file mode 100644 index 000000000..f9be6881f --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/common/Cluster.java @@ -0,0 +1,58 @@ +package com.alibaba.alink.operator.common.clustering.common; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.utils.AlinkSerializable; + +import java.util.ArrayList; +import java.util.List; + +/** + * @author guotao.gt + */ +public class Cluster implements AlinkSerializable { + protected long clusterId = -1; + protected List samples = new ArrayList <>(); + + public Cluster() { + } + + /** + * add + * + * @param sample + */ + public void addSample(Sample sample) { + this.samples.add(sample); + this.clusterId = sample.clusterId; + } + + /** + * calc the center vector of a cluster + * + * @return + */ + public DenseVector mean() { + if (null != samples && samples.size() > 0) { + int dim = samples.get(0).getVector().size(); + DenseVector r = new DenseVector(dim); + for (Sample dp : samples) { + r.plusEqual(dp.getVector()); + } + r.scaleEqual(1.0 / samples.size()); + return r; + } + return null; + } + + /** + * get the center of a cluster + * + * @return + */ + public Center getCenter() { + DenseVector denseVector = this.mean(); + long count = this.samples.size(); + return new Center(this.clusterId, count, denseVector); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/common/Sample.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/common/Sample.java new file mode 100644 index 000000000..977d567d9 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/common/Sample.java @@ -0,0 +1,77 @@ +package com.alibaba.alink.operator.common.clustering.common; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.utils.AlinkSerializable; + +/** + * @author guotao.gt + */ +public class Sample implements AlinkSerializable { + + /** + * the identity of a sample + */ + protected String sampleId; + + /** + * the vector value of the sample + */ + protected DenseVector vector; + + /** + * which cluster the cluster belong to + */ + protected long clusterId = -1; + + /** + * for group clustering,not necessary + */ + private String[] groupColNames; + + public Sample() { + } + + public Sample(String sampleId, double[] vector) { + this(sampleId, new DenseVector(vector)); + } + + public Sample(String sampleId, DenseVector vector) { + this.sampleId = sampleId; + this.vector = vector; + } + + public Sample(String sampleId, DenseVector vector, long clusterId, String[] groupColNames) { + this.sampleId = sampleId; + this.vector = vector; + this.clusterId = clusterId; + this.groupColNames = groupColNames; + } + + public DenseVector getVector() { + return vector; + } + + public long getClusterId() { + return clusterId; + } + + public void setClusterId(long clusterId) { + this.clusterId = clusterId; + } + + public String getSampleId() { + return sampleId; + } + + public String[] getGroupColNames() { + return groupColNames; + } + + public String getGroupColNamesString() { + StringBuilder sb = new StringBuilder(); + for (String key : this.groupColNames) { + sb.append(key).append("\001"); + } + return sb.toString(); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/Dbscan.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/Dbscan.java new file mode 100644 index 000000000..8eee957a1 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/Dbscan.java @@ -0,0 +1,86 @@ +package com.alibaba.alink.operator.common.clustering.dbscan; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.operator.common.distance.FastDistance; + +import java.util.ArrayList; +import java.util.List; + +/** + * @author guotao.gt + */ +@NameCn("Dbscan训练") +public class Dbscan { + public static final int UNCLASSIFIED = -1; + public static final int NOISE = Integer.MIN_VALUE; + + public static List findNeighbors(Iterable values, DbscanNewSample sample, + double epsilon, + FastDistance baseDistance) { + List epsilonDenseVectorList = new ArrayList <>(); + for (DbscanNewSample record : values) { + if (baseDistance.calc(record.getVec(), sample.getVec()).get(0, 0) <= epsilon) { + epsilonDenseVectorList.add(record); + } + } + return epsilonDenseVectorList; + } + + /** + * Assigns this sample to a cluster or remains it as NOISE + * + * @param sample The DbscanNewSample that needs to be assigned + * @return true, if the DbscanNewSample could be assigned, else false + */ + public static boolean expandCluster(Iterable values, DbscanNewSample sample, int clusterId, + double epsilon, int minPoints, FastDistance distance) { + List neighbors = findNeighbors(values, sample, epsilon, distance); + /** sample is NOT CORE */ + if (neighbors.size() < minPoints) { + sample.setType(Type.NOISE); + sample.setClusterId(NOISE); + return false; + } else { + /** sample is CORE */ + sample.setType(Type.CORE); + for (int i = 0; i < neighbors.size(); i++) { + DbscanNewSample neighbor = neighbors.get(i); + /** label this neighbor with the current clusterId */ + neighbor.setClusterId(clusterId); + if (neighbor.equals(sample)) { + neighbors.remove(i); + i--; + } + } + + /** Iterate the neighbors, add the UNCLASSIFIED sample to neighbors */ + for (int j = 0; j < neighbors.size(); j++) { + List indirectNeighbours = findNeighbors(values, neighbors.get(j), epsilon, distance); + + /** neighbor is CORE */ + if (indirectNeighbours.size() >= minPoints) { + neighbors.get(j).setType(Type.CORE); + for (int k = 0; k < indirectNeighbours.size(); k++) { + DbscanNewSample indirectNeighbour = indirectNeighbours.get(k); + if (indirectNeighbour.getClusterId() == UNCLASSIFIED + || indirectNeighbour.getType() == Type.NOISE) { + if (indirectNeighbour.getClusterId() == UNCLASSIFIED) { + neighbors.add(indirectNeighbour); + } + if (indirectNeighbour.getType() == Type.NOISE) { + indirectNeighbour.setType(Type.LINKED); + } + indirectNeighbour.setClusterId(clusterId); + } + } + } else { + neighbors.get(j).setType(Type.LINKED); + } + neighbors.remove(j); + j--; + } + + return true; + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanCenter.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanCenter.java new file mode 100644 index 000000000..527c682f7 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanCenter.java @@ -0,0 +1,38 @@ +package com.alibaba.alink.operator.common.clustering.dbscan; + +import org.apache.flink.types.Row; + +import java.io.Serializable; + +public class DbscanCenter implements Serializable { + + private static final long serialVersionUID = -7360209191438865256L; + private Row groupColNames; + private long clusterId; + private long count; + private T value; + + public DbscanCenter(Row groupColNames, long clusterId, long count, + T value) { + this.groupColNames = groupColNames; + this.clusterId = clusterId; + this.count = count; + this.value = value; + } + + public Row getGroupColNames() { + return groupColNames; + } + + public long getClusterId() { + return clusterId; + } + + public long getCount() { + return count; + } + + public T getValue() { + return value; + } +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanConstant.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanConstant.java new file mode 100644 index 000000000..3e7e5351d --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanConstant.java @@ -0,0 +1,10 @@ +package com.alibaba.alink.operator.common.clustering.dbscan; + +/** + * @author guotaogt + */ +public class DbscanConstant { + public static final String TYPE = "type"; + public static final String COUNT = "count"; + public static final String FEATURE_COL_NAMES = "feature_col_names"; +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanModelDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanModelDataConverter.java new file mode 100644 index 000000000..4a65792a9 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanModelDataConverter.java @@ -0,0 +1,65 @@ +package com.alibaba.alink.operator.common.clustering.dbscan; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.linalg.VectorUtil; +import com.alibaba.alink.common.model.SimpleModelDataConverter; +import com.alibaba.alink.operator.common.clustering.ClusterSummary; +import com.alibaba.alink.params.clustering.DbscanParams; +import com.alibaba.alink.params.shared.clustering.HasClusteringDistanceType; + +import java.util.ArrayList; +import java.util.List; + +/** + * * + * + * @author guotao.gt + */ +public class DbscanModelDataConverter extends SimpleModelDataConverter { + + public DbscanModelDataConverter() { + } + + @Override + public Tuple2 > serializeModel(DbscanModelTrainData modelData) { + List modelRows = new ArrayList <>(); + for (Tuple2 centroid : modelData.coreObjects) { + ClusterSummary c = new ClusterSummary(VectorUtil.toString(centroid.f0), centroid.f1, null, + Type.CORE.name()); + modelRows.add(c.serialize()); + } + Params meta = new Params() + .set(DbscanParams.VECTOR_COL, modelData.vectorColName) + .set(DbscanParams.EPSILON, modelData.epsilon) + .set(DbscanParams.DISTANCE_TYPE, modelData.distanceType); + return Tuple2.of(meta, modelRows); + } + + @Override + public DbscanModelPredictData deserializeModel(Params meta, Iterable data) { + DbscanModelPredictData modelData = new DbscanModelPredictData(); + modelData.epsilon = meta.get(DbscanParams.EPSILON); + modelData.vectorColName = meta.get(DbscanParams.VECTOR_COL); + HasClusteringDistanceType.DistanceType distanceType = meta.get(DbscanParams.DISTANCE_TYPE); + modelData.baseDistance = distanceType.getFastDistance(); + + modelData.coreObjects = new ArrayList <>(); + + // get the model data + for (String row : data) { + try { + ClusterSummary c = ClusterSummary.deserialize(row); + Vector vec = VectorUtil.getVector(c.center); + modelData.coreObjects.add( + modelData.baseDistance.prepareVectorData(Tuple2.of(vec, Row.of(c.clusterId)))); + } catch (Exception e) { + e.printStackTrace(); + } + } + return modelData; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanModelInfoBatchOp.java new file mode 100644 index 000000000..682499b9f --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanModelInfoBatchOp.java @@ -0,0 +1,114 @@ +package com.alibaba.alink.operator.common.clustering.dbscan; + +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; + +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class DbscanModelInfoBatchOp + extends ExtractModelInfoBatchOp { + private static final long serialVersionUID = 1735133462550836751L; + + public DbscanModelInfoBatchOp() { + this(null); + } + + public DbscanModelInfoBatchOp(Params params) { + super(params); + } + + @Override + public DbscanModelInfo createModelInfo(List rows) { + return new DbscanModelInfo(rows); + } + + public static class DbscanModelInfo implements Serializable { + private static final long serialVersionUID = 5349212648420863302L; + private Map > core; + private Map > linked; + private List noise; + + private int totalSamples; + + public DbscanModelInfo(List rows) { + core = new HashMap <>(); + linked = new HashMap <>(); + noise = new ArrayList <>(); + totalSamples = 0; + for (Row row : rows) { + totalSamples++; + Type type = Type.valueOf((String) row.getField(1)); + Object id = row.getField(0); + long clusterId = (long) row.getField(2); + switch (type) { + case CORE: { + List list = core.get((int) clusterId); + if (null == list) { + list = new ArrayList <>(); + } + list.add(id); + core.put((int) clusterId, list); + break; + } + case LINKED: { + List list = linked.get((int) clusterId); + if (null == list) { + list = new ArrayList <>(); + } + list.add(id); + linked.put((int) clusterId, list); + break; + } + case NOISE: { + noise.add(id); + break; + } + } + } + for (int i = 0; i < getClusterNumber(); i++) { + if (!linked.containsKey(i)) { + linked.put(i, new ArrayList <>()); + } + } + } + + public int getClusterNumber() { + return core.size(); + } + + public Object[] getCorePoints(int clusterId) { + return core.get(clusterId).toArray(new Object[0]); + } + + public Object[] getLinkedPoints(int clusterId) { + return linked.get(clusterId).toArray(new Object[0]); + } + + @Override + public String toString() { + StringBuilder sbd = new StringBuilder(PrettyDisplayUtils.displayHeadline("DbscanModelInfo", '-')); + sbd.append("Dbcan clustering with ") + .append(getClusterNumber()) + .append(" clusters on ") + .append(totalSamples) + .append(" samples.\n") + .append(PrettyDisplayUtils.displayHeadline("CorePoints", '=')) + .append(PrettyDisplayUtils.displayMap(core, 3, true)) + .append("\n") + .append(PrettyDisplayUtils.displayHeadline("LinkedPoints", '=')) + .append(PrettyDisplayUtils.displayMap(linked, 3, true)) + .append("\n") + .append(PrettyDisplayUtils.displayHeadline("NoisePoints", '=')) + .append(PrettyDisplayUtils.displayList(noise, 3, false)) + .append("\n"); + return sbd.toString(); + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanModelMapper.java new file mode 100644 index 000000000..8ea6f980a --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanModelMapper.java @@ -0,0 +1,66 @@ +package com.alibaba.alink.operator.common.clustering.dbscan; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.linalg.VectorUtil; +import com.alibaba.alink.common.mapper.ModelMapper; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.common.distance.FastDistanceVectorData; +import com.alibaba.alink.params.clustering.ClusteringPredictParams; + +import java.util.List; + +public class DbscanModelMapper extends ModelMapper { + private static final long serialVersionUID = -3771648601253028057L; + private DbscanModelPredictData modelData = null; + private int colIdx; + + public DbscanModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { + super(modelSchema, dataSchema, params); + } + + @Override + public void loadModel(List modelRows) { + this.modelData = new DbscanModelDataConverter().load(modelRows); + colIdx = TableUtil.findColIndexWithAssert(getDataSchema().getFieldNames(), modelData.vectorColName); + } + + @Override + protected void map(SlicedSelectedSample selection, SlicedResult result) throws Exception { + long clusterId = findCluster(VectorUtil.getVector(selection.get(colIdx))); + result.set(0, clusterId); + } + + @Override + protected Tuple4 [], String[]> prepareIoSchema( + TableSchema modelSchema, TableSchema dataSchema, Params params) { + return Tuple4.of(dataSchema.getFieldNames(), new String[] {params.get(ClusteringPredictParams.PREDICTION_COL)}, + new TypeInformation[] {Types.LONG}, params.get(ClusteringPredictParams.RESERVED_COLS)); + } + + private long findCluster(Vector vec) { + long clusterId = -1; + double d = Double.POSITIVE_INFINITY; + + FastDistanceVectorData sample = modelData.baseDistance.prepareVectorData(Row.of(vec), 0); + + for (FastDistanceVectorData c : modelData.coreObjects) { + double distance = modelData.baseDistance.calc(c, sample).get(0, 0); + if (distance < d) { + clusterId = (long) c.getRows()[0].getField(0); + d = distance; + } + } + if (d > modelData.epsilon) { + //noise + clusterId = Integer.MIN_VALUE; + } + return clusterId; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanModelPredictData.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanModelPredictData.java new file mode 100644 index 000000000..a900753ae --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanModelPredictData.java @@ -0,0 +1,23 @@ +package com.alibaba.alink.operator.common.clustering.dbscan; + +import com.alibaba.alink.operator.common.clustering.DistanceType; +import com.alibaba.alink.operator.common.distance.FastDistance; +import com.alibaba.alink.operator.common.distance.FastDistanceVectorData; + +import java.util.List; + +public class DbscanModelPredictData { + /** + * predict scope (after loading, before predicting) + */ + public double epsilon; + public FastDistance baseDistance; + public String vectorColName; + public DistanceType distanceType; + + /** + * Tuple3: clusterId, clusterCentroid + */ + public List coreObjects; + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanModelTrainData.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanModelTrainData.java new file mode 100644 index 000000000..5b535bcda --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanModelTrainData.java @@ -0,0 +1,21 @@ +package com.alibaba.alink.operator.common.clustering.dbscan; + +import org.apache.flink.api.java.tuple.Tuple2; + +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.params.shared.clustering.HasClusteringDistanceType; + +public class DbscanModelTrainData { + /** + * predict scope (after loading, before predicting) + */ + public double epsilon; + public String vectorColName; + public HasClusteringDistanceType.DistanceType distanceType; + + /** + * Tuple3: clusterId, clusterCentroid + */ + public Iterable > coreObjects; + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanNewSample.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanNewSample.java new file mode 100644 index 000000000..fde6dcd43 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/DbscanNewSample.java @@ -0,0 +1,49 @@ +package com.alibaba.alink.operator.common.clustering.dbscan; + +import com.alibaba.alink.operator.common.distance.FastDistanceVectorData; +import scala.util.hashing.MurmurHash3; + +import java.io.Serializable; + +import static com.alibaba.alink.operator.common.clustering.dbscan.Dbscan.UNCLASSIFIED; + +public class DbscanNewSample implements Serializable { + + private static final long serialVersionUID = 6828060936541551678L; + protected FastDistanceVectorData vec; + protected long clusterId; + protected Type type; + protected int groupHashKey; + + public DbscanNewSample(FastDistanceVectorData vec, String[] groupColName) { + this.vec = vec; + this.clusterId = UNCLASSIFIED; + this.type = null; + this.groupHashKey = new MurmurHash3().arrayHash(groupColName, 0); + } + + public int getGroupHashKey() { + return groupHashKey; + } + + public FastDistanceVectorData getVec() { + return vec; + } + + public long getClusterId() { + return clusterId; + } + + public void setClusterId(long clusterId) { + this.clusterId = clusterId; + } + + public Type getType() { + return type; + } + + public void setType(Type type) { + this.type = type; + } + +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/LocalCluster.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/LocalCluster.java new file mode 100644 index 000000000..09d71e2ab --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/LocalCluster.java @@ -0,0 +1,22 @@ +package com.alibaba.alink.operator.common.clustering.dbscan; + +import java.io.Serializable; + +public class LocalCluster implements Serializable { + private static final long serialVersionUID = 1286966569356492688L; + private int[] keys; + private int[] clusterIds; + + public int[] getKeys() { + return keys; + } + + public int[] getClusterIds() { + return clusterIds; + } + + public LocalCluster(int[] keys, int[] clusterIds) { + this.keys = keys; + this.clusterIds = clusterIds; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/Type.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/Type.java new file mode 100644 index 000000000..844ed9676 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/dbscan/Type.java @@ -0,0 +1,16 @@ +package com.alibaba.alink.operator.common.clustering.dbscan; + +public enum Type { + /** + * CORE + */ + CORE, + /** + * LINKED + */ + LINKED, + /** + * NOISE + */ + NOISE; +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/kmodes/KModesModel.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/kmodes/KModesModel.java new file mode 100644 index 000000000..0d2f2c399 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/kmodes/KModesModel.java @@ -0,0 +1,68 @@ +package com.alibaba.alink.operator.common.clustering.kmodes; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.model.SimpleModelDataConverter; +import com.alibaba.alink.operator.common.clustering.ClusterSummary; +import com.alibaba.alink.operator.common.distance.OneZeroDistance; +import com.alibaba.alink.params.clustering.KModesTrainParams; + +import java.util.ArrayList; +import java.util.List; + +import static com.alibaba.alink.common.utils.JsonConverter.gson; + +/** + * @author guotao.gt + */ +@NameCn("Kmodes模型") +public class KModesModel extends SimpleModelDataConverter { + public KModesModel() { + } + + public static long findCluster(Iterable > centroids, String[] sample, + OneZeroDistance oneZeroDistance) { + long clusterId = -1; + double d = Double.POSITIVE_INFINITY; + + for (Tuple3 c : centroids) { + double distance = oneZeroDistance.calc(sample, c.f2); + if (distance < d) { + clusterId = c.f0; + d = distance; + } + } + return clusterId; + } + + @Override + public Tuple2 > serializeModel(KModesModelData modelData) { + List modelRows = new ArrayList <>(); + for (Tuple3 centroid : modelData.centroids) { + ClusterSummary c = new ClusterSummary(gson.toJson(centroid.f2), centroid.f0, centroid.f1); + modelRows.add(c.serialize()); + } + Params meta = new Params() + .set(KModesTrainParams.FEATURE_COLS, modelData.featureColNames); + return Tuple2.of(meta, modelRows); + } + + @Override + public KModesModelData deserializeModel(Params meta, Iterable rows) { + KModesModelData modelData = new KModesModelData(); + modelData.centroids = new ArrayList <>(); + modelData.featureColNames = meta.get(KModesTrainParams.FEATURE_COLS); + + for (String row : rows) { + ClusterSummary c = ClusterSummary.deserialize(row); + Tuple3 centroid = + new Tuple3 <>(c.clusterId, c.weight, + gson.fromJson(c.center, String[].class)); + modelData.centroids.add(centroid); + } + return modelData; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/kmodes/KModesModelData.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/kmodes/KModesModelData.java new file mode 100644 index 000000000..5bf73b9ae --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/kmodes/KModesModelData.java @@ -0,0 +1,10 @@ +package com.alibaba.alink.operator.common.clustering.kmodes; + +import org.apache.flink.api.java.tuple.Tuple3; + +import java.util.List; + +public class KModesModelData { + public List > centroids; + public String[] featureColNames; +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/clustering/kmodes/KModesModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/clustering/kmodes/KModesModelMapper.java new file mode 100644 index 000000000..8e9292d5b --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/clustering/kmodes/KModesModelMapper.java @@ -0,0 +1,87 @@ +package com.alibaba.alink.operator.common.clustering.kmodes; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.mapper.ModelMapper; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.common.distance.OneZeroDistance; +import com.alibaba.alink.params.clustering.ClusteringPredictParams; + +import java.util.List; + +public class KModesModelMapper extends ModelMapper { + + private static final long serialVersionUID = 1212257106447281392L; + private final boolean isPredDetail; + private KModesModelData modelData; + private int[] colIdx; + private final OneZeroDistance distance = new OneZeroDistance(); + + public KModesModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { + super(modelSchema, dataSchema, params); + isPredDetail = params.contains(ClusteringPredictParams.PREDICTION_DETAIL_COL); + } + + @Override + public void loadModel(List modelRows) { + this.modelData = new KModesModel().load(modelRows); + + colIdx = new int[modelData.featureColNames.length]; + + for (int i = 0; i < modelData.featureColNames.length; i++) { + colIdx[i] = TableUtil.findColIndexWithAssert(getDataSchema().getFieldNames(), modelData + .featureColNames[i]); + } + } + + @Override + protected void map(SlicedSelectedSample selection, SlicedResult result) throws Exception { + String[] record = new String[colIdx.length]; + for (int i = 0; i < record.length; i++) { + record[i] = (String.valueOf(selection.get(colIdx[i]))); + } + + Tuple2 tuple2 = getCluster(modelData.centroids, record, distance); + result.set(0, tuple2.f0); + if (isPredDetail) { + result.set(1, tuple2.f1); + } + } + + @Override + protected Tuple4 [], String[]> prepareIoSchema( + TableSchema modelSchema, TableSchema dataSchema, Params params) { + boolean isPredDetail = params.contains(ClusteringPredictParams.PREDICTION_DETAIL_COL); + String[] resultCols = isPredDetail ? new String[] {params.get(ClusteringPredictParams.PREDICTION_COL), + params.get(ClusteringPredictParams.PREDICTION_DETAIL_COL)} : new String[] {params.get( + ClusteringPredictParams.PREDICTION_COL)}; + TypeInformation[] resultTypes = isPredDetail ? new TypeInformation[] {Types.LONG, Types.DOUBLE} + : new TypeInformation[] {Types.LONG}; + + return Tuple4.of(dataSchema.getFieldNames(), resultCols, resultTypes, + params.get(ClusteringPredictParams.RESERVED_COLS)); + } + + private Tuple2 getCluster(Iterable > centroids, + String[] sample, + OneZeroDistance oneZeroDistance) { + long clusterId = -1; + double d = Double.POSITIVE_INFINITY; + + for (Tuple3 c : centroids) { + double distance = oneZeroDistance.calc(sample, c.f2); + if (distance < d) { + clusterId = c.f0; + d = distance; + } + } + return new Tuple2 <>(clusterId, d); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/AggLookupModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/AggLookupModelMapper.java index bd1f1553e..77c93e627 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/AggLookupModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/AggLookupModelMapper.java @@ -6,7 +6,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.MatVecOp; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ImputerModelDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ImputerModelDataConverter.java index 662e0d2e4..609785d3d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ImputerModelDataConverter.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ImputerModelDataConverter.java @@ -62,13 +62,13 @@ public Tuple3 , Iterable > serializeModel( case MIN: values = new double[selectedColNames.length]; for (int i = 0; i < selectedColNames.length; i++) { - values[i] = summary.min(selectedColNames[i]); + values[i] = summary.minDouble(selectedColNames[i]); } break; case MAX: values = new double[selectedColNames.length]; for (int i = 0; i < selectedColNames.length; i++) { - values[i] = summary.max(selectedColNames[i]); + values[i] = summary.maxDouble(selectedColNames[i]); } break; case MEAN: diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ImputerModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ImputerModelMapper.java index 5539c310a..5c9b4b802 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ImputerModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ImputerModelMapper.java @@ -65,7 +65,7 @@ protected Tuple4[], String[]> prepareIoSc TypeInformation[] selectedColTypes = ImputerModelDataConverter.extractSelectedColTypes(modelSchema); String[] outputColNames = params.get(SrtPredictMapperParams.OUTPUT_COLS); - if (outputColNames == null) { + if (outputColNames == null || outputColNames.length == 0) { outputColNames = selectedColNames; } return Tuple4.of(selectedColNames, outputColNames, selectedColTypes, null); diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/LookupRedisStringMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/LookupRedisStringMapper.java index e913c385d..33daaaec9 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/LookupRedisStringMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/LookupRedisStringMapper.java @@ -4,7 +4,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.io.redis.Redis; import com.alibaba.alink.common.io.redis.RedisClassLoaderFactory; import com.alibaba.alink.common.mapper.SISOMapper; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/MaxAbsScalerModelDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/MaxAbsScalerModelDataConverter.java index fce4f0a91..677a174bd 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/MaxAbsScalerModelDataConverter.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/MaxAbsScalerModelDataConverter.java @@ -56,7 +56,7 @@ public Tuple3 , Iterable > serializeModel(TableSu double[] maxAbs = new double[colNames.length]; for (int i = 0; i < colNames.length; i++) { //max(|min, max|) - maxAbs[i] = Math.max(Math.abs(modelData.min(colNames[i])), Math.abs(modelData.max(colNames[i]))); + maxAbs[i] = Math.max(Math.abs(modelData.minDouble(colNames[i])), Math.abs(modelData.maxDouble(colNames[i]))); } List data = new ArrayList <>(); diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/MinMaxScalerModelDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/MinMaxScalerModelDataConverter.java index cf09c69d5..35a2f8399 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/MinMaxScalerModelDataConverter.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/MinMaxScalerModelDataConverter.java @@ -67,8 +67,8 @@ public Tuple3 , Iterable > serializeModel( double[] eMins = new double[colNames.length]; for (int i = 0; i < colNames.length; i++) { - eMaxs[i] = summary.max(colNames[i]); - eMins[i] = summary.min(colNames[i]); + eMaxs[i] = summary.maxDouble(colNames[i]); + eMins[i] = summary.minDouble(colNames[i]); } List data = new ArrayList <>(); diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/MinMaxScalerModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/MinMaxScalerModelMapper.java index 61553d358..9959354c2 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/MinMaxScalerModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/MinMaxScalerModelMapper.java @@ -43,7 +43,7 @@ protected Tuple4[], String[]> prepareIoSc TypeInformation[] selectedColTypes = ImputerModelDataConverter.extractSelectedColTypes(modelSchema); String[] outputColNames = params.get(SrtPredictMapperParams.OUTPUT_COLS); - if (outputColNames == null) { + if (outputColNames == null || outputColNames.length == 0) { outputColNames = selectedColNames; } return Tuple4.of(selectedColNames, outputColNames, selectedColTypes, null); diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/SparseFeatureIndexerModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/SparseFeatureIndexerModelMapper.java index 75e9d8613..6e9e6bf7c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/SparseFeatureIndexerModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/SparseFeatureIndexerModelMapper.java @@ -6,7 +6,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.exceptions.AkIllegalModelException; import com.alibaba.alink.common.linalg.SparseVector; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/StandardScalerModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/StandardScalerModelMapper.java index deeef00f9..b00b888f1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/StandardScalerModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/StandardScalerModelMapper.java @@ -26,7 +26,7 @@ protected Tuple4[], String[]> prepareIoSc TypeInformation[] selectedColTypes = ImputerModelDataConverter.extractSelectedColTypes(modelSchema); String[] outputColNames = params.get(SrtPredictMapperParams.OUTPUT_COLS); - if (outputColNames == null) { + if (outputColNames == null || outputColNames.length == 0) { outputColNames = selectedColNames; } return Tuple4.of(selectedColNames, outputColNames, selectedColTypes, null); diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/TensorSerializeMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/TensorSerializeMapper.java index aa73033be..04c410059 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/TensorSerializeMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/TensorSerializeMapper.java @@ -7,7 +7,7 @@ import org.apache.flink.table.api.TableSchema; import com.alibaba.alink.common.linalg.tensor.Tensor; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.tensor.TensorUtil; import com.alibaba.alink.common.mapper.Mapper; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/TensorToVectorMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/TensorToVectorMapper.java index 9a02dfe5c..35372b633 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/TensorToVectorMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/TensorToVectorMapper.java @@ -4,7 +4,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import com.alibaba.alink.common.linalg.tensor.DoubleTensor; import com.alibaba.alink.common.linalg.tensor.NumericalTensor; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ToMTableMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ToMTableMapper.java index d85661a6f..46b253689 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ToMTableMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ToMTableMapper.java @@ -4,7 +4,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.MTable; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ToTensorMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ToTensorMapper.java index f2548a610..66372143c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ToTensorMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ToTensorMapper.java @@ -4,7 +4,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import com.alibaba.alink.common.linalg.tensor.DataType; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ToVectorMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ToVectorMapper.java index 35056c1db..1d3c682ee 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ToVectorMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/ToVectorMapper.java @@ -4,7 +4,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/format/FormatTransMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/format/FormatTransMapper.java index 7ff48cc9d..47370cba4 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/format/FormatTransMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/format/FormatTransMapper.java @@ -9,7 +9,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import com.alibaba.alink.common.mapper.Mapper; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/tensor/TensorReshapeMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/tensor/TensorReshapeMapper.java index a673dbc8f..4e31b950b 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/tensor/TensorReshapeMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/tensor/TensorReshapeMapper.java @@ -4,7 +4,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.tensor.Shape; import com.alibaba.alink.common.linalg.tensor.Tensor; import com.alibaba.alink.common.linalg.tensor.TensorUtil; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/MTableSerializeMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/MTableSerializeMapper.java index 4981a7316..53e759367 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/MTableSerializeMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/MTableSerializeMapper.java @@ -7,8 +7,7 @@ import org.apache.flink.table.api.TableSchema; import com.alibaba.alink.common.MTable; -import com.alibaba.alink.common.AlinkTypes; -import com.alibaba.alink.common.MTableUtil; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.mapper.Mapper; import com.alibaba.alink.common.utils.JsonConverter; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/PolynomialExpansionMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/PolynomialExpansionMapper.java index 765734c2f..9ff87c876 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/PolynomialExpansionMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/PolynomialExpansionMapper.java @@ -1,12 +1,11 @@ package com.alibaba.alink.operator.common.dataproc.vector; -import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; @@ -40,7 +39,6 @@ public PolynomialExpansionMapper(TableSchema dataSchema, Params params) { * @param degree the degree of the polynomial. * @return the polynomial size. */ - @VisibleForTesting static int getPolySize(int num, int degree) { if (num == 0) { return 1; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorAssemblerMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorAssemblerMapper.java index 1372ebe7a..8c6462ef2 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorAssemblerMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorAssemblerMapper.java @@ -4,7 +4,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import com.alibaba.alink.common.linalg.DenseVector; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorBiFunctionMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorBiFunctionMapper.java index c06c54599..f17390f67 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorBiFunctionMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorBiFunctionMapper.java @@ -5,7 +5,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import com.alibaba.alink.common.linalg.MatVecOp; import com.alibaba.alink.common.linalg.Vector; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorElementwiseProductMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorElementwiseProductMapper.java index 317ffdecb..7bbad8029 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorElementwiseProductMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorElementwiseProductMapper.java @@ -4,7 +4,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.linalg.Vector; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorFunctionMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorFunctionMapper.java index 9ef274fc8..d121969ad 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorFunctionMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorFunctionMapper.java @@ -5,7 +5,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkUnimplementedOperationException; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorImputerModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorImputerModelMapper.java index dd435714b..7c99b940e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorImputerModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorImputerModelMapper.java @@ -6,7 +6,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorInteractionMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorInteractionMapper.java index 2b4fb1986..846ddc377 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorInteractionMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorInteractionMapper.java @@ -4,7 +4,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorMaxAbsScalerModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorMaxAbsScalerModelMapper.java index 03b8f107c..702c5ba58 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorMaxAbsScalerModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorMaxAbsScalerModelMapper.java @@ -5,7 +5,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.linalg.Vector; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorMinMaxScalerModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorMinMaxScalerModelMapper.java index a5441fbe1..4a2ba732a 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorMinMaxScalerModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorMinMaxScalerModelMapper.java @@ -6,7 +6,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.linalg.Vector; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorNormalizeMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorNormalizeMapper.java index 0df21cfd7..3c98c40ba 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorNormalizeMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorNormalizeMapper.java @@ -4,7 +4,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.Vector; import com.alibaba.alink.common.linalg.VectorUtil; import com.alibaba.alink.common.mapper.SISOMapper; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorSerializeMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorSerializeMapper.java index 8c715198c..de32a5b86 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorSerializeMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorSerializeMapper.java @@ -6,7 +6,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.Vector; import com.alibaba.alink.common.linalg.VectorUtil; import com.alibaba.alink.common.mapper.Mapper; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorSizeHintMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorSizeHintMapper.java index 1250995f5..9fa2b75e2 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorSizeHintMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorSizeHintMapper.java @@ -4,8 +4,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; -import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.linalg.Vector; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorSliceMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorSliceMapper.java index 33d48c1a0..2ab5e2520 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorSliceMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorSliceMapper.java @@ -4,7 +4,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.Vector; import com.alibaba.alink.common.linalg.VectorUtil; import com.alibaba.alink.common.mapper.SISOMapper; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorStandardScalerModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorStandardScalerModelMapper.java index 357c58718..c73586de2 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorStandardScalerModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/dataproc/vector/VectorStandardScalerModelMapper.java @@ -6,7 +6,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.linalg.Vector; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/distance/OneZeroDistance.java b/core/src/main/java/com/alibaba/alink/operator/common/distance/OneZeroDistance.java new file mode 100644 index 000000000..78adc8d84 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/distance/OneZeroDistance.java @@ -0,0 +1,25 @@ +package com.alibaba.alink.operator.common.distance; + +import org.apache.flink.util.Preconditions; + +import java.util.Objects; + +public class OneZeroDistance implements CategoricalDistance { + private static final long serialVersionUID = -6375080752955133016L; + + @Override + public int calc(String str1, String str2) { + return Objects.equals(str1, str2) ? 0 : 1; + } + + @Override + public int calc(String[] str1, String[] str2) { + int distance = 0; + Preconditions.checkArgument(str1.length == str2.length, + "For OneZeroDistance, the categorical feature number must be equal!"); + for (int i = 0; i < str1.length; i++) { + distance += calc(str1[i], str2[i]); + } + return distance; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/evaluation/BinaryMetricsSummary.java b/core/src/main/java/com/alibaba/alink/operator/common/evaluation/BinaryMetricsSummary.java index 59c895c54..326436ebd 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/evaluation/BinaryMetricsSummary.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/evaluation/BinaryMetricsSummary.java @@ -112,10 +112,6 @@ public BinaryClassMetrics toMetrics() { setCurveAreaParams(params, matrixThreCurve.f2); - if(Arrays.stream(negativeBin).sum() == 0 || Arrays.stream(positiveBin).sum() == 0){ - params.set(BinaryClassMetrics.AUC, Double.NaN); - } - Tuple3 sampledMatrixThreCurve = sample( PROBABILITY_INTERVAL, matrixThreCurve); diff --git a/core/src/main/java/com/alibaba/alink/operator/common/evaluation/ClusterMetricsSummary.java b/core/src/main/java/com/alibaba/alink/operator/common/evaluation/ClusterMetricsSummary.java index ef2661faa..b52a13004 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/evaluation/ClusterMetricsSummary.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/evaluation/ClusterMetricsSummary.java @@ -134,18 +134,34 @@ public ClusterMetrics toMetrics() { } } - double DBIndex = StatUtils.sum(DBIndexArray) / k; + double DBIndex = k > 1 ? StatUtils.sum(DBIndexArray) / k : Double.POSITIVE_INFINITY; params.set(ClusterMetrics.SSB, ssb); params.set(ClusterMetrics.SSW, ssw); params.set(ClusterMetrics.CP, compactness / k); params.set(ClusterMetrics.K, k); params.set(ClusterMetrics.COUNT, total); - params.set(ClusterMetrics.SP, 2 * seperation / (k * k - k)); + params.set(ClusterMetrics.SP, k > 1 ? 2 * seperation / (k * k - k) : 0.); params.set(ClusterMetrics.DB, DBIndex); - params.set(ClusterMetrics.VRC, ssb * (total - k) / ssw / (k - 1)); + params.set(ClusterMetrics.VRC, k > 1 ? ssb * (total - k) / ssw / (k - 1) : 0); params.set(ClusterMetrics.CLUSTER_ARRAY, clusters); params.set(ClusterMetrics.COUNT_ARRAY, countArray); return new ClusterMetrics(params); } + + public static ClusterMetrics createForEmptyDataset() { + Params params = new Params(); + params.set(ClusterMetrics.SSB, 0.); + params.set(ClusterMetrics.SSW, Double.POSITIVE_INFINITY); + params.set(ClusterMetrics.CP, Double.POSITIVE_INFINITY); + params.set(ClusterMetrics.K, 0); + params.set(ClusterMetrics.COUNT, 0); + params.set(ClusterMetrics.SP, 0.); + params.set(ClusterMetrics.DB, Double.POSITIVE_INFINITY); + params.set(ClusterMetrics.VRC, 0.); + params.set(ClusterMetrics.CLUSTER_ARRAY, new String[0]); + params.set(ClusterMetrics.COUNT_ARRAY, new double[0]); + params.set(ClusterMetrics.SILHOUETTE_COEFFICIENT, -1.); + return new ClusterMetrics(params); + } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/AutoCrossAlgoModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/AutoCrossAlgoModelMapper.java new file mode 100644 index 000000000..682062679 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/AutoCrossAlgoModelMapper.java @@ -0,0 +1,198 @@ +package com.alibaba.alink.operator.common.feature.AutoCross; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.function.TriFunction; + +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.SparseVector; +import com.alibaba.alink.common.linalg.VectorUtil; +import com.alibaba.alink.common.mapper.ModelMapper; +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.params.feature.AutoCrossPredictParams; +import com.alibaba.alink.params.feature.featuregenerator.HasAppendOriginalData; +import com.alibaba.alink.params.shared.colname.HasReservedColsDefaultAsNull; + +import java.io.Serializable; +import java.util.List; +import java.util.stream.Collectors; + +public class AutoCrossAlgoModelMapper extends ModelMapper { + private static final long serialVersionUID = -4500389710522943248L; + private String[] dataCols; + private int[] numericalIndices; + private int vecIndex = -2; + private OneHotOperator operator; + private int[] cumsumIndex; + private final AutoCrossPredictParams.OutputFormat outputFormat; + TriFunction, OneHotOperator, SlicedResult, Row> mapOperator; + + boolean appendOriginalVec = true; + + public static class FeatureSet implements Serializable { + private static final long serialVersionUID = 3402906686076385472L; + public int numRawFeatures; + public String[] numericalCols; + public String vecColName; + public List crossFeatureSet; + public List scores; + public int[] indexSize; + public boolean hasDiscrete; + } + + //todo may support input vector col. + public AutoCrossAlgoModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { + super(modelSchema, dataSchema, params); + outputFormat = params.get(AutoCrossPredictParams.OUTPUT_FORMAT); + appendOriginalVec = params.get(HasAppendOriginalData.APPEND_ORIGINAL_DATA); + } + + @Override + protected Tuple4 [], String[]> prepareIoSchema(TableSchema modelSchema, + TableSchema dataSchema, + Params params) { + dataCols = dataSchema.getFieldNames(); + String[] outputCols = new String[] {params.get(AutoCrossPredictParams.OUTPUT_COL)}; + TypeInformation[] outputTypes = new TypeInformation[] {AlinkTypes.VECTOR}; + return Tuple4.of(dataCols, outputCols, outputTypes, params.get(HasReservedColsDefaultAsNull.RESERVED_COLS)); + + } + + @Override + public void loadModel(List modelRows) { + String jsonStr = modelRows.stream().filter(row -> row.getField(0).equals(0L)) + .map(row -> (String) row.getField(1)) + .collect(Collectors.toList()).get(0); + FeatureSet fs = JsonConverter.fromJson(jsonStr, FeatureSet.class); + numericalIndices = TableUtil.findColIndices(dataCols, fs.numericalCols); + if (vecIndex == -2) { + vecIndex = TableUtil.findColIndex(dataCols, fs.vecColName); + } + int crossFeatureSetSize = fs.crossFeatureSet.size(); + operator = new OneHotOperator(fs.numRawFeatures, fs.crossFeatureSet, fs.indexSize); + + if (outputFormat == AutoCrossPredictParams.OutputFormat.Dense) { + cumsumIndex = new int[fs.indexSize.length+fs.crossFeatureSet.size()-1]; + cumsumIndex[0] = fs.indexSize[0]; + for (int i = 1; i < fs.indexSize.length; i++) { + cumsumIndex[i] = cumsumIndex[i-1]+fs.indexSize[i]; + } + for (int i = 0; i < fs.crossFeatureSet.size() - 1; i++) { + int tempSize = 1; + for (int v : fs.crossFeatureSet.get(i)) { + tempSize*=v; + } + cumsumIndex[fs.indexSize.length+i] = cumsumIndex[fs.indexSize.length+i-1]+tempSize; + } + } + if (outputFormat == AutoCrossPredictParams.OutputFormat.Sparse) { + if (appendOriginalVec) { + mapOperator = AutoCrossAlgoModelMapper::mapSparse; + } else { + mapOperator = AutoCrossAlgoModelMapper::mapSparseWithoutOriginal; + } + } else if (outputFormat == AutoCrossPredictParams.OutputFormat.Dense) { + if (appendOriginalVec) { + mapOperator = AutoCrossAlgoModelMapper::mapDense; + } else { + mapOperator = AutoCrossAlgoModelMapper::mapDenseWithoutOriginal; + } + } + //inputBufferThreadLocal = ThreadLocal.withInitial(() -> new Row(ioSchema.f0.length)); + } + + @Override + protected void map(SlicedSelectedSample selection, SlicedResult result) throws Exception { + //Row row = inputBufferThreadLocal.get(); + mapOperator.apply(Tuple4.of(selection, vecIndex, cumsumIndex, numericalIndices), operator, result); + } + + + private static Row mapSparseWithoutOriginal(Tuple4 rowAndIndex, OneHotOperator operator, SlicedResult result) { + SlicedSelectedSample row = rowAndIndex.f0; + int vecIndex = rowAndIndex.f1; + SparseVector inputData = VectorUtil.getSparseVector(row.get(vecIndex)); + int inputDataSize = inputData.size(); + int inputDataNum = inputData.numberOfValues(); + SparseVector crossedVec = operator.oneHotData(inputData); + int resSize = crossedVec.size() - inputDataSize; + int resDataNum = crossedVec.numberOfValues() - inputDataNum; + int[] resIndices = new int[resDataNum]; + double[] resValues = new double[resDataNum]; + System.arraycopy(crossedVec.getIndices(), inputDataNum, resIndices, 0, resDataNum); + System.arraycopy(crossedVec.getValues(), inputDataNum, resValues, 0, resDataNum); + for (int i = 0; i < resDataNum; i++) { + resIndices[i] -= inputDataSize; + } + + result.set(0, new SparseVector(resSize, resIndices, resValues)); + return null; + } + + private static Row mapSparse(Tuple4 rowAndIndex, OneHotOperator operator, SlicedResult result) { + SlicedSelectedSample row = rowAndIndex.f0; + int vecIndex = rowAndIndex.f1; + int[] numericalIndices = rowAndIndex.f3; + SparseVector inVec = operator.oneHotData(VectorUtil.getSparseVector(row.get(vecIndex))); + int svSize = inVec.size() + numericalIndices.length; + int[] svIndices = new int[inVec.getIndices().length + numericalIndices.length]; + double[] svValues = new double[inVec.getIndices().length + numericalIndices.length]; + + for (int i = 0; i < numericalIndices.length; i++) { + svIndices[i] = i; + svValues[i] = ((Number) row.get(numericalIndices[i])).doubleValue(); + } + for (int i = 0; i < inVec.getIndices().length; i++) { + inVec.getIndices()[i] += numericalIndices.length; + } + System.arraycopy(inVec.getIndices(), 0, svIndices, numericalIndices.length, inVec.getIndices().length); + System.arraycopy(inVec.getValues(), 0, svValues, numericalIndices.length, inVec.getValues().length); + result.set(0, new SparseVector(svSize, svIndices, svValues)); + return null; + } + + private static Row mapDenseWithoutOriginal(Tuple4 rowAndIndex, OneHotOperator operator, SlicedResult result) { + SlicedSelectedSample row = rowAndIndex.f0; + int vecIndex = rowAndIndex.f1; + SparseVector inputData = (SparseVector) VectorUtil.getVector(row.get(vecIndex)); + int inputDataSize = inputData.size(); + int inputDataNum = inputData.numberOfValues(); + SparseVector crossedVec = operator.oneHotData(inputData); + int[] indices = inputData.getIndices(); + int resDataNum = crossedVec.numberOfValues() - inputDataNum; + double[] resData = new double[resDataNum]; + for (int i = 0; i < resDataNum; i++) { + resData[i] = indices[i + inputDataNum] - inputDataSize; + } + result.set(0, new DenseVector(resData)); + return null; + } + + private static Row mapDense(Tuple4 rowAndIndex, OneHotOperator operator, SlicedResult result) { + SlicedSelectedSample row = rowAndIndex.f0; + int vecIndex = rowAndIndex.f1; + //int[] cumsumIndex = rowAndIndex.f2; + int[] numericalIndices = rowAndIndex.f3; + SparseVector inVec = (SparseVector) VectorUtil.getVector(row.get(vecIndex)); + inVec = operator.oneHotData(inVec); + int[] indices = inVec.getIndices(); + //for (int i = 0; i < cumsumIndex.length; i++) { + // indices[i+1] -= cumsumIndex[i]; + //} + double[] dIndices = new double[numericalIndices.length + indices.length]; + for (int i = 0; i < numericalIndices.length; i++) { + dIndices[i] = ((Number) row.get(numericalIndices[i])).doubleValue(); + } + for (int i = 0; i < indices.length; i++) { + dIndices[i + numericalIndices.length] = indices[i]; + } + result.set(0, new DenseVector(dIndices)); + return null; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/AutoCrossModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/AutoCrossModelMapper.java new file mode 100644 index 000000000..7ebbe4e1e --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/AutoCrossModelMapper.java @@ -0,0 +1,70 @@ +package com.alibaba.alink.operator.common.feature.AutoCross; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.common.mapper.ComboModelMapper; +import com.alibaba.alink.common.mapper.Mapper; +import com.alibaba.alink.params.feature.AutoCrossPredictParams; +import com.alibaba.alink.params.feature.featuregenerator.HasAppendOriginalData; +import com.alibaba.alink.params.shared.colname.HasOutputCol; +import com.alibaba.alink.params.shared.colname.HasReservedColsDefaultAsNull; +import com.alibaba.alink.pipeline.ModelExporterUtils; +import com.alibaba.alink.pipeline.PipelineStageBase; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class AutoCrossModelMapper extends ComboModelMapper { + private static final long serialVersionUID = 4498117230717789425L; + private String[] reversedCols; + private String outputCol; + private final AutoCrossPredictParams.OutputFormat outputFormat; + boolean appendOriginalVec = true; + + //todo may support input vector col. + public AutoCrossModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { + super(modelSchema, dataSchema, params); + outputFormat = params.get(AutoCrossPredictParams.OUTPUT_FORMAT); + } + + @Override + public List getLoadedMapperList() { + List mapperList = new ArrayList <>(); + Collections.addAll(mapperList, this.mapperList.getMappers()); + return mapperList; + } + + @Override + protected Tuple4 [], String[]> prepareIoSchema(TableSchema modelSchema, + TableSchema dataSchema, + Params params) { + outputCol = params.get(AutoCrossPredictParams.OUTPUT_COL); + String[] outputCols = new String[] {outputCol}; + TypeInformation[] outputTypes = new TypeInformation[] {AlinkTypes.VECTOR}; + reversedCols = params.get(HasReservedColsDefaultAsNull.RESERVED_COLS); + return Tuple4.of(dataSchema.getFieldNames(), outputCols, outputTypes, reversedCols); + } + + @Override + public void loadModel(List modelRows) { + List , TableSchema, List >> stages = + ModelExporterUtils.loadStagesFromPipelineModel(modelRows, getModelSchema()); + stages.get(1).f0.set(HasOutputCol.OUTPUT_COL, outputCol); + stages.get(1).f0.set(AutoCrossPredictParams.OUTPUT_FORMAT, outputFormat); + stages.get(1).f0.set(HasAppendOriginalData.APPEND_ORIGINAL_DATA, appendOriginalVec); + if (reversedCols != null) { + stages.get(1).f0.set(HasReservedColsDefaultAsNull.RESERVED_COLS, reversedCols); + } + + this.mapperList = ModelExporterUtils + .loadMapperListFromStages(stages, getDataSchema()); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/AutoCrossObjFunc.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/AutoCrossObjFunc.java new file mode 100644 index 000000000..a6f8f6957 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/AutoCrossObjFunc.java @@ -0,0 +1,95 @@ +package com.alibaba.alink.operator.common.feature.AutoCross; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.MatVecOp; +import com.alibaba.alink.common.linalg.SparseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.operator.common.linear.unarylossfunc.LogLossFunc; +import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc; + +public class AutoCrossObjFunc extends OptimObjFunc { + private static final long serialVersionUID = 4901712689789156832L; + private LogLossFunc lossFunc; + + public AutoCrossObjFunc(Params params) { + super(params); + lossFunc = new LogLossFunc(); + } + + //index, label, vectorData. index indices the start of calculating coef. allCoefVector saves all the coefs. + @Override + public double calcLoss(Tuple3 labelVector, DenseVector allCoefVector) { + double eta = getEta(labelVector, allCoefVector); + return this.lossFunc.loss(eta, labelVector.f1); + } + + //需要将fixed的param的数目记录下来,从而可以稀疏地更新grad。 + @Override + public void updateGradient(Tuple3 labelVector, + DenseVector allCoefVector, + DenseVector updateGrad) { + double eta = getEta(labelVector, allCoefVector); + double div = lossFunc.derivative(eta, labelVector.f1); + int fixedCoefSize = allCoefVector.size() - updateGrad.size(); + double[] grad = updateGrad.getData(); + SparseVector sv = (SparseVector) labelVector.f2; + //grad只是需要更新的梯度部分。 + for (int i = 0; i < sv.getIndices().length; i++) { + int index = sv.getIndices()[i]; + if (index >= fixedCoefSize) { + grad[index - fixedCoefSize] += div * sv.getValues()[i]; + } + } + } + + @Override + public void updateHessian(Tuple3 labelVector, + DenseVector coefVector, + DenseMatrix updateHessian) { + throw new RuntimeException("do not support hessian."); + } + + @Override + public boolean hasSecondDerivative() { + return false; + } + + @Override + public double[] calcSearchValues(Iterable > labelVectors, + DenseVector allCoefVector,//fixed+candidate + DenseVector dirVec,//candidate + double beta, + int numStep) { + double[] vec = new double[numStep + 1]; + + int fixedSize = allCoefVector.size() - dirVec.size(); + double[] realDir = new double[allCoefVector.size()]; + System.arraycopy(dirVec.getData(), 0, realDir, fixedSize, dirVec.size()); + DenseVector realDirVec = new DenseVector(realDir); + for (Tuple3 labelVector : labelVectors) { + // //cancat + // if (index == -1) { + // index = (int) Math.round(labelVector.f0); + // double[] realDir = allCoefVector.getData(); + // System.arraycopy(dirVec.getData(), 0, realDir, index, dirVec.size()); + // realDirVec = new DenseVector(realDir); + // } + + double weight = labelVector.f0; + double etaCoef = getEta(labelVector, allCoefVector); + double etaDelta = getEta(labelVector, realDirVec) * beta; + for (int i = 0; i < numStep + 1; ++i) { + vec[i] += weight * lossFunc.loss(etaCoef - i * etaDelta, labelVector.f1); + } + } + return vec; + } + + private double getEta(Tuple3 labelVector, DenseVector coefVector) { + return MatVecOp.dot(labelVector.f2, coefVector); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/BuildSideOutput.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/BuildSideOutput.java new file mode 100644 index 000000000..7d44d9462 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/BuildSideOutput.java @@ -0,0 +1,272 @@ +package com.alibaba.alink.operator.common.feature.AutoCross; + +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelData; +import com.alibaba.alink.operator.common.feature.OneHotModelDataConverter; +import com.alibaba.alink.params.shared.colname.HasSelectedCols; +import com.google.common.collect.Lists; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +public class BuildSideOutput extends RichMapPartitionFunction { + int numericalSize; + + public BuildSideOutput(int numericalSize) { + this.numericalSize = numericalSize; + } + + public static void buildModel(List oneHotModelRow, List autoCrossModelRow, + Collector out) { + final String LOW_FREQUENCY_VALUE = "lowFrequencyValue"; + final String NULL_VALUE = "null"; + + String jsonStr = autoCrossModelRow.stream().filter(row -> row.getField(0).equals(0L)) + .map(row -> (String) row.getField(1)) + .collect(Collectors.toList()).get(0); + AutoCrossAlgoModelMapper.FeatureSet fs = JsonConverter.fromJson(jsonStr, AutoCrossAlgoModelMapper.FeatureSet.class); + List crossFeatureSet = fs.crossFeatureSet; + boolean hasDiscrete = fs.hasDiscrete; + int numericalSize = fs.numericalCols.length; + + MultiStringIndexerModelData data = new OneHotModelDataConverter().load(oneHotModelRow).modelData; + String[] featureCols = data.meta.get(HasSelectedCols.SELECTED_COLS); + int featureNumber = data.tokenNumber.size(); + int[] featureSize = new int[featureNumber]; + int[] cunsum = new int[featureNumber + 1]; + for (int i = 0; i < featureNumber; i++) { + featureSize[i] = (int) (data.tokenNumber.get(i) + (hasDiscrete ? 2 : 1)); + cunsum[i + 1] = cunsum[i] + featureSize[i]; + } + + //前面是正常的,后面是低频的 + Map > featureValueMap = new HashMap <>(); + Set crossSingleFeature = new HashSet <>(); + for (int[] ints : crossFeatureSet) { + for (int i : ints) { + crossSingleFeature.add(i); + } + } + + //HashMap存tuple2,另一个存低频的。 + //构造的时候,将低频的映射到相应的index中。 + //feature index, feature value, feature value index. + if (hasDiscrete) { + for (Tuple3 tokens : data.tokenAndIndex) { + int featureIndex = tokens.f2.intValue(); + Tuple2 featureValues; + if (crossSingleFeature.contains(tokens.f0)) { + if (!featureValueMap.containsKey(tokens.f0)) { + featureValues = Tuple2.of(new String[featureSize[tokens.f0]], new String[0]); + } else { + featureValues = featureValueMap.get(tokens.f0); + } + featureValues.f0[featureIndex] = tokens.f1; + featureValueMap.put(tokens.f0, featureValues); + } + out.collect(Row.of(cunsum[tokens.f0] + featureIndex + numericalSize, featureCols[tokens.f0], tokens.f1)); + } + //写null和低频的。 + for (int key = 0; key < featureSize.length; key++) { + out.collect(Row.of(cunsum[key + 1] - 2 + numericalSize, featureCols[key], NULL_VALUE)); + Row rareData = Row.of(cunsum[key + 1] - 1 + numericalSize, featureCols[key], LOW_FREQUENCY_VALUE); + out.collect(rareData); + if (featureValueMap.containsKey(key)) { + Tuple2 featureValues = featureValueMap.get(key); + featureValues.f1 = new String[] {LOW_FREQUENCY_VALUE}; + //save the rare feature values. + // for (String s : featureValues.f1) { + // Row rareData = Row.of(cunsum[key + 1] - 1, featureCols[key], s); + // out.collect(rareData); + // } + featureValueMap.put(key, featureValues); + } + } + + //cross feature + //先cross高频的,再加上低频的。 + int startIndex = cunsum[featureNumber] + numericalSize; + for (int[] crossFeature : crossFeatureSet) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < crossFeature.length; i++) { + if (i == 0) { + sb.append(featureCols[crossFeature[i]]); + } else { + sb.append(" and " + featureCols[crossFeature[i]]); + } + } + String crossedFeatureName = sb.toString(); + + int totalSize = 1; + for (int i : crossFeature) { + totalSize *= (featureValueMap.get(i).f0.length); + } + //count是进制。 + int[] count = new int[crossFeature.length + 1]; + count[0] = -1; + for (int i = 0; i < totalSize; i++) { + //控制进位 + int start = 0; + count[start]++; + while (true) { + if (count[start] == featureValueMap.get(crossFeature[start]).f0.length) { + count[start++] = 0; + count[start] += 1; + } else { + break; + } + } + + String[][] crossValues = new String[crossFeature.length][]; + boolean toCalc = true; + for (int j = 0; j < crossFeature.length; j++) { + int crossFeatureIndex = crossFeature[j]; + if (count[j] == featureSize[crossFeatureIndex] - 1) { + crossValues[j] = featureValueMap.get(crossFeature[j]).f1; + if (crossValues[j].length == 0) { + toCalc = false; + break; + } + } else if (count[j] == featureSize[crossFeatureIndex] - 2) { + crossValues[j] = new String[] {NULL_VALUE}; + } else { + crossValues[j] = new String[] {featureValueMap.get(crossFeature[j]).f0[count[j]]}; + } + } + if (toCalc) { + startIndex = concatValue(out, startIndex, crossedFeatureName, crossValues); + } else { + ++startIndex; + } + } + } + } else { + for (Tuple3 tokens : data.tokenAndIndex) { + int featureIndex = tokens.f2.intValue(); + Tuple2 featureValues; + if (crossSingleFeature.contains(tokens.f0)) { + if (!featureValueMap.containsKey(tokens.f0)) { + featureValues = Tuple2.of(new String[featureSize[tokens.f0]], new String[0]); + } else { + featureValues = featureValueMap.get(tokens.f0); + } + featureValues.f0[featureIndex] = tokens.f1; + featureValueMap.put(tokens.f0, featureValues); + } + out.collect(Row.of(cunsum[tokens.f0] + featureIndex + numericalSize, featureCols[tokens.f0], tokens.f1)); + } + for (int key = 0; key < featureSize.length; key++) { + out.collect(Row.of(cunsum[key + 1] - 1 + numericalSize, featureCols[key], NULL_VALUE)); + } + + int startIndex = cunsum[featureNumber] + numericalSize; + for (int[] crossFeature : crossFeatureSet) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < crossFeature.length; i++) { + if (i == 0) { + sb.append(featureCols[crossFeature[i]]); + } else { + sb.append(" and " + featureCols[crossFeature[i]]); + } + } + String crossedFeatureName = sb.toString(); + + int[] count = new int[crossFeature.length + 1]; + count[0] = -1; + int totalSize = 1; + for (int i : crossFeature) { + totalSize *= featureValueMap.get(i).f0.length; + } + + for (int i = 0; i < totalSize; i++) { + int start = 0; + count[start]++; + while (true) { + if (count[start] == featureValueMap.get(crossFeature[start]).f0.length) { + count[start++] = 0; + count[start] += 1; + } else { + break; + } + } + sb = new StringBuilder(); + for (int j = 0; j < crossFeature.length; j++) { + String value = "null"; + + if (featureValueMap.get(crossFeature[j]).f0[count[j]] != null) { + value = featureValueMap.get(crossFeature[j]).f0[count[j]]; + } + + if (j == 0) { + sb.append(value); + } else { + sb.append(", " + value); + } + } + Row res = Row.of(startIndex++, crossedFeatureName, sb.toString()); + out.collect(res); + } + } + } + } + + @Override + public void mapPartition(Iterable values, Collector out) throws Exception { + List oneHotModelRow = Lists.newArrayList(values); + + List autoCrossModelRow = getRuntimeContext().getBroadcastVariable("autocrossModel"); + + buildModel(oneHotModelRow, autoCrossModelRow, out); + + } + + private static int concatValue(Collector out, int startIndex, + String crossedFeatureName, String[][] values) { + int valueSize = values.length; + int[] maxSize = new int[valueSize]; + int[] countSize = new int[valueSize]; + countSize[0] = -1; + int allNumber = 1; + for (int i = 0; i < valueSize; i++) { + maxSize[i] = values[i].length; + allNumber *= maxSize[i]; + } + for (int i = 0; i < allNumber; i++) { + int start = 0; + countSize[start]++; + while (true) { + if (countSize[start] == maxSize[start]) { + countSize[start++] = 0; + countSize[start] += 1; + } else { + break; + } + } + String res = concat(values, countSize); + out.collect(Row.of(startIndex, crossedFeatureName, res)); + } + return ++startIndex; + } + + private static String concat(String[][] values, int[] countSize) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < values.length; i++) { + if (i == 0) { + sb.append(values[i][countSize[i]]); + } else { + sb.append(", " + values[i][countSize[i]]); + } + } + return sb.toString(); + } +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/CrossCandidateSelectorModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/CrossCandidateSelectorModelMapper.java new file mode 100644 index 000000000..47588dc25 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/CrossCandidateSelectorModelMapper.java @@ -0,0 +1,16 @@ +package com.alibaba.alink.operator.common.feature.AutoCross; + +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; + +import com.alibaba.alink.params.feature.featuregenerator.HasAppendOriginalData; + +public class CrossCandidateSelectorModelMapper extends AutoCrossModelMapper { + + public CrossCandidateSelectorModelMapper(TableSchema modelSchema, + TableSchema dataSchema, + Params params) { + super(modelSchema, dataSchema, params); + appendOriginalVec = params.get(HasAppendOriginalData.APPEND_ORIGINAL_DATA); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/DataProfile.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/DataProfile.java new file mode 100644 index 000000000..ab8fbad68 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/DataProfile.java @@ -0,0 +1,14 @@ +package com.alibaba.alink.operator.common.feature.AutoCross; + +import com.alibaba.alink.operator.common.linear.LinearModelType; + +public class DataProfile { + public int numDistinctLabels;//the label numbers. if linear reg, it is 0; if lr, it is 2. + public boolean hasIntercept; + + public DataProfile(LinearModelType linearModelType, boolean hasIntercept) { + this.numDistinctLabels = linearModelType == LinearModelType.LinearReg ? 0 : 2; + this.hasIntercept = hasIntercept; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/FeatureEvaluator.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/FeatureEvaluator.java new file mode 100644 index 000000000..9e21358be --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/FeatureEvaluator.java @@ -0,0 +1,327 @@ +package com.alibaba.alink.operator.common.feature.AutoCross; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.SparseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.model.ModelParamName; +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.operator.common.evaluation.BinaryClassMetrics; +import com.alibaba.alink.operator.common.evaluation.BinaryMetricsSummary; +import com.alibaba.alink.operator.common.evaluation.EvaluationUtil; +import com.alibaba.alink.operator.common.linear.LinearModelData; +import com.alibaba.alink.operator.common.linear.LinearModelDataConverter; +import com.alibaba.alink.operator.common.linear.LinearModelMapper; +import com.alibaba.alink.operator.common.linear.LinearModelType; +import com.alibaba.alink.operator.common.optim.LocalOptimizer; +import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc; +import com.alibaba.alink.operator.local.classification.BaseLinearModelTrainLocalOp; +import com.alibaba.alink.params.classification.LinearModelMapperParams; +import com.alibaba.alink.params.shared.HasNumThreads; +import com.alibaba.alink.params.shared.linear.LinearTrainParams; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; + +public class FeatureEvaluator { + private static final boolean HAS_INTERCEPT = true; + private LinearModelType linearModelType; + private final List > data; + private double[] fixedCoefs; + private int[] featureSize; + private double fraction; + private boolean toFixCoef; + private int kCross; + + public FeatureEvaluator(LinearModelType linearModelType, + List > data, + int[] featureSize, + double[] fixedCoefs, + double fraction, + boolean toFixCoef,//这边是从外面传进来的,少不了。 + int kCross) { + this.linearModelType = linearModelType; + this.data = data; + this.featureSize = featureSize; + this.fixedCoefs = fixedCoefs; + this.fraction = fraction; + this.toFixCoef = toFixCoef; + this.kCross = kCross; + } + + //calculate one group of features with all data. + public Tuple2 score(List crossFeatures, int numericalSize) { + DataProfile profile = new DataProfile(linearModelType, HAS_INTERCEPT); + List > dataWithCrossFea = expandFeatures(data, crossFeatures, featureSize, + numericalSize); + System.out.println(JsonConverter.toJson(crossFeatures) + ",vector size: " + dataWithCrossFea.get(0).f2.size()); + LinearModelData model = null; + double score = 0.; + for (int i = 0; i < kCross; i++) { + Tuple2 >, List >> splited = + split(dataWithCrossFea, fraction, i);// i is the random seed. + //train the model with half data and evaluation with another half. + model = train(splited.f0, profile, toFixCoef, fixedCoefs); + score += evaluate(model, splited.f1);//return the auc of predict data. + } + + return Tuple2.of(score / kCross, model.coefVector.getData()); + } + + public static LinearModelData train(List > samples, DataProfile profile, + boolean toFixCoef, double[] fixedCoefs) { + return train(samples, profile); + } + + public static LinearModelData train(List > samples, DataProfile profile) { + final LinearModelType linearModelType = LinearModelType.LR; + + final OptimObjFunc objFunc = OptimObjFunc.getObjFunction(linearModelType, new Params()); + final boolean remapLabelValue = linearModelType.equals(LinearModelType.LR); + + samples = samples.stream() + .map(sample -> { + Vector features = profile.hasIntercept ? sample.f2.prefix(1.0) : sample.f2; + double label = sample.f1; + if (remapLabelValue) { + label = label == 0. ? 1.0 : -1.0; + } + return Tuple3.of(sample.f0, label, features); + }) + .collect(Collectors.toList()); + + Params optParams = new Params() + .set(HasNumThreads.NUM_THREADS, 1) + .set(LinearTrainParams.WITH_INTERCEPT, profile.hasIntercept) + .set(LinearTrainParams.STANDARDIZATION, false); + + + int weightSize = samples.get(0).f2.size();//intercept has been added before. + double[] iniWeight = new double[weightSize]; + Arrays.fill(iniWeight, 1e-4); + DenseVector initialWeights = new DenseVector(iniWeight); + + Tuple2 weightsAndLoss = LocalOptimizer.optimize(objFunc, samples, initialWeights, optParams); + + Params meta = new Params(); + Double[] labelValues = new Double[profile.numDistinctLabels]; + for (int i = 0; i < labelValues.length; i++) { + labelValues[i] = (double) i; + } + meta.set(ModelParamName.MODEL_NAME, "model"); + meta.set(ModelParamName.LINEAR_MODEL_TYPE, linearModelType); + meta.set(ModelParamName.LABEL_VALUES, labelValues); + meta.set(ModelParamName.HAS_INTERCEPT_ITEM, profile.hasIntercept); + meta.set(ModelParamName.VECTOR_COL_NAME, "features"); + //those params are set just for build linear model data. + meta.set(LinearTrainParams.LABEL_COL, null); + meta.set(ModelParamName.FEATURE_TYPES, null); + + return BaseLinearModelTrainLocalOp.buildLinearModelData(meta, + null, + Types.DOUBLE, + null, + profile.hasIntercept, + false, + weightsAndLoss + ); + } + + public static double evaluate(LinearModelData model, List > samples) { + + LinearModelMapper modelMapper = new LinearModelMapper( + new LinearModelDataConverter(Types.DOUBLE).getModelSchema(), + new TableSchema(new String[] {"features", "label"}, new TypeInformation[] {Types.STRING, Types.DOUBLE}), + new Params() + .set(LinearModelMapperParams.VECTOR_COL, "features") + .set(LinearModelMapperParams.PREDICTION_COL, "prediction_result") + .set(LinearModelMapperParams.PREDICTION_DETAIL_COL, "prediction_detail") + .set(LinearModelMapperParams.RESERVED_COLS, new String[] {"label"}) + ); + modelMapper.loadModel(model); + + List predictions = samples.stream() + .map(t3 -> { + Row row = Row.of(t3.f2, t3.f1); + try { + Row pred = modelMapper.map(row); + return Row.of(pred.getField(0), pred.getField(2)); + } catch (Exception e) { + throw new RuntimeException("Fail to predict.", e); + } + }) + .collect(Collectors.toList()); + + if (model.linearModelType.equals(LinearModelType.LR)) { + BinaryMetricsSummary metricsSummary = (BinaryMetricsSummary) EvaluationUtil.getDetailStatistics + (predictions, + "1.0", true, Types.DOUBLE); + BinaryClassMetrics metrics = metricsSummary.toMetrics(); + return metrics.getAuc(); + } else { + throw new UnsupportedOperationException("Not yet supported model type: " + model.linearModelType); + } + } + + //split all the samples to two parts. + public static Tuple2 >, List >> + split(List > samples, + double fraction, + int seed) { + List > part1 = new ArrayList <>(samples.size() / 2 + 1); + List > part2 = new ArrayList <>(samples.size() / 2 + 1); + int dataSize = samples.size(); + if (dataSize < 2) { + throw new RuntimeException("Data size is too small!"); + } + part1.add(samples.get(0)); + part2.add(samples.get(1)); + Random rand = new Random(seed); + for (int i = 2; i < samples.size(); i++) { + if (rand.nextDouble() <= fraction) { + part1.add(samples.get(i)); + } else { + part2.add(samples.get(i)); + } + } + return Tuple2.of(part1, part2); + } + + /** + * Generate data for training and evaluation by onehoting the raw features and crossed features. + */ + public static List > expandFeatures(List > data, + List crossFeatures, + int[] featureSize, int numericalSize) { + int featureNum = featureSize.length;//输入了这么多个categorical的特征。 + int originVectorSize = data.get(0).f2.size(); + int crossFeatureSize = crossFeatures.size(); + int[] cumsumFeatureSize = new int[featureSize.length];//累积加起来的,在dot的时候用到。 + for (int i = 0; i < featureSize.length; i++) { + if (i == 0) { + cumsumFeatureSize[i] = 0; + } else { + cumsumFeatureSize[i] = cumsumFeatureSize[i - 1] + featureSize[i - 1]; + } + } + int[][] carry = new int[crossFeatureSize][];//onehot进制 + for (int i = 0; i < crossFeatureSize; i++) { + int[] candidateFeature = crossFeatures.get(i); + for (int j = 0; j < candidateFeature.length; j++) { + if (j == 0) { + carry[i] = new int[candidateFeature.length]; + carry[i][j] = 1; + } else { + carry[i][j] = carry[i][j - 1] * featureSize[candidateFeature[j - 1]]; + } + } + } + List > resData = new ArrayList <>(data.size()); + + //我需要知道每个candidate有多少个取值。然后每个都拼起来。 + //现在这种做法可以解决之前的问题。也就是训练的时候过于稀疏。 + //第一个存的是每个candidate的可选,后第二个存的是每条数据在每个candidate上的index。 + int[] candidateNumber = new int[crossFeatureSize]; + int[][] featureIndices = new int[data.size()][crossFeatureSize]; + for (int i = 0; i < crossFeatureSize; i++) { + //Map tokenToInt = new HashMap <>(); + //int count = 0; + for (int j = 0; j < data.size(); j++) { + int[] originIndices = ((SparseVector) data.get(j).f2).getIndices().clone(); + for (int k = numericalSize; k < originIndices.length; k++) { + originIndices[k] -= numericalSize; + } + //index是组合在当前cross里的位置 + int index = dot(carry[i], crossFeatures.get(i), originIndices, numericalSize, cumsumFeatureSize); + if (index < 0) { + System.out.println(); + } + //if (!tokenToInt.containsKey(index)) { + // tokenToInt.put(index, count++); + //} + featureIndices[j][i] = index; + //featureIndices[j][i] = tokenToInt.get(index);//记不清干嘛的了? + } + int count = 1; + for (int featureIndex : crossFeatures.get(i)) { + count *= featureSize[featureIndex]; + } + candidateNumber[i] = count;//这个很明确,每个candidate的数目。 + } + + //candidateCrossIndex是每组cross feature开始的index + int[] candidateCrossIndex = new int[crossFeatureSize]; + candidateCrossIndex[0] = originVectorSize; + for (int i = 1; i < crossFeatureSize; i++) { + candidateCrossIndex[i] = candidateCrossIndex[i - 1] + candidateNumber[i - 1]; + } + + //featureIndices是每个交叉特征的位置 + for (int i = 0; i < data.size(); i++) { + for (int j = 0; j < crossFeatureSize; j++) { + featureIndices[i][j] += candidateCrossIndex[j]; + } + } + + //最后一组cross特征的起始位置 + int resVecSize = candidateCrossIndex[crossFeatureSize - 1] + candidateNumber[crossFeatureSize - 1]; + //非零元的indices个数 + int newVecElementSize = numericalSize + featureNum + crossFeatureSize; + for (int i = 0; i < data.size(); i++) { + Tuple3 datum = data.get(i); + int[] originIndices = ((SparseVector) datum.f2).getIndices(); + int[] newIndices = new int[newVecElementSize]; + System.arraycopy(originIndices, 0, newIndices, 0, featureNum + numericalSize); + System.arraycopy(featureIndices[i], 0, newIndices, featureNum + numericalSize, crossFeatureSize); + double[] newValues = new double[newVecElementSize]; + Arrays.fill(newValues, 1.0); + System.arraycopy(((SparseVector) datum.f2).getValues(), 0, newValues, 0, numericalSize); + //todo 存fixed。 + resData.add(Tuple3.of(datum.f0, datum.f1, + new SparseVector(resVecSize, newIndices, newValues))); + } + + // + // for (Tuple3 datum : data) { + // int[] originIndices = ((SparseVector) datum.f2).getIndices(); + // int indicesSize = originIndices.length + crossFeatureSize; + // int[] newIndices = new int[indicesSize]; + // System.arraycopy(originIndices, 0, newIndices, 0, originIndices.length); + // for (int i = 0; i < crossFeatureSize; i++) { + // newIndices[i + originIndices.length] = cunsumCrossFeatureSize[i] + // + dot(carry[i], crossFeatures.get(i), originIndices, cunsumFeatureSize); + // } + // double[] newValues = new double[indicesSize]; + // Arrays.fill(newValues, 1.0); + // resData.add(Tuple3.of(datum.f0, datum.f1, + // new SparseVector(dataSize, newIndices, newValues))); + // + // } + return resData; + } + + /* + * calculate and consider the carry. + * if carry is 1, 3, 9; if the input is 1, 2, 0, then it is the 7th one; the input 0, 1, 0 is the 3rd one. + * //considering its real index in vector, we need to minus 1 by the former element count. + */ + private static int dot(int[] carry, int[] crossFeatures, int[] originIndices, int numericalSize, + int[] cumsumFeatureSize) { + int res = 0; + for (int i = 0; i < carry.length; i++) { + res += carry[i] * (originIndices[crossFeatures[i] + numericalSize] - cumsumFeatureSize[crossFeatures[i]]); + } + return res; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/FeatureSet.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/FeatureSet.java new file mode 100644 index 000000000..510407467 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/FeatureSet.java @@ -0,0 +1,145 @@ +package com.alibaba.alink.operator.common.feature.AutoCross; + +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnore; + +import com.alibaba.alink.common.utils.JsonConverter; +import org.apache.commons.lang3.ArrayUtils; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +public class FeatureSet implements Serializable { + private static final long serialVersionUID = 3050538250196533180L; + public int numRawFeatures;//the feature number of all the input feature, including numerical and categorical. + public String[] numericalCols; + public String vecColName;//the name of all the input features. + public List crossFeatureSet;//save all the crossed features. + public List scores; + public int[] indexSize; + public boolean hasDiscrete; + @JsonIgnore + public double[] fixedCoefs = new double[0]; + + /** + * Create initial set composing of all non-crossed features. + */ + public FeatureSet(int[] featureSize) { + //featureSize is each value of onehot feature, so numRawFeatures is the sum of numRawFeatures. + this.numRawFeatures = featureSize.length; + crossFeatureSet = new ArrayList <>(); + scores = new ArrayList <>(); + } + + + public void updateFixedCoefs(double[] fixedCoefs) { + this.fixedCoefs = fixedCoefs; + } + + public List generateCandidateCrossFeatures() { + int n = crossFeatureSet.size() + numRawFeatures; + List existingFeatures = new ArrayList <>();//saves the index of existing features. + for (int i = 0; i < numRawFeatures; i++) { + existingFeatures.add(new int[] {i}); + } + existingFeatures.addAll(crossFeatureSet); + Set pairwiseCrossFea = new HashSet <>(n * (n - 1) / 2); + for (int i = 0; i < n; i++) { + for (int j = i + 1; j < n; j++) { + int[] crossIJ = ArrayUtils.addAll(existingFeatures.get(i), + existingFeatures.get(j));//cross two features. + //remove duplicate features. If the crossed features are "1 1 2", then it will be transformed to "1 2" + crossIJ = removeDuplicate(crossIJ); + Arrays.sort(crossIJ); + if (!this.contains(crossIJ)) { + pairwiseCrossFea.add(new IntArrayComparator(crossIJ)); + } + } + } + List res = new ArrayList <>(pairwiseCrossFea.size()); + for (IntArrayComparator i : pairwiseCrossFea) { + res.add(i.data); + } + return res; + } + + //judge whether crossFeatureSet contains crossFea. + private boolean contains(int[] crossFea) { + for (int[] target : crossFeatureSet) { + if (crossFea.length != target.length) { + continue; + } + boolean same = true; + for (int j = 0; j < target.length; j++) { + if (target[j] != crossFea[j]) { + same = false; + break; + } + } + if (same) { + return true; + } + } + return false; + } + + private static int[] removeDuplicate(int[] input) { + Set set = new HashSet <>(); + for (int anInput : input) { + set.add(anInput); + } + int[] ret = new int[set.size()]; + int cnt = 0; + for (Integer v : set) { + ret[cnt++] = v; + } + return ret; + } + + public void addOneCrossFeature(int[] crossFea, double score) { + this.crossFeatureSet.add(crossFea); + this.scores.add(score); + } + + public double[] getFixedCoefs() { + return fixedCoefs; + } + + @Override + public String toString() { + assert crossFeatureSet.size() == scores.size(); + return JsonConverter.toJson(this); + } + + + private class IntArrayComparator { + int[] data; + + IntArrayComparator() {} + + IntArrayComparator(int[] data) { + this.data = data; + } + + + @Override + public int hashCode() { + StringBuilder sbd = new StringBuilder(); + for (int datum : data) { + sbd.append(datum).append(","); + } + return sbd.toString().hashCode(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null) return false; + if (this.getClass() != o.getClass()) return false; + return o.hashCode() == this.hashCode(); + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/OneHotExtractor.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/OneHotExtractor.java new file mode 100644 index 000000000..602731ec9 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/OneHotExtractor.java @@ -0,0 +1,36 @@ +package com.alibaba.alink.operator.common.feature.AutoCross; + +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.operator.common.feature.OneHotModelData; +import com.alibaba.alink.operator.common.feature.OneHotModelDataConverter; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class OneHotExtractor extends RichFlatMapFunction > { + private Map map; + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + int parallelism = getRuntimeContext().getNumberOfParallelSubtasks(); + if (parallelism != 1) { + throw new RuntimeException("parallelism of token number extraction must be set as 1."); + } + map = new HashMap <>(); + List modelRows = getRuntimeContext().getBroadcastVariable("model"); + OneHotModelData model = new OneHotModelDataConverter().load(modelRows); + map = model.modelData.tokenNumber; + map.keySet().forEach(x -> map.compute(x, (k, v) -> v += 1));//consider null. + } + + @Override + public void flatMap(Integer value, Collector > out) throws Exception { + out.collect(map); + } +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/OneHotOperator.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/OneHotOperator.java new file mode 100644 index 000000000..4ce6c56c1 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/OneHotOperator.java @@ -0,0 +1,69 @@ +package com.alibaba.alink.operator.common.feature.AutoCross; + +import com.alibaba.alink.common.linalg.SparseVector; + +import java.util.Arrays; +import java.util.List; + +public class OneHotOperator { + List crossFeatures; + int[][] carry; + int[] cunsumCrossFeatureSize; + int[] cunsumFeatureSize; + int dataSize; + + //original size is the value number of input data. + OneHotOperator(int originSize, List crossFeatures, int[] indexSize) { + this.crossFeatures = crossFeatures; + cunsumCrossFeatureSize = new int[crossFeatures.size() + 1];//存累积的feature size。第0个存origin,后面的存cross feature的。 + cunsumFeatureSize = new int[indexSize.length + 1]; + for (int i = 0; i < indexSize.length + 1; i++) { + if (i == 0) { + cunsumFeatureSize[i] = 0; + } else { + cunsumFeatureSize[i] = cunsumFeatureSize[i - 1] + indexSize[i - 1]; + } + } + Arrays.fill(cunsumCrossFeatureSize, 1); + carry = new int[crossFeatures.size()][];//onehot进制 + for (int i = 0; i < crossFeatures.size(); i++) { + int[] candidateFeature = crossFeatures.get(i); + for (int j = 0; j < candidateFeature.length; j++) { + if (j == 0) { + carry[i] = new int[candidateFeature.length]; + carry[i][j] = 1; + } else { + carry[i][j] = carry[i][j - 1] * indexSize[candidateFeature[j - 1]]; + } + cunsumCrossFeatureSize[i + 1] *= indexSize[candidateFeature[j]]; + } + } + cunsumCrossFeatureSize[0] = cunsumFeatureSize[cunsumFeatureSize.length - 1]; + for (int i = 1; i <= crossFeatures.size(); i++) { + cunsumCrossFeatureSize[i] += cunsumCrossFeatureSize[i - 1]; + } + dataSize = cunsumCrossFeatureSize[cunsumCrossFeatureSize.length - 1]; + } + + SparseVector oneHotData(SparseVector data) { + int[] originIndices = data.getIndices(); + int indicesSize = originIndices.length + crossFeatures.size(); + int[] newIndices = new int[indicesSize]; + System.arraycopy(originIndices, 0, newIndices, 0, originIndices.length); + for (int i = 0; i < crossFeatures.size(); i++) { + newIndices[originIndices.length + i] = cunsumCrossFeatureSize[i] + + dot(carry[i], crossFeatures.get(i), originIndices, cunsumFeatureSize); + } + double[] newValues = new double[indicesSize]; + Arrays.fill(newValues, 1.0); + return new SparseVector(dataSize, newIndices, newValues); + } + + static int dot(int[] carry, int[] crossFeatures, int[] originIndices, int[] cunsumFeatureSize) { + int res = 0; + for (int i = 0; i < carry.length; i++) { + res += carry[i] * (originIndices[crossFeatures[i]] - cunsumFeatureSize[crossFeatures[i]]); + } + return res; + } +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/PartitionData.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/PartitionData.java new file mode 100644 index 000000000..b43cd6bce --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/AutoCross/PartitionData.java @@ -0,0 +1,61 @@ +package com.alibaba.alink.operator.common.feature.AutoCross; + +import org.apache.flink.api.java.tuple.Tuple3; + +import com.alibaba.alink.common.linalg.Vector; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +public class PartitionData { + private static Map >> data; + + static { + data = new HashMap <>(); + } + + private static ReadWriteLock rwlock = new ReentrantReadWriteLock(); + + public static void addData(int taskId, Tuple3 s) { + rwlock.writeLock().lock(); + try { + data.compute(taskId, (k, v) -> { + if (v == null) { + v = new ArrayList <>(); + } + v.add(s); + return v; + }); + } finally { + rwlock.writeLock().unlock(); + } + } + + public static void addData(int taskId, List > s) { + rwlock.writeLock().lock(); + try { + data.compute(taskId, (k, v) -> { + if (v == null) { + v = new ArrayList <>(); + } + v.addAll(s); + return v; + }); + } finally { + rwlock.writeLock().unlock(); + } + } + + public static List > getData(int taskId) { + rwlock.readLock().lock(); + try { + return data.get(taskId); + } finally { + rwlock.readLock().unlock(); + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/BinarizerMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/BinarizerMapper.java index d262a7445..828e5e829 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/BinarizerMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/BinarizerMapper.java @@ -4,7 +4,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/BinningModelInfo.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/BinningModelInfo.java new file mode 100644 index 000000000..97c3d1005 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/BinningModelInfo.java @@ -0,0 +1,119 @@ +package com.alibaba.alink.operator.common.feature; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import com.alibaba.alink.operator.common.feature.binning.Bins; +import com.alibaba.alink.operator.common.feature.binning.FeatureBinsCalculator; +import com.alibaba.alink.operator.common.feature.binning.FeatureBinsUtil; +import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * BinningModelSummary. + */ +public class BinningModelInfo implements Serializable { + private static final long serialVersionUID = -5091603455134141015L; + private Map >> categoricalScores; + private Map >> numericScores; + + private Map > cutsArray; + + public BinningModelInfo(List list) { + List featureBins = new BinningModelDataConverter().load(list); + cutsArray = new HashMap <>(); + numericScores = new HashMap <>(); + categoricalScores = new HashMap <>(); + + for (FeatureBinsCalculator featureBinsCalculator : featureBins) { + featureBinsCalculator.splitsArrayToInterval(); + List > map = new ArrayList <>(); + if (null != featureBinsCalculator.bin.normBins) { + for (Bins.BaseBin bin : featureBinsCalculator.bin.normBins) { + map.add(Tuple2.of(bin.getValueStr(featureBinsCalculator.getColType()), bin.getIndex())); + } + } + if (null != featureBinsCalculator.bin.nullBin) { + map.add(Tuple2.of(FeatureBinsUtil.NULL_LABEL, featureBinsCalculator.bin.nullBin.getIndex())); + } + if (null != featureBinsCalculator.bin.elseBin) { + map.add(Tuple2.of(FeatureBinsUtil.ELSE_LABEL, featureBinsCalculator.bin.elseBin.getIndex())); + } + if (featureBinsCalculator.isNumeric()) { + numericScores.put(featureBinsCalculator.getFeatureName(), map); + cutsArray.put(featureBinsCalculator.getFeatureName(), + Arrays.asList(featureBinsCalculator.getSplitsArray())); + } else { + categoricalScores.put(featureBinsCalculator.getFeatureName(), map); + } + } + } + + public String[] getCategoryColumns() { + return categoricalScores.keySet().toArray(new String[0]); + } + + public String[] getContinuousColumns() { + return cutsArray.keySet().toArray(new String[0]); + } + + public int getCategorySize(String columnName) { + Integer binCount = categoricalScores.get(columnName).size() - 2; + Preconditions.checkNotNull(binCount, columnName + "is not discrete column!"); + return binCount; + } + + public Number[] getCutsArray(String columnName) { + Number[] cuts = cutsArray.get(columnName).toArray(new Number[0]); + Preconditions.checkNotNull(cuts, columnName + "is not continuous column!"); + return cuts; + } + + @Override + public String toString() { + StringBuilder sbd = new StringBuilder(PrettyDisplayUtils.displayHeadline("BinningModelSummary", '-')); + sbd.append("Binning on ") + .append(categoricalScores.size() + numericScores.size()) + .append(" features.\n") + .append("Categorical features:") + .append(PrettyDisplayUtils.displayList(new ArrayList <>(categoricalScores.keySet()), 3, false)) + .append("\nNumeric features:") + .append(PrettyDisplayUtils.displayList(new ArrayList <>(numericScores.keySet()), 3, false)) + .append("\n") + .append(PrettyDisplayUtils.displayHeadline("Details", '-')); + int size = categoricalScores.values().stream().mapToInt(m -> m.size()).sum() + + numericScores.values().stream().mapToInt(m -> m.size()).sum(); + String[][] table = new String[size][3]; + int cnt = 0; + for (Map.Entry >> entry : categoricalScores.entrySet()) { + table[cnt][0] = entry.getKey(); + for (Tuple2 entry1 : entry.getValue()) { + if (table[cnt][0] == null) { + table[cnt][0] = ""; + } + table[cnt][1] = entry1.f0; + table[cnt++][2] = entry1.f1.toString(); + } + } + for (Map.Entry >> entry : numericScores.entrySet()) { + table[cnt][0] = entry.getKey(); + for (Tuple2 entry1 : entry.getValue()) { + if (table[cnt][0] == null) { + table[cnt][0] = ""; + } + table[cnt][1] = entry1.f0; + table[cnt++][2] = entry1.f1.toString(); + } + } + sbd.append(PrettyDisplayUtils.displayTable(table, table.length, 3, null, + new String[] {"featureName", "FeatureBin", "BinIndex"}, null, 20, 3)); + return sbd.toString(); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/BinningModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/BinningModelMapper.java new file mode 100644 index 000000000..a6574ce65 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/BinningModelMapper.java @@ -0,0 +1,525 @@ +package com.alibaba.alink.operator.common.feature; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.ParamInfo; +import org.apache.flink.ml.api.misc.param.ParamInfoFactory; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.common.mapper.ComboModelMapper; +import com.alibaba.alink.common.mapper.Mapper; +import com.alibaba.alink.common.utils.RowCollector; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.feature.OneHotTrainBatchOp; +import com.alibaba.alink.operator.batch.feature.QuantileDiscretizerTrainBatchOp; +import com.alibaba.alink.operator.batch.feature.WoeTrainBatchOp; +import com.alibaba.alink.operator.common.dataproc.vector.VectorAssemblerMapper; +import com.alibaba.alink.operator.common.feature.binning.Bins; +import com.alibaba.alink.operator.common.feature.binning.FeatureBinsCalculator; +import com.alibaba.alink.params.dataproc.HasHandleInvalid; +import com.alibaba.alink.params.dataproc.vector.VectorAssemblerParams; +import com.alibaba.alink.params.feature.HasEncode; +import com.alibaba.alink.params.feature.HasEncode.Encode; +import com.alibaba.alink.params.feature.HasEncodeWithoutWoe; +import com.alibaba.alink.params.feature.OneHotPredictParams; +import com.alibaba.alink.params.feature.QuantileDiscretizerPredictParams; +import com.alibaba.alink.params.finance.BinningPredictParams; +import com.alibaba.alink.params.finance.WoePredictParams; +import com.alibaba.alink.params.shared.colname.HasOutputCol; +import com.alibaba.alink.params.shared.colname.HasSelectedCols; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; + +import static com.alibaba.alink.params.feature.HasEncode.Encode.ASSEMBLED_VECTOR; + +public class BinningModelMapper extends ComboModelMapper { + + private List mappers; + private BinningPredictParamsBuilder paramsBuilder; + + public BinningModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { + super(modelSchema, dataSchema, params); + paramsBuilder = new BinningPredictParamsBuilder(params.clone(), dataSchema); + } + + public static class BinningPredictParamsBuilder implements Serializable { + public HasEncode.Encode encode; + public String[] selectedCols; + public String[] oneHotInputCols; + public String[] oneHotOutputCols; + public String[] quantileInputCols; + public String[] quantileOutputCols; + public String[] assemblerSelectedCols; + + public String[] resultCols; + public TypeInformation[] resultColTypes; + public String[] reservedCols; + + public Params params; + + public BinningPredictParamsBuilder(Params params, TableSchema dataSchema) { + selectedCols = params.get(BinningPredictParams.SELECTED_COLS); + resultCols = params.get(BinningPredictParams.OUTPUT_COLS); + reservedCols = params.get(BinningPredictParams.RESERVED_COLS); + encode = params.get(BinningPredictParams.ENCODE); + + assemblerSelectedCols = new String[selectedCols.length]; + + switch (encode) { + case WOE: { + params.set(BinningPredictParams.HANDLE_INVALID, HasHandleInvalid.HandleInvalid.KEEP); + params.set(HasEncodeWithoutWoe.ENCODE, HasEncodeWithoutWoe.Encode.INDEX); + if (resultCols == null) { + resultCols = selectedCols; + } + Preconditions.checkArgument(resultCols.length == selectedCols.length, + "OutputCols length must be equal to SelectedCols length!"); + resultColTypes = new TypeInformation[selectedCols.length]; + Arrays.fill(resultColTypes, AlinkTypes.DOUBLE); + break; + } + case INDEX: { + if (resultCols == null) { + resultCols = selectedCols; + } + Preconditions.checkArgument(resultCols.length == selectedCols.length, + "OutputCols length must be equal to SelectedCols length!"); + params.set(HasEncodeWithoutWoe.ENCODE, HasEncodeWithoutWoe.Encode.INDEX); + resultColTypes = new TypeInformation[selectedCols.length]; + Arrays.fill(resultColTypes, AlinkTypes.LONG); + break; + } + case VECTOR: { + if (resultCols == null) { + resultCols = selectedCols; + } + Preconditions.checkArgument(resultCols.length == selectedCols.length, + "OutputCols length must be equal to SelectedCols length!"); + params.set(HasEncodeWithoutWoe.ENCODE, HasEncodeWithoutWoe.Encode.VECTOR); + resultColTypes = new TypeInformation[selectedCols.length]; + Arrays.fill(resultColTypes, AlinkTypes.SPARSE_VECTOR); + break; + } + case ASSEMBLED_VECTOR: { + params.set(HasEncodeWithoutWoe.ENCODE, HasEncodeWithoutWoe.Encode.VECTOR); + Preconditions.checkArgument(null != resultCols && resultCols.length == 1, + "When encode is ASSEMBLED_VECTOR, outputCols must be given and the length must be 1!"); + params.set(HasOutputCol.OUTPUT_COL, resultCols[0]); + resultCols = new String[selectedCols.length]; + for (int i = 0; i < selectedCols.length; i++) { + resultCols[i] = selectedCols[i] + "_ASSEMBLED_VECTOR"; + } + + resultColTypes = new TypeInformation[] {AlinkTypes.SPARSE_VECTOR}; + break; + } + default: { + throw new RuntimeException("Not support " + encode.name() + " yet!"); + } + } + List numericIndices = new ArrayList <>(); + List discreteIndices = new ArrayList <>(); + for (int i = 0; i < selectedCols.length; i++) { + if (TableUtil.isSupportedNumericType(TableUtil.findColTypeWithAssert(dataSchema, selectedCols[i]))) { + numericIndices.add(i); + } else { + discreteIndices.add(i); + } + } + + quantileInputCols = new String[numericIndices.size()]; + quantileOutputCols = new String[numericIndices.size()]; + oneHotInputCols = new String[discreteIndices.size()]; + oneHotOutputCols = new String[discreteIndices.size()]; + + for (int i = 0; i < numericIndices.size(); i++) { + quantileInputCols[i] = selectedCols[numericIndices.get(i)]; + quantileOutputCols[i] = quantileInputCols[i] + "_QUANTILE"; + assemblerSelectedCols[TableUtil.findColIndexWithAssertAndHint(selectedCols, quantileInputCols[i])] + = quantileOutputCols[i]; + } + + for (int i = 0; i < discreteIndices.size(); i++) { + oneHotInputCols[i] = selectedCols[discreteIndices.get(i)]; + oneHotOutputCols[i] = oneHotInputCols[i] + "_ONE_HOT"; + assemblerSelectedCols[TableUtil.findColIndexWithAssertAndHint(selectedCols, oneHotInputCols[i])] + = oneHotOutputCols[i]; + } + this.params = params; + } + } + + @Override + public void loadModel(List modelRows) { + List featureBinsCalculatorList = new BinningModelDataConverter().load(modelRows); + HashSet numeric = new HashSet <>(); + HashSet discrete = new HashSet <>(); + //numeric, discrete + Tuple2 , List > featureBorders = + distinguishNumericDiscrete( + featureBinsCalculatorList.toArray(new FeatureBinsCalculator[0]), + params.get(BinningPredictParams.SELECTED_COLS), + numeric, + discrete); + + for (String s : params.get(BinningPredictParams.SELECTED_COLS)) { + TypeInformation type = TableUtil.findColTypeWithAssert(this.getDataSchema(), s); + Preconditions.checkNotNull(type, "%s is not found in data!", s); + Preconditions.checkState((TableUtil.isSupportedNumericType(type) && numeric.contains(s)) || + (!TableUtil.isSupportedNumericType(type) && discrete.contains(s)), "%s is not found in model!", s); + } + + mappers = new ArrayList <>(); + TableSchema dataSchema = getDataSchema(); + //discrete + if (numeric.size() > 0) { + Preconditions.checkState(featureBorders.f0.size() > 0, + "There is numeric col that is not included in model, please check selectedCols!"); + QuantileDiscretizerModelMapper quantileDiscretizerModelMapper = getQuantileModelMapper(featureBorders.f0, + dataSchema, paramsBuilder); + dataSchema = quantileDiscretizerModelMapper.getOutputSchema(); + mappers.add(quantileDiscretizerModelMapper); + } + + //numeric + if (discrete.size() > 0) { + Preconditions.checkState(featureBorders.f1.size() > 0, + "There is discrete col that is not included in model, please check selectedCols!"); + OneHotModelMapper oneHotModelMapper = getOneHotModelMapper( + featureBorders.f1, + dataSchema, + paramsBuilder); + dataSchema = oneHotModelMapper.getOutputSchema(); + mappers.add(oneHotModelMapper); + } + + //woe + if (params.get(BinningPredictParams.ENCODE).equals(HasEncode.Encode.WOE)) { + featureBorders.f0.addAll(featureBorders.f1); + WoeModelMapper woeModelMapper = getWoeModelMapper( + featureBorders.f0, + dataSchema, + paramsBuilder); + dataSchema = woeModelMapper.getOutputSchema(); + mappers.add(woeModelMapper); + } + + //vector assembler + if (params.get(BinningPredictParams.ENCODE).equals(ASSEMBLED_VECTOR)) { + VectorAssemblerMapper vectorAssemblerMapper = getVectorAssemblerMapper( + dataSchema, + paramsBuilder); + dataSchema = vectorAssemblerMapper.getOutputSchema(); + mappers.add(vectorAssemblerMapper); + } + + //final + params.set(BINNING_INPUT_COLS, this.getDataSchema().getFieldNames()); + mappers.add(new BinningResultMapper(dataSchema, params)); + } + + @Override + public List getLoadedMapperList() { + return mappers; + } + + @Override + protected Tuple4 [], String[]> prepareIoSchema( + TableSchema modelSchema, TableSchema dataSchema, Params params) { + String[] selectedCols = params.get(BinningPredictParams.SELECTED_COLS); + String[] resultCols = params.get(BinningPredictParams.OUTPUT_COLS); + if (resultCols == null) { + resultCols = selectedCols; + } + String[] reservedCols = params.get(BinningPredictParams.RESERVED_COLS); + + Encode encode = params.get(BinningPredictParams.ENCODE); + + TypeInformation[] resultColTypes = new TypeInformation[resultCols.length]; + switch (encode) { + case ASSEMBLED_VECTOR: + Arrays.fill(resultColTypes, AlinkTypes.SPARSE_VECTOR); + break; + case WOE: + Arrays.fill(resultColTypes, AlinkTypes.DOUBLE); + break; + case INDEX: + Arrays.fill(resultColTypes, AlinkTypes.LONG); + break; + default: + Arrays.fill(resultColTypes, AlinkTypes.SPARSE_VECTOR); + } + + return Tuple4.of(selectedCols, resultCols, resultColTypes, reservedCols); + } + + private static OneHotModelMapper getOneHotModelMapper(List featureBinsCalculatorList, + TableSchema oneHotInputSchema, + BinningPredictParamsBuilder paramsBuilder) { + + OneHotModelMapper mapper = new OneHotModelMapper(new OneHotModelDataConverter().getModelSchema(), + oneHotInputSchema, setOneHotModelParams(paramsBuilder)); + String[] selectedCols = new String[featureBinsCalculatorList.size()]; + List > featureBorders = new ArrayList <>(); + for (int i = 0; i < featureBinsCalculatorList.size(); i++) { + FeatureBinsCalculator featureBinsCalculator = featureBinsCalculatorList.get(i); + selectedCols[i] = featureBinsCalculator.getFeatureName(); + featureBorders.add(Tuple2.of((long) i, featureBinsCalculator)); + } + RowCollector modelRows = new RowCollector(); + Params meta = new Params().set(HasSelectedCols.SELECTED_COLS, selectedCols); + OneHotTrainBatchOp.transformFeatureBinsToModel(featureBorders, modelRows, meta); + mapper.loadModel(modelRows.getRows()); + return mapper; + } + + private static QuantileDiscretizerModelMapper getQuantileModelMapper( + List featureBinsCalculatorList, + TableSchema quantileInputSchema, + BinningPredictParamsBuilder paramsBuilder) { + + QuantileDiscretizerModelMapper mapper = new QuantileDiscretizerModelMapper( + new QuantileDiscretizerModelDataConverter().getModelSchema(), quantileInputSchema, + setQuantileModelParams(paramsBuilder)); + RowCollector modelRows = new RowCollector(); + QuantileDiscretizerTrainBatchOp.transformFeatureBinsToModel(featureBinsCalculatorList, modelRows); + mapper.loadModel(modelRows.getRows()); + return mapper; + } + + private static WoeModelMapper getWoeModelMapper(List featureBinsCalculatorList, + TableSchema woeInputSchema, + BinningPredictParamsBuilder paramsBuilder) { + RowCollector modelRows = new RowCollector(); + Params meta = new Params().set(WoeTrainBatchOp.SELECTED_COLS, paramsBuilder.assemblerSelectedCols); + transformFeatureBinsToModel(featureBinsCalculatorList, modelRows, meta, paramsBuilder.assemblerSelectedCols); + WoeModelMapper mapper = new WoeModelMapper(new WoeModelDataConverter().getModelSchema(), woeInputSchema, + setWoeModelParams(paramsBuilder, paramsBuilder.assemblerSelectedCols)); + mapper.loadModel(modelRows.getRows()); + return mapper; + } + + private static VectorAssemblerMapper getVectorAssemblerMapper(TableSchema inputSchema, + BinningPredictParamsBuilder paramsBuilder) { + return new VectorAssemblerMapper(inputSchema, + setVectorAssemblerParams(paramsBuilder, paramsBuilder.assemblerSelectedCols, inputSchema.getFieldNames())); + } + + private static Params setOneHotModelParams(BinningPredictParamsBuilder paramsBuilder) { + return new Params() + .merge(paramsBuilder.params) + .set(OneHotPredictParams.SELECTED_COLS, paramsBuilder.oneHotInputCols) + .set(OneHotPredictParams.OUTPUT_COLS, paramsBuilder.oneHotOutputCols) + .set(OneHotPredictParams.RESERVED_COLS, null); + } + + private static Params setQuantileModelParams(BinningPredictParamsBuilder paramsBuilder) { + Params params = new Params() + .merge(paramsBuilder.params) + .set(QuantileDiscretizerPredictParams.SELECTED_COLS, paramsBuilder.quantileInputCols) + .set(QuantileDiscretizerPredictParams.OUTPUT_COLS, paramsBuilder.quantileOutputCols) + .set(QuantileDiscretizerPredictParams.RESERVED_COLS, null); + return params; + } + + private static Params setWoeModelParams(BinningPredictParamsBuilder paramsBuilder, String[] woeSelectedCols) { + Params params = new Params() + .merge(paramsBuilder.params) + .set(WoePredictParams.SELECTED_COLS, woeSelectedCols) + .set(WoePredictParams.OUTPUT_COLS, null) + .set(WoePredictParams.RESERVED_COLS, null); + return params; + } + + private static Params setVectorAssemblerParams(BinningPredictParamsBuilder paramsBuilder, + String[] woeSelectedCols, + String[] binningInputCols) { + + Params params = new Params() + .merge(paramsBuilder.params) + .set(VectorAssemblerParams.SELECTED_COLS, woeSelectedCols) + .set(VectorAssemblerParams.RESERVED_COLS, binningInputCols); + switch (params.get(BinningPredictParams.HANDLE_INVALID)) { + case KEEP: + case SKIP: { + params.set(VectorAssemblerParams.HANDLE_INVALID, VectorAssemblerParams.HandleInvalidMethod.SKIP); + break; + } + case ERROR: { + params.set(VectorAssemblerParams.HANDLE_INVALID, VectorAssemblerParams.HandleInvalidMethod.ERROR); + break; + } + } + return params; + } + + public static Tuple2 , List > distinguishNumericDiscrete( + FeatureBinsCalculator[] featureBinsCalculators, + String[] selectedCols, + HashSet userDefinedNumeric, + HashSet userDefinedDiscrete) { + List numericFeatureBinsCalculator = new ArrayList <>(); + List discreteFeatureBinsCalculator = new ArrayList <>(); + + if (null != featureBinsCalculators) { + for (FeatureBinsCalculator featureBinsCalculator : featureBinsCalculators) { + if (TableUtil.findColIndex(selectedCols, featureBinsCalculator.getFeatureName()) >= 0) { + if (featureBinsCalculator.isNumeric()) { + numericFeatureBinsCalculator.add(featureBinsCalculator); + userDefinedNumeric.add(featureBinsCalculator.getFeatureName()); + } else { + discreteFeatureBinsCalculator.add(featureBinsCalculator); + userDefinedDiscrete.add(featureBinsCalculator.getFeatureName()); + } + } + } + } + + return Tuple2.of(numericFeatureBinsCalculator, discreteFeatureBinsCalculator); + } + + public static Tuple2 , List > distinguishNumericDiscrete( + String[] selectedCols, + TableSchema dataSchema) { + List numeric = new ArrayList <>(); + List discrete = new ArrayList <>(); + + for (String s : selectedCols) { + TypeInformation type = TableUtil.findColTypeWithAssert(dataSchema, s); + Preconditions.checkNotNull(type, "%s is not found in data", s); + if (TableUtil.isSupportedNumericType(type)) { + numeric.add(s); + } else { + discrete.add(s); + } + } + + return Tuple2.of(numeric, discrete); + } + + public static class BinningResultMapper extends Mapper { + + public BinningResultMapper(TableSchema dataSchema, Params params) { + super(dataSchema, params); + } + + @Override + protected void map(SlicedSelectedSample selection, SlicedResult result) throws Exception { + for (int i = 0; i < selection.length(); i++) { + result.set(i, selection.get(i)); + } + } + + @Override + protected Tuple4 [], String[]> prepareIoSchema( + TableSchema dataSchema, Params params) { + + //binning input + String[] binningInputCols = params.get(BINNING_INPUT_COLS); + String[] selectedCols = params.get(BinningPredictParams.SELECTED_COLS); + String[] reservedCols = params.get(BinningPredictParams.RESERVED_COLS); + String[] resultCols = params.get(BinningPredictParams.OUTPUT_COLS); + Encode encode = params.get(BinningPredictParams.ENCODE); + + reservedCols = reservedCols == null ? binningInputCols : reservedCols; + resultCols = resultCols == null ? selectedCols : resultCols; + + BinningPredictParamsBuilder paramsBuilder = new BinningPredictParamsBuilder(params, dataSchema); + + String[] binningResultSelectedCols = null; + switch (encode) { + case ASSEMBLED_VECTOR: + binningResultSelectedCols = new String[] { + dataSchema.getFieldName(dataSchema.getFieldNames().length - 1).get()}; + break; + case WOE: + case INDEX: + case VECTOR: + binningResultSelectedCols = paramsBuilder.assemblerSelectedCols; + break; + } + + TypeInformation[] resultColTypes = new TypeInformation[resultCols.length]; + switch (encode) { + case ASSEMBLED_VECTOR: + Arrays.fill(resultColTypes, AlinkTypes.SPARSE_VECTOR); + break; + case WOE: + Arrays.fill(resultColTypes, AlinkTypes.DOUBLE); + break; + case INDEX: + Arrays.fill(resultColTypes, AlinkTypes.LONG); + break; + case VECTOR: + Arrays.fill(resultColTypes, AlinkTypes.SPARSE_VECTOR); + break; + } + + return Tuple4.of(binningResultSelectedCols, + resultCols, + resultColTypes, + reservedCols); + } + } + + public static void transformFeatureBinsToModel(Iterable values, Collector out, + Params meta, + String[] featureNames) { + Long positiveTotal = null; + Long negativeTotal = null; + List > list = new ArrayList <>(); + for (FeatureBinsCalculator featureBinsCalculator : values) { + if (positiveTotal == null) { + positiveTotal = featureBinsCalculator.getPositiveTotal(); + Preconditions.checkNotNull(positiveTotal, "The label col of Binning is not set!"); + negativeTotal = featureBinsCalculator.getTotal() - positiveTotal; + } + String binFeatureName = featureBinsCalculator.getFeatureName(); + int featureIdx = TableUtil.findColIndex(featureNames, binFeatureName); + if (featureIdx == -1) { + featureIdx = TableUtil.findColIndex(featureNames, + binFeatureName + "_ONE_HOT"); + } + if (featureIdx == -1) { + featureIdx = TableUtil.findColIndex(featureNames, + binFeatureName + "_QUANTILE"); + } + if (null != featureBinsCalculator.bin.nullBin) { + list.add(getWoeModelTuple(featureBinsCalculator.bin.nullBin, featureIdx)); + } + if (null != featureBinsCalculator.bin.elseBin) { + list.add(getWoeModelTuple(featureBinsCalculator.bin.elseBin, featureIdx)); + } + for (Bins.BaseBin bin : featureBinsCalculator.bin.normBins) { + list.add(getWoeModelTuple(bin, featureIdx)); + } + } + if (null != meta) { + meta.set(WoeModelDataConverter.POSITIVE_TOTAL, positiveTotal) + .set(WoeModelDataConverter.NEGATIVE_TOTAL, negativeTotal); + } + new WoeModelDataConverter().save(Tuple2.of(meta, list), out); + } + + private static Tuple4 getWoeModelTuple(Bins.BaseBin bin, int featureIdx) { + Long total = bin.getTotal(); + Long positive = bin.getPositive(); + return Tuple4.of(featureIdx, String.valueOf(bin.getIndex()), null == total ? 0L : total, + null == positive ? 0L : positive); + } + + static ParamInfo BINNING_INPUT_COLS = ParamInfoFactory + .createParamInfo("origin_data_cols", String[].class) + .setDescription("origin data cols") + .build(); +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/CrossFeatureModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/CrossFeatureModelMapper.java index 5ce2afec7..52add9845 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/CrossFeatureModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/CrossFeatureModelMapper.java @@ -7,7 +7,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.mapper.ModelMapper; import com.alibaba.alink.common.utils.TableUtil; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/ExclusiveFeatureBundleModelDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/ExclusiveFeatureBundleModelDataConverter.java new file mode 100644 index 000000000..0bb91273e --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/ExclusiveFeatureBundleModelDataConverter.java @@ -0,0 +1,68 @@ +package com.alibaba.alink.operator.common.feature; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.model.RichModelDataConverter; + +import java.util.ArrayList; + +/** + * This converter can help serialize and deserialize the model data. + */ +public class ExclusiveFeatureBundleModelDataConverter extends RichModelDataConverter { + + public String[] efbColNames; + public TypeInformation[] efbColTypes; + + /** + * Constructor. + */ + public ExclusiveFeatureBundleModelDataConverter() { + } + + /** + * Get the additional column names. + */ + @Override + protected String[] initAdditionalColNames() { + return efbColNames; + } + + /** + * Get the additional column types. + */ + @Override + protected TypeInformation[] initAdditionalColTypes() { + return efbColTypes; + } + + /** + * Serialize the model data to "Tuple3, List>". + * + * @return The serialization result. + */ + @Override + public Tuple3 , Iterable > serializeModel(FeatureBundles bundles) { + return new Tuple3 <>( + new Params().set("bundles_json", bundles.toJson()), + new ArrayList (), + new ArrayList () + ); + } + + /** + * Deserialize the model data. + * + * @param meta The model meta data. + * @param data The model concrete data. + * @param additionData The additional data. + * @return The deserialized model data. + */ + @Override + public FeatureBundles deserializeModel(Params meta, Iterable data, Iterable additionData) { + return FeatureBundles.fromJson(meta.getString("bundles_json")); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/ExclusiveFeatureBundleModelInfo.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/ExclusiveFeatureBundleModelInfo.java new file mode 100644 index 000000000..ef48bdb28 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/ExclusiveFeatureBundleModelInfo.java @@ -0,0 +1,47 @@ +package com.alibaba.alink.operator.common.feature; + +import org.apache.flink.types.Row; + +import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +/** + * Summary of ExclusiveFeatureBundleModel. + */ +public class ExclusiveFeatureBundleModelInfo implements Serializable { + public FeatureBundles bundles; + public final ArrayList [] bundleIndexes; + + public ExclusiveFeatureBundleModelInfo(List rows) { + ExclusiveFeatureBundleModelDataConverter converter = new ExclusiveFeatureBundleModelDataConverter(); + bundles = converter.load(rows); + bundleIndexes = new ArrayList[bundles.numFeatures]; + for (int k = 0; k < bundles.numFeatures; k++) { + bundleIndexes[k] = new ArrayList <>(); + } + for (int i = 0; i < bundles.toEfbIndex.length; i++) { + bundleIndexes[bundles.toEfbIndex[i]].add(i); + } + } + + @Override + public String toString() { + StringBuilder sbd = new StringBuilder( + PrettyDisplayUtils.displayHeadline("ExclusiveFeatureBundleModelInfo", '-')); + sbd.append("Sparse Vector Dimension : ").append(bundles.dimVector) + .append(" , Number of Feature Bundles : ").append(bundles.numFeatures) + .append("\nFeature Schema String : ") + .append(PrettyDisplayUtils.display(bundles.schemaStr)) + .append("\nFeature Bundles Info : "); + for (int k = 0; k < bundles.numFeatures; k++) { + sbd.append("\nBundle " + k + " : ") + .append(PrettyDisplayUtils.displayList(bundleIndexes[k])); + } + sbd.append("\n"); + return sbd.toString(); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/ExclusiveFeatureBundleModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/ExclusiveFeatureBundleModelMapper.java new file mode 100644 index 000000000..6b4afc464 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/ExclusiveFeatureBundleModelMapper.java @@ -0,0 +1,54 @@ +package com.alibaba.alink.operator.common.feature; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.mapper.ModelMapper; +import com.alibaba.alink.params.feature.ExclusiveFeatureBundlePredictParams; +import scala.Array; + +import java.util.List; + +public class ExclusiveFeatureBundleModelMapper extends ModelMapper { + private FeatureBundles bundles; + + public ExclusiveFeatureBundleModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { + super(modelSchema, dataSchema, params); + } + + @Override + protected void map(SlicedSelectedSample selection, SlicedResult result) throws Exception { + Row row = bundles.map(selection.get(0)); + for (int i = 0; i < result.length(); i++) { + result.set(i, row.getField(i)); + } + } + + @Override + protected Tuple4 [], String[]> prepareIoSchema(TableSchema modelSchema, + TableSchema dataSchema, + Params params) { + int n = modelSchema.getFieldNames().length; + String[] resultColNames = new String[n - 2]; + TypeInformation [] resultColTypes = new TypeInformation[n - 2]; + + Array.copy(modelSchema.getFieldNames(), 2, resultColNames, 0, n - 2); + Array.copy(modelSchema.getFieldTypes(), 2, resultColTypes, 0, n - 2); + + return Tuple4.of( + new String[] {params.get(ExclusiveFeatureBundlePredictParams.SPARSE_VECTOR_COL)}, + resultColNames, + resultColTypes, + params.get(ExclusiveFeatureBundlePredictParams.RESERVED_COLS) + ); + } + + @Override + public void loadModel(List modelRows) { + ExclusiveFeatureBundleModelDataConverter converter = new ExclusiveFeatureBundleModelDataConverter(); + bundles = converter.load(modelRows); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/FeatureBundles.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/FeatureBundles.java new file mode 100644 index 000000000..1b4aa0fae --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/FeatureBundles.java @@ -0,0 +1,107 @@ +package com.alibaba.alink.operator.common.feature; + +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.linalg.SparseVector; +import com.alibaba.alink.common.linalg.VectorUtil; +import com.alibaba.alink.common.utils.JsonConverter; + +import java.io.Serializable; +import java.util.List; + +public class FeatureBundles implements Serializable { + int dimVector; + int numFeatures; + String schemaStr; + int[] toEfbIndex; + boolean[] isNumeric; + + public int getDimVector() { + return dimVector; + } + + public void setDimVector(int dimVector) { + this.dimVector = dimVector; + } + + public int getNumFeatures() { + return numFeatures; + } + + public void setNumFeatures(int numFeatures) { + this.numFeatures = numFeatures; + } + + public String getSchemaStr() { + return schemaStr; + } + + public void setSchemaStr(String schemaStr) { + this.schemaStr = schemaStr; + } + + public int[] getToEfbIndex() { + return toEfbIndex; + } + + public void setToEfbIndex(int[] toEfbIndex) { + this.toEfbIndex = toEfbIndex; + } + + public boolean[] getIsNumeric() { + return isNumeric; + } + + public void setIsNumeric(boolean[] isNumeric) { + this.isNumeric = isNumeric; + } + + public FeatureBundles() {} + + public FeatureBundles(int dim, List bundles) { + this.dimVector = dim; + this.numFeatures = bundles.size(); + this.toEfbIndex = new int[dimVector]; + this.isNumeric = new boolean[numFeatures]; + StringBuilder sbd = new StringBuilder(); + for (int k = 0; k < bundles.size(); k++) { + boolean isNumericCol = (bundles.get(k).length <= 1); + sbd.append((k > 0) ? "," : "") + .append("efb_").append(k).append(" ") + .append(isNumericCol ? "double" : "string"); + isNumeric[k] = isNumericCol; + } + this.schemaStr = sbd.toString(); + for (int k = 0; k < bundles.size(); k++) { + for (int idx : bundles.get(k)) { + toEfbIndex[idx] = k; + } + } + } + + public Row map(Object obj) { + Row row = new Row(this.numFeatures); + SparseVector vec = VectorUtil.getSparseVector(obj); + int[] indices = vec.getIndices(); + double[] values = vec.getValues(); + for (int i = 0; i < indices.length; i++) { + if (indices[i] < this.dimVector) { + int efbIndex = toEfbIndex[indices[i]]; + if (isNumeric[efbIndex]) { + row.setField(efbIndex, values[i]); + } else { + row.setField(efbIndex, String.valueOf(indices[i])); + } + } + } + return row; + } + + public String toJson() { + return JsonConverter.toJson(this); + } + + public static FeatureBundles fromJson(String json) { + return JsonConverter.fromJson(json, FeatureBundles.class); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/FeatureHasherMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/FeatureHasherMapper.java index 38693012f..c43e12838 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/FeatureHasherMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/FeatureHasherMapper.java @@ -6,7 +6,7 @@ import org.apache.flink.shaded.guava18.com.google.common.hash.HashFunction; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.mapper.Mapper; import com.alibaba.alink.common.utils.TableUtil; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/HashCrossFeatureMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/HashCrossFeatureMapper.java index bdfc1c57c..a86d0c563 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/HashCrossFeatureMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/HashCrossFeatureMapper.java @@ -5,7 +5,7 @@ import org.apache.flink.shaded.guava18.com.google.common.hash.HashFunction; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.mapper.MISOMapper; import com.alibaba.alink.params.feature.HashCrossFeatureParams; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/MultiHotModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/MultiHotModelMapper.java index 8fa038329..27146c104 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/MultiHotModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/MultiHotModelMapper.java @@ -7,7 +7,8 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; import com.alibaba.alink.common.exceptions.AkIllegalModelException; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.mapper.ModelMapper; @@ -30,12 +31,13 @@ public class MultiHotModelMapper extends ModelMapper { private static final long serialVersionUID = 7431062592310976413L; private MultiHotModelData model; - private String[] selectedCols; + private final String[] inputPredictColNames; private final HandleInvalid handleInvalid; private final Encode encode; private int offsetSize = 0; boolean enableElse = false; + public MultiHotModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { super(modelSchema, dataSchema, params); this.handleInvalid = params.get(MultiHotPredictParams.HANDLE_INVALID); @@ -43,12 +45,27 @@ public MultiHotModelMapper(TableSchema modelSchema, TableSchema dataSchema, Para if (handleInvalid.equals(HandleInvalid.KEEP)) { offsetSize = 1; } + if(params.contains(MultiHotPredictParams.SELECTED_COLS)) { + inputPredictColNames = params.get(MultiHotPredictParams.SELECTED_COLS); + }else{ + inputPredictColNames = null; + } } @Override public void loadModel(List modelRows) { this.model = new MultiHotModelDataConverter().load(modelRows); - this.enableElse = this.model.getEnableElse(selectedCols); + this.enableElse = this.model.getEnableElse(inputPredictColNames); + + if (null != inputPredictColNames) { + Set trainColSet = model.modelData.keySet(); + for (String predictColName : inputPredictColNames) { + if (!trainColSet.contains(predictColName)) { + throw new AkIllegalArgumentException( + "Column '" + predictColName + "' has not been precessed in OneHot model training."); + } + } + } } @Override @@ -59,12 +76,12 @@ protected Tuple4 [], String[]> prepareIo if (reservedCols == null) { reservedCols = dataSchema.getFieldNames(); } - this.selectedCols = params.get(MultiHotPredictParams.SELECTED_COLS); String[] outputCols = params.get(MultiHotPredictParams.OUTPUT_COLS); TypeInformation [] outputTypes = new TypeInformation [outputCols.length]; Arrays.fill(outputTypes, AlinkTypes.SPARSE_VECTOR); - return Tuple4.of(this.selectedCols, outputCols, outputTypes, reservedCols); + return Tuple4.of(params.get(MultiHotPredictParams.SELECTED_COLS), + outputCols, outputTypes, reservedCols); } @Override @@ -79,7 +96,7 @@ protected void map(SlicedSelectedSample selection, SlicedResult result) throws E } else if (encode.equals(Encode.VECTOR)) { for (int i = 0; i < selection.length(); ++i) { String str = (String) selection.get(i); - Tuple2 indices = getSingleIndicesAndSize(selectedCols[i], str); + Tuple2 indices = getSingleIndicesAndSize(inputPredictColNames[i], str); double[] vals = new double[indices.f1.length]; Arrays.fill(vals, 1.0); if (indices.f1.length != 0) { @@ -129,7 +146,7 @@ public Tuple2 getIndicesAndSize(SlicedSelectedSample selection) int cnt = 0; for (int i = 0; i < selection.length(); ++i) { String str = (String) selection.get(i); - Tuple2 t2 = getSingleIndicesAndSize(this.selectedCols[i], str); + Tuple2 t2 = getSingleIndicesAndSize(this.inputPredictColNames[i], str); for (int j = 0; j < t2.f1.length; ++j) { set.add(cnt + t2.f1[j]); } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/OneHotModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/OneHotModelMapper.java index cfbab3d59..2bd46ba61 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/OneHotModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/OneHotModelMapper.java @@ -7,20 +7,27 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; +import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; import com.alibaba.alink.common.exceptions.AkPreconditions; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import com.alibaba.alink.common.mapper.ModelMapper; import com.alibaba.alink.common.utils.Functional; import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.feature.OneHotTrainBatchOp; import com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelMapper.DiscreteParamsBuilder; import com.alibaba.alink.params.dataproc.HasHandleInvalid; import com.alibaba.alink.params.feature.HasEnableElse; import com.alibaba.alink.params.feature.OneHotPredictParams; +import com.alibaba.alink.params.feature.OneHotTrainParams; import com.alibaba.alink.params.shared.colname.HasSelectedCols; +import scala.Int; + import java.io.Serializable; +import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; @@ -30,6 +37,7 @@ public class OneHotModelMapper extends ModelMapper { private static final long serialVersionUID = -6192598346177373139L; OneHotMapperBuilder mapperBuilder; + private final String[] inputPredictColNames; /** * Deal with the abnormal cases. @@ -145,10 +153,16 @@ public static InvalidStrategy valueOf(boolean enableElse, */ public OneHotModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { super(modelSchema, dataSchema, params); + if (params.contains(OneHotPredictParams.SELECTED_COLS)) { + inputPredictColNames = params.get(OneHotPredictParams.SELECTED_COLS); + } else { + inputPredictColNames = null; + } } /** * get size of vector with all columns encoding. + * * @return size of vector. */ public int getSize() { @@ -165,6 +179,19 @@ public void loadModel(List modelRows) { OneHotModelData model = new OneHotModelDataConverter().load(modelRows); String[] trainColNames = model.modelData.meta.get(HasSelectedCols.SELECTED_COLS); + if (null != inputPredictColNames) { + HashSet trainColSet = new HashSet <>(); + for (String trainColName : trainColNames) { + trainColSet.add(trainColName); + } + for (String predictColName : inputPredictColNames) { + if (!trainColSet.contains(predictColName)) { + throw new AkIllegalArgumentException( + "Column '" + predictColName + "' has not been precessed in OneHot model training."); + } + } + } + //to be compatible with previous versions if (null == mapperBuilder.getSelectedCols()) { mapperBuilder.setSelectedCols(trainColNames); @@ -320,4 +347,25 @@ protected void map(SlicedSelectedSample selection, SlicedResult result) throws E mapperBuilder.map(selection, result); } + public static boolean isEnableElse(Params params) { + int[] thresholdArray; + + if (!params.contains(OneHotTrainParams.DISCRETE_THRESHOLDS_ARRAY)) { + thresholdArray = new int[] {params.get(OneHotTrainParams.DISCRETE_THRESHOLDS)}; + } else { + thresholdArray = Arrays.stream(params.get(OneHotTrainParams.DISCRETE_THRESHOLDS_ARRAY)).mapToInt( + Integer::intValue).toArray(); + } + + return isEnableElse(thresholdArray); + } + + public static boolean isEnableElse(int[] thresholdArray) { + for (int threshold : thresholdArray) { + if (threshold > 0) { + return true; + } + } + return false; + } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/QuantileDiscretizerModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/QuantileDiscretizerModelMapper.java index 79b0675f9..ccc0b76ec 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/QuantileDiscretizerModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/QuantileDiscretizerModelMapper.java @@ -8,7 +8,8 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.exceptions.AkPreconditions; @@ -20,11 +21,13 @@ import com.alibaba.alink.params.feature.HasEncodeWithoutWoe; import com.alibaba.alink.params.feature.QuantileDiscretizerPredictParams; import com.alibaba.alink.params.shared.colname.HasOutputCol; +import com.alibaba.alink.params.shared.colname.HasSelectedCols; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.stream.IntStream; @@ -37,9 +40,15 @@ public class QuantileDiscretizerModelMapper extends ModelMapper implements Cloneable { private static final long serialVersionUID = 5400967430347827818L; private DiscreteMapperBuilder mapperBuilder; + private final String[] inputPredictColNames; public QuantileDiscretizerModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { super(modelSchema, dataSchema, params); + if(params.contains(QuantileDiscretizerPredictParams.SELECTED_COLS)) { + inputPredictColNames = params.get(QuantileDiscretizerPredictParams.SELECTED_COLS); + }else{ + inputPredictColNames = null; + } } @Override @@ -47,6 +56,20 @@ public void loadModel(List modelRows) { QuantileDiscretizerModelDataConverter model = new QuantileDiscretizerModelDataConverter(); model.load(modelRows); + String[] trainColNames = model.meta.get(HasSelectedCols.SELECTED_COLS); + if (null != inputPredictColNames) { + HashSet trainColSet = new HashSet <>(); + for (String trainColName : trainColNames) { + trainColSet.add(trainColName); + } + for (String predictColName : inputPredictColNames) { + if (!trainColSet.contains(predictColName)) { + throw new AkIllegalArgumentException( + "Column '" + predictColName + "' has not been precessed in QuantileDiscretizer model training."); + } + } + } + mapperBuilder = new DiscreteMapperBuilder(params, getDataSchema()); for (int i = 0; i < mapperBuilder.paramsBuilder.selectedCols.length; i++) { diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/SelectorModelData.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/SelectorModelData.java new file mode 100644 index 000000000..c92d02585 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/SelectorModelData.java @@ -0,0 +1,9 @@ +package com.alibaba.alink.operator.common.feature; + +import com.alibaba.alink.common.utils.AlinkSerializable; + +public class SelectorModelData implements AlinkSerializable { + public int[] selectedIndices; + public String vectorColName; + public String[] vectorColNames; +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/SelectorModelDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/SelectorModelDataConverter.java new file mode 100644 index 000000000..bbc24b835 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/SelectorModelDataConverter.java @@ -0,0 +1,40 @@ +package com.alibaba.alink.operator.common.feature; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.model.SimpleModelDataConverter; +import com.alibaba.alink.common.utils.JsonConverter; + +import java.util.Collections; + +/** + * Save the hash randVectors for MinHash. + */ +public class SelectorModelDataConverter extends SimpleModelDataConverter { + + public SelectorModelDataConverter() { + } + + /** + * Serialize the model to "Tuple2>" + * + * @param modelData: selected col indices + */ + @Override + public Tuple2 > serializeModel(SelectorModelData modelData) { + return Tuple2.of(new Params(), Collections.singletonList(JsonConverter.toJson(modelData))); + } + + /** + * @param meta The model meta data. + * @param modelData: json + * @return + */ + @Override + public SelectorModelData deserializeModel(Params meta, Iterable modelData) { + String json = modelData.iterator().next(); + return JsonConverter.fromJson(json, SelectorModelData.class); + } + +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/TargetEncoderConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/TargetEncoderConverter.java new file mode 100644 index 000000000..9cef115bb --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/TargetEncoderConverter.java @@ -0,0 +1,49 @@ +package com.alibaba.alink.operator.common.feature; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.model.ModelDataConverter; + +import java.util.HashMap; +import java.util.List; + +public class TargetEncoderConverter + implements ModelDataConverter { + public static final String SEPARATOR = "_____"; + private String[] selectedCols = null; + + public TargetEncoderConverter() {} + + public TargetEncoderConverter(String[] selectedCols) { + this.selectedCols = selectedCols; + } + + @Override + public void save(Row modelData, Collector collector) { + collector.collect(modelData); + } + + @Override + public TargetEncoderModelData load(List rows) { + TargetEncoderModelData modelData = new TargetEncoderModelData(); + rows.forEach(modelData::setData); + return modelData; + } + + @Override + public TableSchema getModelSchema() { + StringBuilder sbd = new StringBuilder(); + sbd.append(selectedCols[0]); + int size = selectedCols.length; + for (int i = 1; i < size; i++) { + sbd.append(SEPARATOR+selectedCols[i]); + } + String[] resCol = new String[] {"colName", sbd.toString()}; + TypeInformation[] resType = new TypeInformation[] {Types.STRING, Types.STRING}; + return new TableSchema(resCol, resType); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/TargetEncoderModelData.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/TargetEncoderModelData.java new file mode 100644 index 000000000..98a02750a --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/TargetEncoderModelData.java @@ -0,0 +1,26 @@ +package com.alibaba.alink.operator.common.feature; + +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.utils.JsonConverter; + +import java.util.HashMap; + +public class TargetEncoderModelData { + private HashMap > modelData; + + public TargetEncoderModelData() { + modelData = new HashMap <>(); + } + + public HashMap getData(String key) { + return modelData.get(key); + } + + public void setData(Row data) { + String key = (String) data.getField(0); + String strMap = (String) data.getField(1); + this.modelData.put(key, JsonConverter.fromJson(strMap, HashMap.class)); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/TargetEncoderModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/TargetEncoderModelMapper.java new file mode 100644 index 000000000..6179ff740 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/TargetEncoderModelMapper.java @@ -0,0 +1,60 @@ +package com.alibaba.alink.operator.common.feature; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.mapper.ModelMapper; +import com.alibaba.alink.params.feature.TargetEncoderPredictParams; + +import java.util.Arrays; +import java.util.List; + +import static com.alibaba.alink.operator.common.feature.TargetEncoderConverter.SEPARATOR; + +public class TargetEncoderModelMapper extends ModelMapper { + + private String[] selectedCols; + private TargetEncoderModelData modelData; + int resSize; + + + public TargetEncoderModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { + super(modelSchema, dataSchema, params); + } + + @Override + protected void map(SlicedSelectedSample selection, SlicedResult result) throws Exception { + for (int i = 0; i < resSize; i++) { + String key = selection.get(i).toString(); + result.set(i, modelData.getData(selectedCols[i]).get(key)); + } + } + + @Override + protected Tuple4 [], String[]> prepareIoSchema(TableSchema modelSchema, + TableSchema dataSchema, + Params params) { + if (params.contains(TargetEncoderPredictParams.SELECTED_COLS)) { + selectedCols = params.get(TargetEncoderPredictParams.SELECTED_COLS); + } else { + selectedCols = modelSchema.getFieldNames()[1].split(SEPARATOR); + } + resSize = selectedCols.length; + String[] resCols = params.get(TargetEncoderPredictParams.OUTPUT_COLS); + if (resCols.length != resSize) { + throw new RuntimeException("Output column size must be equal to input column size."); + } + TypeInformation[] resTypes = new TypeInformation[resSize]; + Arrays.fill(resTypes, Types.DOUBLE); + return Tuple4.of(selectedCols, resCols, resTypes, params.get(TargetEncoderPredictParams.RESERVED_COLS)); + } + + @Override + public void loadModel(List modelRows) { + this.modelData = new TargetEncoderConverter().load(modelRows); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/WoeModelData.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/WoeModelData.java new file mode 100644 index 000000000..6eaaaf98b --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/WoeModelData.java @@ -0,0 +1,19 @@ +package com.alibaba.alink.operator.common.feature;/** + * @ClassName WoeModelData + * @description WoeModelData is + * @author lqb + * @date 2019/12/23 + */ + +import java.util.Map; + +/** + * @author lqb + * @ClassName WoeModelData + * @description WoeModelData is + * @date 2019/12/23 + */ +public class WoeModelData { + String[] selectedCols; + Map > featureBinIndexTotalMap; +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/WoeModelDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/WoeModelDataConverter.java new file mode 100644 index 000000000..403c4244e --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/WoeModelDataConverter.java @@ -0,0 +1,97 @@ +package com.alibaba.alink.operator.common.feature; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.ParamInfo; +import org.apache.flink.ml.api.misc.param.ParamInfoFactory; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.model.ModelDataConverter; +import com.alibaba.alink.operator.common.feature.binning.FeatureBinsUtil; +import com.alibaba.alink.params.shared.colname.HasSelectedCols; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class WoeModelDataConverter implements + ModelDataConverter >>, Map >> { + + private static final String[] MODEL_COL_NAMES = new String[] {"featureIndex", "enumValue", "binTotal", + "binPositiveTotal"}; + + private static final TypeInformation[] MODEL_COL_TYPES = new TypeInformation[] { + Types.LONG, Types.STRING, Types.LONG, Types.LONG}; + + private static final TableSchema MODEL_SCHEMA = new TableSchema(MODEL_COL_NAMES, MODEL_COL_TYPES); + + public static ParamInfo POSITIVE_TOTAL = ParamInfoFactory + .createParamInfo("positiveTotal", Long.class) + .setDescription("positiveTotal") + .setRequired() + .build(); + + public static ParamInfo NEGATIVE_TOTAL = ParamInfoFactory + .createParamInfo("negativeTotal", Long.class) + .setDescription("negativeTotal") + .setRequired() + .build(); + + @Override + public TableSchema getModelSchema() { + return MODEL_SCHEMA; + } + + @Override + public Map > load(List rows) { + Map > featureBinIndexTotalMap = new HashMap <>(); + String[] selectedCols = null; + Long positiveTotal = null; + Long negativeTotal = null; + for (Row row : rows) { + long colIndex = (Long) row.getField(0); + if (colIndex < 0L) { + Params params = Params.fromJson((String) row.getField(1)); + selectedCols = params.get(HasSelectedCols.SELECTED_COLS); + positiveTotal = params.get(POSITIVE_TOTAL); + negativeTotal = params.get(NEGATIVE_TOTAL); + break; + } + } + for (Row row : rows) { + long colIndex = (Long) row.getField(0); + if (colIndex >= 0L) { + String featureName = selectedCols[(int) colIndex]; + Map map = featureBinIndexTotalMap.get(featureName); + if (null != map) { + map.put((String) row.getField(1), FeatureBinsUtil + .calcWoe((long) row.getField(2), (long) row.getField(3), positiveTotal, negativeTotal)); + } else { + map = new HashMap <>(); + map.put((String) row.getField(1), FeatureBinsUtil + .calcWoe((long) row.getField(2), (long) row.getField(3), positiveTotal, negativeTotal)); + featureBinIndexTotalMap.put(featureName, map); + } + } + } + + return featureBinIndexTotalMap; + } + + @Override + public void save(Tuple2 >> modelData, + Collector collector) { + if (modelData.f0 != null) { + collector.collect(Row.of(-1L, modelData.f0.toJson(), null, null)); + } + modelData.f1.forEach(tuple -> { + collector.collect(Row.of(tuple.f0.longValue(), tuple.f1, tuple.f2, tuple.f3)); + }); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/WoeModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/WoeModelMapper.java new file mode 100644 index 000000000..daff2942f --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/WoeModelMapper.java @@ -0,0 +1,77 @@ +package com.alibaba.alink.operator.common.feature; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import com.alibaba.alink.common.mapper.ModelMapper; +import com.alibaba.alink.operator.batch.feature.WoeTrainBatchOp; +import com.alibaba.alink.params.finance.WoePredictParams; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +public class WoeModelMapper extends ModelMapper { + private static final long serialVersionUID = 2784537716011869646L; + private String[] selectedColNames; + private List > indexBinMap; + private final Double defaultWoe; + + public WoeModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { + super(modelSchema, dataSchema, params); + + defaultWoe = params.get(WoePredictParams.DEFAULT_WOE); + } + + @Override + protected Tuple4 [], String[]> prepareIoSchema(TableSchema modelSchema, + TableSchema dataSchema, + Params params) { + selectedColNames = params.get(WoePredictParams.SELECTED_COLS); + String[] outputColNames = params.get(WoePredictParams.OUTPUT_COLS); + if (outputColNames == null) { + outputColNames = selectedColNames; + } + Preconditions.checkArgument(outputColNames.length == selectedColNames.length, + "OutputCol length must be equal to selectedCol length"); + String[] reservedColNames = params.get(WoePredictParams.RESERVED_COLS); + + TypeInformation[] outputColTypes = new TypeInformation[selectedColNames.length]; + Arrays.fill(outputColTypes, Types.DOUBLE); + return Tuple4.of(selectedColNames, outputColNames, outputColTypes, reservedColNames); + } + + @Override + public void loadModel(List modelRows) { + Map > featureBinIndexWoeMap = new WoeModelDataConverter().load(modelRows); + indexBinMap = new ArrayList <>(); + for (String s : selectedColNames) { + Map map = featureBinIndexWoeMap.get(s); + Preconditions.checkNotNull(map, "Can not find %s in model!", s); + for (Map.Entry entry : map.entrySet()) { + if (Double.isNaN(entry.getValue())) { + map.put(entry.getKey(), defaultWoe); + } + } + indexBinMap.add(map); + } + } + + @Override + protected void map(SlicedSelectedSample selection, SlicedResult result) throws Exception { + for (int i = 0; i < selectedColNames.length; i++) { + Object bin = selection.get(i); + Double woe = indexBinMap.get(i).get(null == bin ? WoeTrainBatchOp.NULL_STR : bin.toString()); + woe = null == woe ? defaultWoe : woe; + Preconditions.checkArgument(!Double.isNaN(woe), + "Woe is not set or is Nan for %s, you can provide default woe!", bin); + result.set(i, woe); + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/binning/BinningModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/binning/BinningModelInfoBatchOp.java new file mode 100644 index 000000000..fe3f79a2f --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/binning/BinningModelInfoBatchOp.java @@ -0,0 +1,27 @@ +package com.alibaba.alink.operator.common.feature.binning; + +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; + +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; +import com.alibaba.alink.operator.common.feature.BinningModelInfo; + +import java.util.List; + +public class BinningModelInfoBatchOp + extends ExtractModelInfoBatchOp { + private static final long serialVersionUID = 1735133462550836751L; + + public BinningModelInfoBatchOp() { + this(null); + } + + public BinningModelInfoBatchOp(Params params) { + super(params); + } + + @Override + public BinningModelInfo createModelInfo(List rows) { + return new BinningModelInfo(rows); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/featurebuilder/BaseWindowStreamOp.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/featurebuilder/BaseWindowStreamOp.java index ca8c8dbe9..bfca65022 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/featurebuilder/BaseWindowStreamOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/featurebuilder/BaseWindowStreamOp.java @@ -29,7 +29,7 @@ import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.ReservedColsWithFirstInputSpec; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.utils.DataStreamConversionUtil; +import com.alibaba.alink.operator.stream.utils.DataStreamConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.common.feature.featurebuilder.FeatureClauseUtil.ClauseInfo; import com.alibaba.alink.operator.stream.StreamOperator; @@ -138,14 +138,11 @@ public Row map(Row value) throws Exception { .setMLEnvironmentId(getMLEnvironmentId()); if (this instanceof BaseOverWindowStreamOp) { - //rename ROW_TIME_COL_NAME to timeCol for watermark. + //rename ROW_TIME_COL_NAME to timeCol for watermark. ?? String[] tmpColNames = res.getColNames(); StringBuilder sbd = new StringBuilder(); for (String tmpColName : tmpColNames) { - if (timeCol.equals(tmpColName)) { - sbd.append(","); - sbd.append(ROW_TIME_COL_NAME).append(" as ").append(tmpColName); - } else if (tmpColName.equals(ROW_TIME_COL_NAME)) { + if (tmpColName.equals(ROW_TIME_COL_NAME)) { // pass } else { sbd.append(","); diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/featurebuilder/FeatureClauseOperator.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/featurebuilder/FeatureClauseOperator.java index ddee7b0aa..ef71d8e18 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/featurebuilder/FeatureClauseOperator.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/featurebuilder/FeatureClauseOperator.java @@ -3,7 +3,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.sql.builtin.agg.AvgUdaf; import com.alibaba.alink.common.sql.builtin.agg.BaseUdaf; import com.alibaba.alink.common.sql.builtin.agg.CountUdaf; @@ -150,15 +150,15 @@ public enum FeatureClauseOperator { MTABLE_AGG(AlinkTypes.M_TABLE, new MTableAgg(false)); - private final TypeInformation resType; - private final BaseUdaf calc; + private final TypeInformation resType; + BaseUdaf calc; - FeatureClauseOperator(TypeInformation resType, BaseUdaf calc) { + FeatureClauseOperator(TypeInformation resType, BaseUdaf calc) { this.resType = resType; this.calc = calc; } - public TypeInformation getResType() { + public TypeInformation getResType() { return resType; } @@ -166,4 +166,4 @@ public BaseUdaf getCalc() { return JsonConverter.fromJson(JsonConverter.toJson(calc), calc.getClass()); } - } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/featurebuilder/FeatureClauseUtil.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/featurebuilder/FeatureClauseUtil.java index 59c4d8afa..648f4364a 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/featurebuilder/FeatureClauseUtil.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/featurebuilder/FeatureClauseUtil.java @@ -5,7 +5,7 @@ import com.alibaba.alink.common.MLEnvironment; import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; -import com.alibaba.alink.common.sql.builtin.BuildInAggRegister; +import com.alibaba.alink.common.sql.builtin.BuiltInAggRegister; import com.alibaba.alink.common.sql.builtin.UdafName; import com.alibaba.alink.common.sql.builtin.agg.MTableAgg; import com.alibaba.alink.common.utils.TableUtil; @@ -23,10 +23,15 @@ public class FeatureClauseUtil { private static final String ERROR_MESSAGE = "expressions must op(incol) as outcol, eg, sum(f1) as outf1."; public static FeatureClause[] extractFeatureClauses(String exprStr) { + return extractFeatureClauses(exprStr, null, null); + } + + public static FeatureClause[] extractFeatureClauses(String exprStr, TableSchema tableSchema, String timeCol) { if (exprStr == null || exprStr.isEmpty()) { throw new AkIllegalOperatorParameterException("expressions must be set"); } String[] clauses = splitClause(exprStr); + FeatureClause[] featureClauses = new FeatureClause[clauses.length]; for (int i = 0; i < clauses.length; i++) { String[] opAndResCol = trimArray(clauses[i].split(" (?i)as ")); @@ -39,8 +44,16 @@ public static FeatureClause[] extractFeatureClauses(String exprStr) { if (opAndInput.length != 2) { throw new AkIllegalOperatorParameterException(ERROR_MESSAGE); } + + String funcName = opAndInput[0].trim().toUpperCase(); featureClause.op = FeatureClauseOperator .valueOf(opAndInput[0].trim().toUpperCase()); + if (funcName.equals("MTABLE_AGG")) { + featureClause.op.calc = new MTableAgg(false, + getMTableSchema(clauses[i], tableSchema), timeCol); + + //featureClause.op.calc.createAccumulatorAndSet(); + } String[] inputContent = opAndInput[1].split("\\)"); if (inputContent.length != 0) { if (inputContent.length != 1) { @@ -54,8 +67,8 @@ public static FeatureClause[] extractFeatureClauses(String exprStr) { for (int j = 1; j < inputColAndParams.length; j++) { String temp = inputColAndParams[j].trim(); int tempSize = temp.length(); - if (temp.charAt(0) == "\"" .charAt(0) && temp.charAt(tempSize - 1) == "\"" .charAt(0) || - temp.charAt(0) == "\'" .charAt(0) && temp.charAt(tempSize - 1) == "\'" .charAt(0)) { + if (temp.charAt(0) == "\"".charAt(0) && temp.charAt(tempSize - 1) == "\"".charAt(0) || + temp.charAt(0) == "\'".charAt(0) && temp.charAt(tempSize - 1) == "\'".charAt(0)) { featureClause.inputParams[j - 1] = inputColAndParams[j].trim().substring(1, tempSize - 1); } else { featureClause.inputParams[j - 1] = inputColAndParams[j].trim(); @@ -105,9 +118,8 @@ private static String[] trimArray(String[] inStr) { static { aggHideTimeCol.add(UdafName.LAST_DISTINCT.name); - aggHideTimeCol.add(UdafName.LAST_DISTINCT.name + BuildInAggRegister.CONSIDER_NULL_EXTEND); - aggHideTimeCol.add(UdafName.LAST_VALUE.name); - aggHideTimeCol.add(UdafName.LAST_VALUE.name + BuildInAggRegister.CONSIDER_NULL_EXTEND); + aggHideTimeCol.add(UdafName.LAST_DISTINCT.name + BuiltInAggRegister.CONSIDER_NULL_EXTEND); + aggHideTimeCol.add(UdafName.LAST_VALUE_CONSIDER_NULL.name); aggHideTimeCol.add(UdafName.SUM_LAST.name); groupWindowTimeCol.add("TUMBLE_START"); @@ -290,6 +302,7 @@ public static void buildOperatorClause(String[] operatorFunc, String[] operators } else if (operatorFunc[clauseIndex].startsWith("MTABLE_AGG")) { operators[clauseIndex] = registMTableAgg(operator, operatorFunc[clauseIndex], env, tableSchema, timeCol); } else if (operatorFunc[clauseIndex].equals("LAST_VALUE")) { + //last value name is last_value_impl, so need transform name. String[] components = operator.split("\\("); String[] components2 = components[1].split("\\)"); operators[clauseIndex] = UdafName.LAST_VALUE.name + "(" + @@ -303,7 +316,7 @@ public static String registMTableAgg(String clause, String operatorFunc, MLEnvironment mlEnv, TableSchema tableSchema, String timeCol) { String aggName = "mtable_agg_" + UUID.randomUUID().toString().replace("-", ""); - if ("MTABLE_AGG" .equals(operatorFunc)) { + if ("MTABLE_AGG".equals(operatorFunc)) { mlEnv.getStreamTableEnvironment().registerFunction(aggName, new MTableAgg(false, getMTableSchema(clause, tableSchema), timeCol)); } else { diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelMapper.java index ad8015713..512838ca1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelMapper.java @@ -6,7 +6,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.linalg.Vector; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/finance/ScorePredictMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/finance/ScorePredictMapper.java new file mode 100644 index 000000000..3950ef3b5 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/finance/ScorePredictMapper.java @@ -0,0 +1,245 @@ +package com.alibaba.alink.operator.common.finance; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.MatVecOp; +import com.alibaba.alink.common.linalg.SparseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.linalg.VectorUtil; +import com.alibaba.alink.common.mapper.ModelMapper; +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.common.dataproc.vector.VectorAssemblerMapper; +import com.alibaba.alink.operator.common.linear.FeatureLabelUtil; +import com.alibaba.alink.operator.common.linear.LinearModelData; +import com.alibaba.alink.operator.common.linear.LinearModelDataConverter; +import com.alibaba.alink.params.classification.LinearModelMapperParams; +import com.alibaba.alink.params.finance.ScorePredictParams; +import com.alibaba.alink.params.mapper.RichModelMapperParams; +import com.alibaba.alink.params.shared.colname.HasFeatureCols; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class ScorePredictMapper extends ModelMapper { + private static final long serialVersionUID = -6096135125528711852L; + private int vectorColIndex = -1; + private LinearModelData model; + private int[] featureIdx; + private int featureN; + private boolean calculateScore; + private boolean calculateScorePerFeature; + private boolean calculateDetail; + private int[] selectedIndices; + + public ScorePredictMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { + super(modelSchema, dataSchema, params); + //from LinearModelMapper + if (null != params) { + String vectorColName = params.get(LinearModelMapperParams.VECTOR_COL); + if (null != vectorColName && vectorColName.length() != 0) { + this.vectorColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), + vectorColName); + } + } + } + + protected Double predictScore(Vector aVector) { + return MatVecOp.dot(aVector, model.coefVector); + } + + protected double[] predictBinsScore(Vector aVector) { + if (aVector instanceof SparseVector) { + double[] values = ((SparseVector) aVector).getValues(); + int[] indices = ((SparseVector) aVector).getIndices(); + int dim = values.length; + double[] scores = new double[dim]; + for (int i = 0; i < dim; i++) { + scores[i] = values[i] * model.coefVector.get(indices[i]); + } + return scores; + } else { + double[] values = ((DenseVector) aVector).getData(); + int dim = values.length; + double[] score = new double[dim]; + for (int i = 0; i < dim; i++) { + score[i] = values[i] * model.coefVector.get(i); + } + return score; + } + } + + protected Object predictResult(Vector aVector) throws Exception { + double dotValue = MatVecOp.dot(aVector, model.coefVector); + double prob = Math.exp(-dotValue); + if (Double.isNaN(prob) || Double.isInfinite(prob)) { + prob = 0; + } else { + prob = 1.0 / (1.0 + prob); + } + return prob; + } + + //mapper operation. + @Override + public void loadModel(List modelRows) { + LinearModelDataConverter linearModelDataConverter + = new LinearModelDataConverter(LinearModelDataConverter.extractLabelType(super.getModelSchema())); + this.model = linearModelDataConverter.load(modelRows); + if (vectorColIndex == -1) { + TableSchema dataSchema = getDataSchema(); + if (this.model.featureNames != null) { + this.featureN = this.model.featureNames.length; + this.featureIdx = new int[this.featureN]; + String[] predictTableColNames = dataSchema.getFieldNames(); + for (int i = 0; i < this.featureN; i++) { + this.featureIdx[i] = TableUtil.findColIndexWithAssertAndHint(predictTableColNames, + this.model.featureNames[i]); + } + String[] selectedCols = params.get(HasFeatureCols.FEATURE_COLS); + selectedIndices = TableUtil.findColIndicesWithAssertAndHint(selectedCols, this.model.featureNames); + } else { + vectorColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), + model.vectorColName); + } + } + + } + + @Override + protected void map(SlicedSelectedSample selection, SlicedResult result) throws Exception { + //if has intercept, then add the intercept. + Vector aVector = getFeatureVector(selection, model.hasInterceptItem, this.featureN, + this.featureIdx, + this.vectorColIndex, model.vectorSize); + //here do not include linear model mapper, because the predictResultDetail in linear model mapper is protected. + double[] scores = predictBinsScore(aVector); + double sum = 0; + for (double v : scores) { + sum += v; + } + int index = 0; + if (calculateDetail) { + result.set(index, predictResultDetail(sum)); + index++; + } + if (calculateScore) { + result.set(index, sum); + index++; + } + if (calculateScorePerFeature) { + //whether need to use the feature index? + for (int i = 1; i < scores.length; i++) { + if (selectedIndices != null) { + result.set(index + selectedIndices[i - 1], scores[i]); + } else { + result.set(index + i - 1, scores[i]); + } + } + } + } + + @Override + protected Tuple4 [], String[]> prepareIoSchema(TableSchema modelSchema, + TableSchema dataSchema, + Params params) { + String[] reservedColNames = params.get(RichModelMapperParams.RESERVED_COLS); + calculateScore = params.contains(ScorePredictParams.PREDICTION_SCORE_COL); + calculateDetail = params.contains(ScorePredictParams.PREDICTION_DETAIL_COL); + calculateScorePerFeature = params.get(ScorePredictParams.CALCULATE_SCORE_PER_FEATURE); + List helperColNames = new ArrayList <>(); + List helperColTypes = new ArrayList <>(); + if (calculateDetail) { + String predDetailColName = params.get(ScorePredictParams.PREDICTION_DETAIL_COL); + helperColNames.add(predDetailColName); + helperColTypes.add(Types.STRING); + } + if (calculateScore) { + String predictionScore = params.get(ScorePredictParams.PREDICTION_SCORE_COL); + helperColNames.add(predictionScore); + helperColTypes.add(Types.DOUBLE); + } + if (calculateScorePerFeature) { + String[] predictionScorePerFeature = params.get(ScorePredictParams.PREDICTION_SCORE_PER_FEATURE_COLS); + helperColNames.addAll(Arrays.asList(predictionScorePerFeature)); + for (String aPredictionScorePerFeature : predictionScorePerFeature) { + helperColTypes.add(Types.DOUBLE); + } + } + + return Tuple4.of(this.getDataSchema().getFieldNames(), + helperColNames.toArray(new String[0]), + helperColTypes.toArray(new TypeInformation[0]), + reservedColNames); + } + + + // linear model mapper function + protected String predictResultDetail(double sum) { + Double[] result = predictWithProb(sum); + Map detail = new HashMap <>(1); + int labelSize = model.labelValues.length; + for (int i = 0; i < labelSize; ++i) { + detail.put(model.labelValues[i].toString(), result[i].toString()); + } + return JsonConverter.toJson(detail); + } + + private Double[] predictWithProb(double sum) { + double prob = sigmoid(sum); + return new Double[] {prob, 1 - prob}; + + } + + private double sigmoid(double val) { + return 1 - 1.0 / (1.0 + Math.exp(val)); + } + + /** + * Retrieve the feature vector from the input row data. + */ + private static Vector getFeatureVector(SlicedSelectedSample selection, boolean hasInterceptItem, int featureN, + int[] featureIdx, + int vectorColIndex, Integer vectorSize) { + Vector aVector; + if (vectorColIndex != -1) { + Vector vec = VectorUtil.getVector(selection.get(vectorColIndex)); + if (vec instanceof SparseVector) { + SparseVector tmp = (SparseVector) vec; + if (null != vectorSize) { + tmp.setSize(vectorSize); + } + aVector = hasInterceptItem ? tmp.prefix(1.0) : tmp; + } else { + DenseVector tmp = (DenseVector) vec; + aVector = hasInterceptItem ? tmp.prefix(1.0) : tmp; + } + } else { + if (hasInterceptItem) { + Object[] objs = new Object[featureN + 1]; + objs[0] = 1.0; + for (int i = 0; i < featureN; i++) { + objs[1 + i] = VectorUtil.getVector(selection.get(featureIdx[i])); + } + aVector = (Vector) VectorAssemblerMapper.assembler(objs); + } else { + Object[] objs = new Object[featureN]; + for (int i = 0; i < featureN; i++) { + objs[i] = VectorUtil.getVector(selection.get(featureIdx[i])); + } + aVector = (Vector) VectorAssemblerMapper.assembler(objs); + } + } + return aVector; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/finance/ScorecardModelInfo.java b/core/src/main/java/com/alibaba/alink/operator/common/finance/ScorecardModelInfo.java new file mode 100644 index 000000000..4e7658f82 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/finance/ScorecardModelInfo.java @@ -0,0 +1,381 @@ +package com.alibaba.alink.operator.common.finance; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import com.alibaba.alink.common.lazy.BasePMMLModelInfo; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.finance.ScorecardTrainBatchOp; +import com.alibaba.alink.operator.common.feature.BinningModelDataConverter; +import com.alibaba.alink.operator.common.feature.binning.Bins; +import com.alibaba.alink.operator.common.feature.binning.FeatureBinsCalculator; +import com.alibaba.alink.operator.common.feature.binning.FeatureBinsUtil; +import com.alibaba.alink.operator.common.linear.LinearModelData; +import com.alibaba.alink.operator.common.linear.LinearModelDataConverter; +import com.alibaba.alink.operator.common.linear.LinearRegressorModelInfo; +import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils; +import com.alibaba.alink.params.finance.BinningPredictParams; +import com.alibaba.alink.params.finance.ScorecardTrainParams; +import com.alibaba.alink.pipeline.ModelExporterUtils; +import com.alibaba.alink.pipeline.PipelineStageBase; +import com.alibaba.alink.pipeline.feature.BinningModel; +import com.alibaba.alink.pipeline.finance.ScoreModel; +import org.dmg.pmml.Application; +import org.dmg.pmml.CompoundPredicate; +import org.dmg.pmml.DataDictionary; +import org.dmg.pmml.DataField; +import org.dmg.pmml.DataType; +import org.dmg.pmml.FieldName; +import org.dmg.pmml.Header; +import org.dmg.pmml.InvalidValueTreatmentMethod; +import org.dmg.pmml.MiningField; +import org.dmg.pmml.MiningFunction; +import org.dmg.pmml.MiningSchema; +import org.dmg.pmml.OpType; +import org.dmg.pmml.Output; +import org.dmg.pmml.OutputField; +import org.dmg.pmml.PMML; +import org.dmg.pmml.SimplePredicate; +import org.dmg.pmml.Timestamp; +import org.dmg.pmml.True; +import org.dmg.pmml.regression.NumericPredictor; +import org.dmg.pmml.regression.RegressionModel; +import org.dmg.pmml.regression.RegressionTable; +import org.dmg.pmml.scorecard.Attribute; +import org.dmg.pmml.scorecard.Characteristic; +import org.dmg.pmml.scorecard.Characteristics; +import org.dmg.pmml.scorecard.Scorecard; + +import java.io.Serializable; +import java.text.SimpleDateFormat; +import java.util.ArrayList; +import java.util.Date; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +public class ScorecardModelInfo implements BasePMMLModelInfo, Serializable { + private Map >> categoricalScores; + private Map >> numericScores; + private Map splitArrays; + private boolean isLeftOpen; + private Double intercept; + private String modelName; + + private LinearRegressorModelInfo linearModelInfo; + + public ScorecardModelInfo(List modelRows, TableSchema modelSchema) { + List binningModelRows = null; + List linearModelRows = null; + + List , TableSchema, List >> deserialized = + ModelExporterUtils.loadStagesFromPipelineModel(modelRows, modelSchema); + + Params scoreCardParams = deserialized.get(0).f0.getParams(); + + for (Tuple3 , TableSchema, List > d : deserialized) { + if (d.f0 instanceof BinningModel) { + binningModelRows = d.f2; + } + if (d.f0 instanceof ScoreModel) { + linearModelRows = d.f2; + } + } + + if (null == binningModelRows) { + linearModelInfo = new LinearRegressorModelInfo(linearModelRows); + } else { + categoricalScores = new HashMap <>(); + numericScores = new HashMap <>(); + splitArrays = new HashMap <>(); + Preconditions.checkNotNull(linearModelRows); + LinearModelData linearModelData = new LinearModelDataConverter().load(linearModelRows); + modelName = linearModelData.modelName; + intercept = linearModelData.coefVector.get(0); + + List featureBinsCalculatorList + = new BinningModelDataConverter().load(binningModelRows); + ScorecardTrainBatchOp.ScorecardTransformData transformData + = new ScorecardTrainBatchOp.ScorecardTransformData( + true, + scoreCardParams.get(ScorecardTrainParams.SELECTED_COLS), + BinningPredictParams.Encode.WOE.equals(scoreCardParams.get(BinningPredictParams.ENCODE)), + scoreCardParams.get(ScorecardTrainParams.DEFAULT_WOE), + true); + + transformData.scaledModel = ScorecardTrainBatchOp.FeatureBinsToScorecard.getModelData(linearModelData, + transformData); + transformData.unscaledModel = transformData.scaledModel; + if (!transformData.isWoe) { + Map nameBinCountMap = new HashMap <>(); + featureBinsCalculatorList.forEach(featureBinsCalculator -> + nameBinCountMap.put(featureBinsCalculator.getFeatureName(), + FeatureBinsUtil.getBinEncodeVectorSize(featureBinsCalculator))); + ScorecardTrainBatchOp.FeatureBinsToScorecard.initializeStartIndex(transformData, nameBinCountMap); + } + if (null != linearModelData.featureNames) { + transformData.stepwiseSelectedCols = ScorecardTrainBatchOp.trimCols(linearModelData.featureNames, + ScorecardTrainBatchOp.BINNING_OUTPUT_COL); + } else { + transformData.stepwiseSelectedCols = transformData.selectedCols; + } + + for (FeatureBinsCalculator featureBinsCalculator : featureBinsCalculatorList) { + transformData.featureIndex = TableUtil.findColIndex(transformData.stepwiseSelectedCols, + featureBinsCalculator.getFeatureName()); + if (transformData.featureIndex < 0) { + continue; + } + featureBinsCalculator.splitsArrayToInterval(); + List > map = new ArrayList <>(); + if (null != featureBinsCalculator.bin.normBins) { + for (Bins.BaseBin bin : featureBinsCalculator.bin.normBins) { + map.add(Tuple2.of(bin.getValueStr(featureBinsCalculator.getColType()), + getBinScore(transformData, bin, featureBinsCalculator))); + } + } + if (null != featureBinsCalculator.bin.nullBin) { + map.add(Tuple2.of(FeatureBinsUtil.NULL_LABEL, getBinScore(transformData, + featureBinsCalculator.bin.nullBin, featureBinsCalculator))); + } + if (null != featureBinsCalculator.bin.elseBin) { + map.add(Tuple2.of(FeatureBinsUtil.ELSE_LABEL, getBinScore(transformData, + featureBinsCalculator.bin.elseBin, featureBinsCalculator))); + } + if (featureBinsCalculator.isNumeric()) { + numericScores.put(featureBinsCalculator.getFeatureName(), map); + splitArrays.put(featureBinsCalculator.getFeatureName(), featureBinsCalculator.getSplitsArray()); + isLeftOpen = featureBinsCalculator.getLeftOpen(); + } else { + categoricalScores.put(featureBinsCalculator.getFeatureName(), map); + } + } + } + } + + private static double getBinScore(ScorecardTrainBatchOp.ScorecardTransformData transformData, + Bins.BaseBin bin, + FeatureBinsCalculator featureBinsCalculator) { + int binIndex = bin.getIndex().intValue(); + String featureColName = featureBinsCalculator.getFeatureName(); + + int linearModelCoefIdx = ScorecardTrainBatchOp.FeatureBinsToScorecard.findLinearModelCoefIdx(featureColName, + binIndex, transformData); + Preconditions.checkArgument(linearModelCoefIdx >= 0); + return transformData.isScaled ? FeatureBinsUtil.keepGivenDecimal( + ScorecardTrainBatchOp.FeatureBinsToScorecard + .getModelValue(linearModelCoefIdx, bin.getWoe(), transformData.scaledModel, transformData.isWoe, + transformData.defaultWoe), 3) : null; + } + + @Override + public String toString() { + StringBuilder sbd = new StringBuilder(PrettyDisplayUtils.displayHeadline("ScorecardModelInfo", '-')); + if (null != linearModelInfo) { + sbd.append(linearModelInfo.toString()); + } else { + int size = categoricalScores.values().stream().mapToInt(m -> m.size()).sum() + + numericScores.values().stream().mapToInt(m -> m.size()).sum() + 1; + String[][] table = new String[size][3]; + int cnt = 0; + for (Map.Entry >> entry : categoricalScores.entrySet()) { + table[cnt][0] = entry.getKey(); + for (Tuple2 entry1 : entry.getValue()) { + if (table[cnt][0] == null) { + table[cnt][0] = ""; + } + table[cnt][1] = entry1.f0; + table[cnt++][2] = entry1.f1.toString(); + } + } + for (Map.Entry >> entry : numericScores.entrySet()) { + table[cnt][0] = entry.getKey(); + for (Tuple2 entry1 : entry.getValue()) { + if (table[cnt][0] == null) { + table[cnt][0] = ""; + } + table[cnt][1] = entry1.f0; + table[cnt++][2] = entry1.f1.toString(); + } + } + table[cnt][0] = "intercept"; + table[cnt][1] = ""; + table[cnt][2] = intercept.toString(); + sbd.append(PrettyDisplayUtils.displayTable(table, table.length, 3, null, + new String[] {"featureName", "FeatureBin", "BinScore"}, null, 20, 3)); + } + return sbd.toString(); + } + + @Override + public PMML toPMML() { + String version = this.getClass().getPackage().getImplementationVersion(); + Application app = new Application().setVersion(version); + Timestamp timestamp = new Timestamp() + .addContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss", Locale.US).format(new Date())); + Header header = new Header() + .setApplication(app) + .setTimestamp(timestamp); + PMML pmml = new PMML("4.2", header, null); + String targetName = "prediction_score"; + FieldName outputField = FieldName.create(targetName); + Output output = new Output().addOutputFields(new OutputField(outputField, OpType.CONTINUOUS, DataType.DOUBLE) + .setTargetField(outputField)); + + if (null != linearModelInfo) { + String description = "Linear Model"; + String modelName = linearModelInfo.getModelName(); + pmml.getHeader().setDescription(description); + double[] weights = linearModelInfo.getWeight().getData(); + String[] featureNames = linearModelInfo.getFeatureNames(); + FieldName[] fields = new FieldName[weights.length - 1]; + DataDictionary dataDictionary = new DataDictionary(); + MiningSchema miningSchema = new MiningSchema(); + + RegressionTable regressionTable = new RegressionTable(weights[0]); + RegressionModel regressionModel = new RegressionModel() + .setMiningFunction(MiningFunction.REGRESSION) + .setMiningSchema(miningSchema) + .setModelName(modelName) + //.setOutput(output) + .addRegressionTables(regressionTable); + + for (int i = 0; i < weights.length - 1; i++) { + fields[i] = FieldName.create(featureNames[i]); + dataDictionary.addDataFields(new DataField(fields[i], OpType.CONTINUOUS, DataType.DOUBLE)); + miningSchema.addMiningFields(new MiningField(fields[i]) + .setUsageType(MiningField.UsageType.ACTIVE) + .setInvalidValueTreatment(InvalidValueTreatmentMethod.AS_MISSING)); + regressionTable.addNumericPredictors(new NumericPredictor(fields[i], weights[i + 1]).setExponent(1)); + } + + // for completeness add target field + dataDictionary.addDataFields(new DataField(outputField, OpType.CONTINUOUS, DataType.DOUBLE)); + miningSchema.addMiningFields(new MiningField(outputField).setUsageType(MiningField.UsageType.TARGET)); + + dataDictionary.setNumberOfFields(dataDictionary.getDataFields().size()); + + pmml.setDataDictionary(dataDictionary); + pmml.addModels(regressionModel); + } else { + String description = "Scorecard Model"; + pmml.getHeader().setDescription(description); + DataDictionary dataDictionary = new DataDictionary(); + MiningSchema miningSchema = new MiningSchema(); + Characteristics characteristics = new Characteristics(); + + Scorecard scorecard = new Scorecard() + .setMiningFunction(MiningFunction.REGRESSION) + .setMiningSchema(miningSchema) + .setModelName("Scorecard") + .setAlgorithmName(modelName) + //.setOutput(output) + .setUseReasonCodes(false) + .setInitialScore(intercept) + .setCharacteristics(characteristics); + + for (Map.Entry >> entry : categoricalScores.entrySet()) { + FieldName fieldName = FieldName.create(entry.getKey()); + dataDictionary.addDataFields(new DataField(fieldName, OpType.CATEGORICAL, DataType.STRING)); + miningSchema.addMiningFields(new MiningField(fieldName) + .setUsageType(MiningField.UsageType.ACTIVE) + .setInvalidValueTreatment(InvalidValueTreatmentMethod.AS_MISSING)); + Characteristic characteristic = new Characteristic().setName(entry.getKey() + "_score"); + for (Tuple2 t : entry.getValue()) { + Attribute attribute = new Attribute().setPartialScore(t.f1); + if (t.f0.equals(FeatureBinsUtil.NULL_LABEL)) { + attribute.setPredicate(new SimplePredicate() + .setField(fieldName) + .setOperator(SimplePredicate.Operator.IS_MISSING)); + } else if (t.f0.equals(FeatureBinsUtil.ELSE_LABEL)) { + attribute.setPredicate(new True()); + } else { + String[] values = t.f0.split(Bins.JOIN_DELIMITER); + if (values.length == 1) { + attribute.setPredicate(new SimplePredicate() + .setField(fieldName) + .setOperator(SimplePredicate.Operator.EQUAL) + .setValue(values[0])); + } else { + CompoundPredicate predicate = new CompoundPredicate().setBooleanOperator( + CompoundPredicate.BooleanOperator.OR); + attribute.setPredicate(predicate); + for (String s : values) { + predicate.addPredicates(new SimplePredicate() + .setField(fieldName) + .setOperator(SimplePredicate.Operator.EQUAL) + .setValue(s)); + } + } + } + characteristic.addAttributes(attribute); + } + characteristics.addCharacteristics(characteristic); + } + for (Map.Entry >> entry : numericScores.entrySet()) { + FieldName fieldName = FieldName.create(entry.getKey()); + dataDictionary.addDataFields(new DataField(fieldName, OpType.CONTINUOUS, DataType.DOUBLE)); + miningSchema.addMiningFields(new MiningField(fieldName) + .setUsageType(MiningField.UsageType.ACTIVE) + .setInvalidValueTreatment(InvalidValueTreatmentMethod.AS_MISSING)); + Characteristic characteristic = new Characteristic().setName(entry.getKey() + "_score"); + Number[] list = splitArrays.get(entry.getKey()); + SimplePredicate.Operator first = isLeftOpen ? SimplePredicate.Operator.GREATER_THAN : SimplePredicate + .Operator.GREATER_OR_EQUAL; + SimplePredicate.Operator second = isLeftOpen ? SimplePredicate.Operator.LESS_OR_EQUAL : + SimplePredicate.Operator.LESS_THAN; + + for (int i = 0; i < entry.getValue().size(); i++) { + Tuple2 t = entry.getValue().get(i); + Attribute attribute = new Attribute().setPartialScore(t.f1); + if (t.f0.equals(FeatureBinsUtil.NULL_LABEL)) { + attribute.setPredicate(new SimplePredicate() + .setField(fieldName) + .setOperator(SimplePredicate.Operator.IS_MISSING)); + } else { + if (i == 0) { + attribute.setPredicate(new SimplePredicate() + .setField(fieldName) + .setOperator(second) + .setValue(list[i].toString())); + } else if (i == list.length) { + attribute.setPredicate(new SimplePredicate() + .setField(fieldName) + .setOperator(first) + .setValue(list[i - 1].toString())); + } else { + CompoundPredicate predicate = new CompoundPredicate().setBooleanOperator( + CompoundPredicate.BooleanOperator.AND); + attribute.setPredicate(predicate); + predicate.addPredicates(new SimplePredicate() + .setField(fieldName) + .setOperator(first) + .setValue(list[i - 1].toString())) + .addPredicates(new SimplePredicate() + .setField(fieldName) + .setOperator(second) + .setValue(list[i].toString())); + } + } + characteristic.addAttributes(attribute); + } + characteristics.addCharacteristics(characteristic); + } + + // for completeness add target field + dataDictionary.addDataFields(new DataField(outputField, OpType.CONTINUOUS, DataType.DOUBLE)); + miningSchema.addMiningFields(new MiningField(outputField).setUsageType(MiningField.UsageType.TARGET)); + + dataDictionary.setNumberOfFields(dataDictionary.getDataFields().size()); + + pmml.setDataDictionary(dataDictionary); + pmml.addModels(scorecard); + } + return pmml; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/finance/ScorecardModelInfoBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/finance/ScorecardModelInfoBatchOp.java new file mode 100644 index 000000000..57023c2dd --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/finance/ScorecardModelInfoBatchOp.java @@ -0,0 +1,25 @@ +package com.alibaba.alink.operator.common.finance; + +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; + +import com.alibaba.alink.operator.batch.utils.ExtractModelInfoBatchOp; + +import java.util.List; + +public class ScorecardModelInfoBatchOp extends ExtractModelInfoBatchOp { + private static final long serialVersionUID = 1735133462550836751L; + + public ScorecardModelInfoBatchOp() { + this(null); + } + + public ScorecardModelInfoBatchOp(Params params) { + super(params); + } + + @Override + public ScorecardModelInfo createModelInfo(List rows) { + return new ScorecardModelInfo(rows, this.getSchema()); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/finance/ScorecardModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/finance/ScorecardModelMapper.java new file mode 100644 index 000000000..9294b72a8 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/finance/ScorecardModelMapper.java @@ -0,0 +1,124 @@ +package com.alibaba.alink.operator.common.finance; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import com.alibaba.alink.common.mapper.ComboModelMapper; +import com.alibaba.alink.common.mapper.PipelineModelMapper; +import com.alibaba.alink.operator.common.linear.LinearModelDataConverter; +import com.alibaba.alink.params.finance.BinningPredictParams; +import com.alibaba.alink.params.finance.ScorePredictParams; +import com.alibaba.alink.params.finance.ScorecardPredictParams; +import com.alibaba.alink.params.shared.colname.HasFeatureCols; +import com.alibaba.alink.pipeline.ModelExporterUtils; +import com.alibaba.alink.pipeline.PipelineStageBase; +import com.alibaba.alink.pipeline.feature.BinningModel; +import com.alibaba.alink.pipeline.finance.ScoreModel; +import org.apache.commons.lang3.ArrayUtils; + +import java.util.List; + +public class ScorecardModelMapper extends ComboModelMapper { + private static final long serialVersionUID = 7877677418109112341L; + private static String SCORE_SUFFIX = "_SCORE"; + + public ScorecardModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { + super(modelSchema, dataSchema, params); + + Tuple2 []> tuple2 = PipelineModelMapper.getExtendModelSchema(modelSchema); + String[] selectedCols = tuple2.f0; + + String[] scoreColNames = new String[selectedCols.length]; + for (int i = 0; i < scoreColNames.length; i++) { + scoreColNames[i] = selectedCols[i] + SCORE_SUFFIX; + } + + this.params.set(ScorePredictParams.PREDICTION_SCORE_PER_FEATURE_COLS, scoreColNames); + this.params.set(HasFeatureCols.FEATURE_COLS, selectedCols); + + Preconditions.checkArgument( + this.params.contains(ScorecardPredictParams.PREDICTION_SCORE_COL) + && this.params.contains(ScorecardPredictParams.PREDICTION_DETAIL_COL), + "predictionScore and predictionDetail must be given!" + ); + + if (this.params.get(ScorePredictParams.RESERVED_COLS) == null) { + this.params.set(ScorePredictParams.RESERVED_COLS, dataSchema.getFieldNames()); + } + } + + @Override + public void loadModel(List modelRows) { + List , TableSchema, List >> stages = + ModelExporterUtils.loadStagesFromPipelineModel(modelRows, getModelSchema()); + for (int i = 0; i < stages.size(); i++) { + PipelineStageBase stage = stages.get(i).f0; + + stage.getParams().merge(this.params); + + if (stage instanceof BinningModel) { + stage.getParams().set(ScorecardPredictParams.RESERVED_COLS, null); + this.params.set(HasFeatureCols.FEATURE_COLS, + stage.getParams().get(BinningPredictParams.OUTPUT_COLS)); + } + + if (stage instanceof ScoreModel) { + if (i < stages.size() - 1) { + stage.getParams().remove(ScorePredictParams.PREDICTION_SCORE_COL); + stage.getParams().remove(ScorePredictParams.PREDICTION_SCORE_PER_FEATURE_COLS); + stage.getParams().remove(ScorePredictParams.CALCULATE_SCORE_PER_FEATURE); + stage.getParams().set(ScorePredictParams.RESERVED_COLS, null); + stages.get(i + 1).f0.getParams().merge(this.params); + stages.get(i + 1).f0.getParams().remove(ScorePredictParams.PREDICTION_DETAIL_COL); + stages.get(i + 1).f0.getParams().set( + ScorePredictParams.RESERVED_COLS, + ArrayUtils.add( + stages.get(i + 1).f0.getParams().get(ScorePredictParams.RESERVED_COLS), + stage.getParams().get(ScorePredictParams.PREDICTION_DETAIL_COL) + ) + ); + break; + } + } + } + + this.mapperList = ModelExporterUtils + .loadMapperListFromStages(stages, getDataSchema()); + } + + @Override + protected Tuple4 [], String[]> prepareIoSchema(TableSchema modelSchema, + TableSchema dataSchema, + Params params) { + Tuple2 []> tuple2 = PipelineModelMapper.getExtendModelSchema(modelSchema); + String[] selectedCols = tuple2.f0; + TypeInformation labelType = tuple2.f1[0]; + + String[] scoreColNames = new String[selectedCols.length]; + for (int i = 0; i < scoreColNames.length; i++) { + scoreColNames[i] = selectedCols[i] + SCORE_SUFFIX; + } + + params.set(ScorePredictParams.PREDICTION_SCORE_PER_FEATURE_COLS, scoreColNames); + params.set(HasFeatureCols.FEATURE_COLS, selectedCols); + + if (params.get(ScorePredictParams.RESERVED_COLS) == null) { + params.set(ScorePredictParams.RESERVED_COLS, dataSchema.getFieldNames()); + } + + ScorePredictMapper emptyMapper = new ScorePredictMapper( + new LinearModelDataConverter(labelType).getModelSchema(), + dataSchema, params + ); + Tuple4 [], String[]> tuple4 = emptyMapper.prepareIoSchema(modelSchema, + dataSchema, params); + return Tuple4.of(dataSchema.getFieldNames(), tuple4.f1, tuple4.f2, + params.get(ScorecardPredictParams.RESERVED_COLS)); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/finance/VizData.java b/core/src/main/java/com/alibaba/alink/operator/common/finance/VizData.java new file mode 100644 index 000000000..3d033caed --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/finance/VizData.java @@ -0,0 +1,106 @@ +package com.alibaba.alink.operator.common.finance; + +import com.alibaba.alink.operator.common.feature.binning.FeatureBinsUtil; +import com.alibaba.alink.operator.common.similarity.SerializableComparator; + +import java.io.Serializable; + +public class VizData implements Serializable { + private static final long serialVersionUID = 6180829409440580211L; + public String featureName; + public Long index; + public String value; + + VizData() {} + + VizData(String featureName) { + this(featureName, null, null); + } + + VizData(String featureName, Long index, String value) { + this.featureName = featureName; + this.index = index; + this.value = value; + } + + public static class PSIVizData extends VizData { + private static final long serialVersionUID = 6939286125661156752L; + public Double testPercentage; + public Double basePercentage; + public Double testSubBase; + public Double lnTestDivBase; + public Double psi; + + PSIVizData() {} + + public PSIVizData(String featureName) { + super(featureName); + } + + public PSIVizData(String featureName, Long index, String value) { + super(featureName, index, value); + } + + public void calcPSI() { + testSubBase = testPercentage - basePercentage; + if (basePercentage > 0 && testPercentage > 0) { + lnTestDivBase = Math.log(testPercentage / basePercentage); + psi = testSubBase * lnTestDivBase / 100; + } + testPercentage = FeatureBinsUtil.keepGivenDecimal(testPercentage, 2); + basePercentage = FeatureBinsUtil.keepGivenDecimal(basePercentage, 2); + testSubBase = FeatureBinsUtil.keepGivenDecimal(testSubBase, 2); + lnTestDivBase = FeatureBinsUtil.keepGivenDecimal(lnTestDivBase, 2); + psi = FeatureBinsUtil.keepGivenDecimal(psi, 4); + } + } + + public static class ScorecardVizData extends VizData { + private static final long serialVersionUID = 260525523654434938L; + public Double unscaledValue; + public Double scaledValue; + public Double woe; + public Long total; + public Long positive; + public Long negative; + public Double positiveRate; + public Double negativeRate; + + ScorecardVizData() {} + + public ScorecardVizData(String featureName) { + super(featureName); + } + + public ScorecardVizData(String featureName, Long index, String value) { + super(featureName, index, value); + } + } + + public static SerializableComparator VizDataComparator = new SerializableComparator () { + private static final long serialVersionUID = 8715645640166932623L; + + @Override + public int compare(VizData o1, VizData o2) { + if (o1.index == null) { + return -1; + } + if (o2.index == null) { + return 1; + } + if (o1.index == -1) { + return o2.index < -1 ? -1 : 1; + } + if (o1.index == -2) { + return 1; + } + if (o2.index == -1) { + return o1.index < -1 ? 1 : -1; + } + if (o2.index == -2) { + return -1; + } + return o1.index.compareTo(o2.index); + } + }; +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/BaseStepWiseSelectorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/BaseStepWiseSelectorBatchOp.java new file mode 100644 index 000000000..cc9fe7a92 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/BaseStepWiseSelectorBatchOp.java @@ -0,0 +1,964 @@ +package com.alibaba.alink.operator.common.finance.stepwiseSelector; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Table; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.MLEnvironmentFactory; +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.SparseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.linalg.VectorUtil; +import com.alibaba.alink.common.model.ModelParamName; +import com.alibaba.alink.common.viz.AlinkViz; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.common.viz.VizDataWriterInterface; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.finance.ScorecardTrainBatchOp; +import com.alibaba.alink.operator.batch.utils.DataSetUtil; +import com.alibaba.alink.operator.common.dataproc.vector.VectorAssemblerMapper; +import com.alibaba.alink.operator.common.feature.SelectorModelData; +import com.alibaba.alink.operator.common.feature.SelectorModelDataConverter; +import com.alibaba.alink.operator.common.linear.BaseLinearModelTrainBatchOp; +import com.alibaba.alink.operator.common.linear.LinearModelData; +import com.alibaba.alink.operator.common.linear.LinearModelDataConverter; +import com.alibaba.alink.operator.common.linear.LinearRegressionSummary; +import com.alibaba.alink.operator.common.linear.LocalLinearModel; +import com.alibaba.alink.operator.common.linear.LogistRegressionSummary; +import com.alibaba.alink.operator.common.linear.ModelSummary; +import com.alibaba.alink.operator.common.linear.ModelSummaryHelper; +import com.alibaba.alink.operator.common.optim.FeatureConstraint; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; +import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummarizer; +import com.alibaba.alink.params.finance.BaseStepwiseSelectorParams; +import com.alibaba.alink.params.finance.ConstrainedLogisticRegressionTrainParams; +import com.alibaba.alink.params.shared.linear.LinearTrainParams; + +import java.security.InvalidParameterException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class BaseStepWiseSelectorBatchOp extends BatchOperator + implements BaseStepwiseSelectorParams , AlinkViz { + private static final long serialVersionUID = -1353820179732001005L; + + private static final String INNER_VECTOR_COL = "vec"; + private static final String INNER_LABLE_COL = "label"; + private BatchOperator in; + + private DataSet labels; + + private boolean hasConstraint; + private DataSet constraintDataSet; + + private boolean hasVectorSizes; + private DataSet vectorSizes = null; + + //col and label after standard + private String selectColNew; + private String labelColNew; + private int selectedColIdxNew; + private int labelIdxNew; + + //origin + private String selectedCol; + private String[] selectedCols; + + private LinearModelType linearModelType; + + private boolean inScorcard; + + public BaseStepWiseSelectorBatchOp(Params params) { + super(params); + } + + @Override + public BaseStepWiseSelectorBatchOp linkFrom(BatchOperator ... inputs) { + if (inputs.length != 2 && inputs.length != 1) { + throw new InvalidParameterException("input size must be one or two."); + } + + this.linearModelType = getLinearModelType(); + inScorcard = getParams().get(ScorecardTrainBatchOp.IN_SCORECARD); + + this.in = inputs[0]; + BatchOperator constraint = null; + if (inputs.length == 2) { + constraint = inputs[1]; + } + + int labelIdx = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), getLabelCol()); + TypeInformation labelType = TableUtil.findColTypeWithAssertAndHint(in.getSchema(), getLabelCol()); + + int[] forceSelectedColIndices = new int[0]; + if (getParams().contains(BaseStepwiseSelectorParams.FORCE_SELECTED_COLS)) { + forceSelectedColIndices = getForceSelectedCols(); + } + + String positiveLabel = null; + if (LinearModelType.LR == this.linearModelType || inScorcard) { + if (this.getParams().contains(ConstrainedLogisticRegressionTrainParams.POS_LABEL_VAL_STR)) { + positiveLabel = this.getParams().get(ConstrainedLogisticRegressionTrainParams.POS_LABEL_VAL_STR); + } + } + + standardConstraint(constraint); + standardLabel(positiveLabel); + transformToVector(labelIdx, labelType, linearModelType); + + //summary + DataSet summarizerData = in.getDataSet() + .map(new ToVectorWithReservedCols(selectedColIdxNew, labelIdxNew)); + + DataSet summary = StatisticsHelper.summarizer(summarizerData, true); + + DataSet > trainData = StatisticsHelper + .transformToVector(in, null, selectColNew, new String[] {labelColNew}); + + //train + DataSet result = trainData + .mapPartition(new StepWiseMapPartition(forceSelectedColIndices, + getAlphaEntry(), + getAlphaStay(), + getLinearModelType(), + getOptimMethod(), + getStepWiseType(), + getL1(), + getL2(), + hasVectorSizes, + hasConstraint)) + .withBroadcastSet(summary, "summarizer") + .withBroadcastSet(vectorSizes, "vectorSizes") + .withBroadcastSet(constraintDataSet, "constraint") + .setParallelism(1); + + if (getWithViz()) { + writeVizData(result, getLinearModelType(), selectedCols, this.getVizDataWriter()); + } + + setOutput(result.flatMap(new BuildModel(selectedCol, selectedCols, getLinearModelType())).setParallelism(1), + new SelectorModelDataConverter().getModelSchema()); + + Table[] sideTables = new Table[2]; + + //linear model + DataSet linearModel; + if (labels != null) { + linearModel = result.flatMap( + new BuildLinearModel(getLinearModelType(), selectedCols, + TableUtil.findColTypes(inputs[0].getSchema(), selectedCols), + getLabelCol(), labelType, positiveLabel, inScorcard)) + .withBroadcastSet(labels, "labelValues"); + } else { + linearModel = result.flatMap( + new BuildLinearModel(getLinearModelType(), selectedCols, + TableUtil.findColTypes(inputs[0].getSchema(), selectedCols), + getLabelCol(), labelType, positiveLabel, inScorcard)); + } + + sideTables[0] = DataSetConversionUtil.toTable(getMLEnvironmentId(), + linearModel, new LinearModelDataConverter(labelType).getModelSchema()); + + //statistics + sideTables[1] = DataSetConversionUtil.toTable(getMLEnvironmentId(), + result, new String[] {"result"}, new TypeInformation[] {Types.STRING}); + + setSideOutputTables(sideTables); + + return this; + } + + public DataSet getStepWiseSummary() { + return getSideOutput(1).getDataSet() + .map(new ToClassificationSelectorResult(this.linearModelType, getSelectedCols())); + } + + private static class ToClassificationSelectorResult implements MapFunction { + private static final long serialVersionUID = 382487577293357907L; + private String[] selectedCols; + private LinearModelType linearModelType; + + ToClassificationSelectorResult(LinearModelType linearModelType, String[] selectedCols) { + this.selectedCols = selectedCols; + this.linearModelType = linearModelType; + } + + @Override + public SelectorResult map(Row row) throws Exception { + SelectorResult result; + if (LinearModelType.LR == linearModelType) { + result = JsonConverter.fromJson((String) row.getField(0), + ClassificationSelectorResult.class); + } else { + result = JsonConverter.fromJson((String) row.getField(0), + RegressionSelectorResult.class); + } + result.selectedCols = BaseStepWiseSelectorBatchOp.getSCurSelectedCols(selectedCols, result + .selectedIndices); + return result; + } + } + + //deal with null + private void standardConstraint(BatchOperator constraint) { + this.hasConstraint = true; + if (constraint == null) { + this.constraintDataSet = MLEnvironmentFactory.get(in.getMLEnvironmentId()) + .getExecutionEnvironment().fromElements(new Row(0)); + this.hasConstraint = false; + } else { + this.constraintDataSet = constraint.getDataSet(); + } + } + + //label value to double type + private void standardLabel(String positiveLabel) { + String labelCol = getLabelCol(); + if (getLinearModelType() == LinearModelType.LR || inScorcard) { + Tuple2 > tuple2 + = ModelSummaryHelper.transformLrLabel(in, labelCol, positiveLabel, getMLEnvironmentId()); + this.in = tuple2.f0; + this.labels = tuple2.f1; + } + } + + //calc vector size + private void calcVectorSizes(int[] selectedColIndices, boolean isVector) { + if (isVector) { + vectorSizes = in.getDataSet() + .mapPartition(new CalcVectorSize(selectedColIndices)) + .reduce(new ReduceFunction () { + private static final long serialVersionUID = -3014179424640804678L; + + @Override + public int[] reduce(int[] left, int[] right) { + int[] result = new int[left.length]; + for (int i = 0; i < left.length; i++) { + result[i] = Math.max(left[i], right[i]); + } + return result; + } + }); + hasVectorSizes = true; + } else { + vectorSizes = MLEnvironmentFactory.get(in.getMLEnvironmentId()) + .getExecutionEnvironment().fromElements(new int[0]); + hasVectorSizes = false; + } + } + + private void transformToVector(int labelIdx, TypeInformation labelType, LinearModelType linearModelType) { + if (getParams().contains(BaseStepwiseSelectorParams.SELECTED_COL)) { + selectedCol = getSelectedCol(); + if (selectedCol != null && !selectedCol.isEmpty()) { + selectColNew = selectedCol; + selectedColIdxNew = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), selectedCol); + } + calcVectorSizes(null, false); + + labelColNew = getLabelCol(); + labelIdxNew = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), getLabelCol()); + } + + if (getParams().contains(BaseStepwiseSelectorParams.SELECTED_COLS)) { + selectedCols = getSelectedCols(); + int[] selectedColsIdx = TableUtil.findColIndicesWithAssertAndHint(in.getColNames(), selectedCols); + calcVectorSizes(selectedColsIdx, true); + + //labelType to double when lr + TypeInformation[] resultTypes = new TypeInformation[2]; + resultTypes[0] = AlinkTypes.VECTOR; + if (linearModelType == LinearModelType.LR || this.inScorcard) { + resultTypes[1] = Types.DOUBLE; + } else { + resultTypes[1] = labelType; + } + + //standard input + in = BatchOperator.fromTable(DataSetConversionUtil.toTable(getMLEnvironmentId(), + in.getDataSet() + .map(new VectorAssembler(selectedColsIdx, labelIdx)) + .withBroadcastSet(vectorSizes, "vectorSizes"), + new String[] {INNER_VECTOR_COL, INNER_LABLE_COL}, + resultTypes)); + + selectColNew = INNER_VECTOR_COL; + labelColNew = INNER_LABLE_COL; + selectedColIdxNew = 0; + labelIdxNew = 1; + } + } + + private static void writeVizData(DataSet summary, + LinearModelType linearModelType, + String[] selectedCols, + VizDataWriterInterface writer) { + DataSet dummy = summary.flatMap(new ProcViz(linearModelType, selectedCols, writer)) + .setParallelism(1) + .name("WriteStepWiseViz"); + DataSetUtil.linkDummySink(dummy); + } + + private static class ProcViz implements FlatMapFunction { + private static final long serialVersionUID = -2466978695333617548L; + private LinearModelType linearModelType; + private String[] selectedCols; + private VizDataWriterInterface writer; + + ProcViz(LinearModelType linearModelType, String[] selectedCols, VizDataWriterInterface writer) { + this.linearModelType = linearModelType; + this.selectedCols = selectedCols; + this.writer = writer; + } + + @Override + public void flatMap(Row row, Collector collector) throws Exception { + String result = (String) row.getField(0); + + String vizData; + if (LinearModelType.LR == linearModelType) { + ClassificationSelectorResult selector = JsonConverter.fromJson(result, + ClassificationSelectorResult.class); + selector.selectedCols = getSCurSelectedCols(selectedCols, selector.selectedIndices); + vizData = selector.toVizData(); + } else { + RegressionSelectorResult selector = JsonConverter.fromJson(result, RegressionSelectorResult.class); + selector.selectedCols = getSCurSelectedCols(selectedCols, selector.selectedIndices); + vizData = selector.toVizData(); + } + + writer.writeBatchData(0L, vizData, System.currentTimeMillis()); + } + } + + static class StepWiseMapPartition extends RichMapPartitionFunction , Row> { + private static final long serialVersionUID = -3204445094879081660L; + private int featureSize; + + private boolean hasVectorSizes; + private int[] vectorSizes; + + private BaseVectorSummarizer summary; + private List > trainData; + + //stepwise params + private StepWiseType stepwiseType; + private double alphaEntry; + private double alphaStay; + private int[] forceSelectedIndices; + + //optim params + private com.alibaba.alink.operator.common.linear.LinearModelType linearModelType; + private String optimMethod; + private double l1; + private double l2; + private boolean hasConstraint; + private FeatureConstraint constraints; + + StepWiseMapPartition(int[] forceSelectedIndices, + double alphaEntry, + double alphaStay, + LinearModelType linearModelType, + String optimMethod, + StepWiseType stepwiseType, + double l1, + double l2, + boolean hasVectorSizes, + boolean hasConstraint) { + this.forceSelectedIndices = forceSelectedIndices; + this.alphaEntry = alphaEntry; + this.alphaStay = alphaStay; + this.linearModelType = com.alibaba.alink.operator.common.linear.LinearModelType.valueOf( + linearModelType.name()); + this.optimMethod = optimMethod.toUpperCase().trim(); + this.stepwiseType = stepwiseType; + this.l1 = l1; + this.l2 = l2; + this.hasVectorSizes = hasVectorSizes; + this.hasConstraint = hasConstraint; + } + + public void open(Configuration conf) { + this.summary = (BaseVectorSummarizer) this.getRuntimeContext().getBroadcastVariable("summarizer").get(0); + + if (hasVectorSizes) { + List obj = this.getRuntimeContext().getBroadcastVariable("vectorSizes"); + int[] curVectorSizes = (int[]) obj.get(0); + vectorSizes = new int[curVectorSizes.length + 1]; + vectorSizes[0] = 0; + for (int i = 0; i < curVectorSizes.length; i++) { + vectorSizes[i + 1] = vectorSizes[i] + curVectorSizes[i]; + } + } + + String constraint = null; + if (hasConstraint) { + Object obj = ((Row) this.getRuntimeContext().getBroadcastVariable("constraint").get(0)).getField(0); + + if (obj instanceof FeatureConstraint) { + constraint = obj.toString(); + } else { + constraint = (String) obj; + } + } + this.constraints = FeatureConstraint.fromJson(constraint); + } + + @Override + public void mapPartition(Iterable > iterable, Collector collector) throws + Exception { + trainData = transformData(iterable); + + //set selected and rest + List selected = new ArrayList <>(); + List selectedDummy = new ArrayList <>(); + List rest = new ArrayList <>(); + + for (int forceSelectedIdx : forceSelectedIndices) { + selected.add(forceSelectedIdx); + selectedDummy.addAll(getOne(forceSelectedIdx, vectorSizes)); + } + + for (int i = 0; i < featureSize; i++) { + if (!selected.contains(i)) { + rest.add(i); + } + } + + ModelSummary lrBest = null; + double maxValue; + List selectorSteps = new ArrayList <>(); + + //trainLinear only has forceSelectedCol and intercept. + if (forceSelectedIndices != null && forceSelectedIndices.length != 0) { + lrBest = train(getIndicesFromList(selectedDummy)); + } + + //train with all features + ModelSummary wholeModelSummary = null; + if (isLinearRegression()) { + wholeModelSummary = train(null); + } + + while (true) { + if (rest.isEmpty()) { + break; + } + int selectedIndex = -1; + ModelSummary lrr = null; + + //forward + maxValue = Double.NEGATIVE_INFINITY; + for (Integer aRest : rest) { + List curSelected = new ArrayList <>(); + curSelected.addAll(selectedDummy); + curSelected.addAll(getOne(aRest, vectorSizes)); + + ModelSummary lrCur = train(getIndicesFromList(curSelected)); + + Tuple2 forwardValues = getForwardValue(lrCur, lrBest, this.stepwiseType); + + //forward selector + if (forwardValues.f0 > maxValue) { + if (selected.size() == 0 || forwardValues.f1 <= alphaEntry) { + maxValue = forwardValues.f0; + lrr = lrCur; + selectedIndex = aRest; + } + } + } + + //forward step all variable p value >= alphaEntry, loop stop. + if (selectedIndex < 0) { + break; + } + + //add var to selected + selectedDummy.addAll(getOne(selectedIndex, vectorSizes)); + selected.add(selectedIndex); + rest.remove(rest.indexOf(selectedIndex)); + + //backward selector + ArrayList deleted = new ArrayList <>(); + if (selected.size() > 1) { + double[] backwardValues = getBackwardValues(lrr, summary, this.stepwiseType, vectorSizes, + selected); + int maxBackWardIdx = argmax(backwardValues); + if (backwardValues[maxBackWardIdx] >= alphaStay + && !isIdxExist(forceSelectedIndices, selected.get(maxBackWardIdx))) { + int deleteIdx = selected.get(maxBackWardIdx); + deleted.add(deleteIdx); + selected.remove(selected.indexOf(deleteIdx)); + selectedDummy.removeAll(getOne(deleteIdx, vectorSizes)); + rest.add(deleteIdx); + } + } + if (deleted.size() == 1 && deleted.get(0).equals(selectedIndex)) { + break; + } + + //cur best model + lrBest = calMallowCp(lrr, wholeModelSummary); + + //add to selector steps + selectorSteps.add(lrBest.toSelectStep(selectedIndex)); + for (Integer aDeleted : deleted) { + int idx = aDeleted; + for (int k = 0; k < selectorSteps.size(); k++) { + if (idx == Integer.valueOf(selectorSteps.get(k).enterCol)) { + selectorSteps.remove(k); + break; + } + } + } + + if (deleted.size() != 0) { + lrBest = train(getIndicesFromList(selected)); + } + } + + Row row = new Row(1); + row.setField(0, bestModelResult(selectorSteps, lrBest)); + collector.collect(row); + } + + private ModelSummary calMallowCp(ModelSummary lrBest, ModelSummary wholeSummary) { + if (isLinearRegression()) { + LinearRegressionSummary lrs = (LinearRegressionSummary) lrBest; + //Cp = ( n - m - 1 )*( SSEp / SSEm )- n + 2*( p + 1 ) + lrs.mallowCp = (lrs.count - featureSize - 1) * (lrs.sse / + ((LinearRegressionSummary) wholeSummary).sse) - lrs.count + 2 * (lrs.lowerConfidence.length + 1); + } + return lrBest; + } + + private boolean isLinearRegression() { + return this.linearModelType == com.alibaba.alink.operator.common.linear.LinearModelType.LinearReg; + } + + private String bestModelResult(List selectorSteps, ModelSummary bestSummary) { + if (!isLinearRegression()) { + ClassificationSelectorResult selector = new ClassificationSelectorResult(); + selector.entryVars = new ClassificationSelectorStep[selectorSteps.size()]; + selector.selectedIndices = new int[selectorSteps.size() + forceSelectedIndices.length]; + System.arraycopy(forceSelectedIndices, 0, selector.selectedIndices, 0, forceSelectedIndices.length); + + for (int i = 0; i < selectorSteps.size(); i++) { + selector.entryVars[i] = (ClassificationSelectorStep) selectorSteps.get(i); + selector.selectedIndices[i + forceSelectedIndices.length] = Integer.valueOf( + selector.entryVars[i].enterCol); + } + selector.modelSummary = (LogistRegressionSummary) bestSummary; + return JsonConverter.toJson(selector); + } else { + RegressionSelectorResult selector = new RegressionSelectorResult(); + selector.entryVars = new RegressionSelectorStep[selectorSteps.size()]; + selector.selectedIndices = new int[selectorSteps.size() + forceSelectedIndices.length]; + + System.arraycopy(forceSelectedIndices, 0, selector.selectedIndices, 0, forceSelectedIndices.length); + + for (int i = 0; i < selectorSteps.size(); i++) { + selector.entryVars[i] = (RegressionSelectorStep) selectorSteps.get(i); + selector.selectedIndices[i + forceSelectedIndices.length] = Integer.valueOf( + selector.entryVars[i].enterCol); + } + + selector.modelSummary = (LinearRegressionSummary) bestSummary; + return JsonConverter.toJson(selector); + } + } + + private List > transformData(Iterable > iterable) { + //transform data + if (hasVectorSizes) { + featureSize = vectorSizes.length - 1; + } else { + featureSize = summary.toSummary().vectorSize() - 1; + } + + //f0: weight, f1: label, f2: data + List > trainData = new ArrayList <>(); + for (Tuple2 tuple2 : iterable) { + if (vectorSizes == null && tuple2.f0 instanceof SparseVector) { + ((SparseVector) tuple2.f0).setSize(featureSize); + } + trainData.add(Tuple3.of(1.0, ((Number) tuple2.f1.getField(0)).doubleValue(), tuple2.f0)); + } + return trainData; + } + + private ModelSummary train(int[] indices) { + String constraint = indices == null ? + constraints.toString() : + constraints.extractConstraint(indices); + return + LocalLinearModel.trainWithSummary(trainData, indices, + this.linearModelType, this.optimMethod, + true, false, + constraint, l1, l2, summary); + } + } + + public static String[] getSCurSelectedCols(String[] selectedCols, int[] indices) { + if (selectedCols == null || selectedCols.length == 0) { + return null; + } + + String[] curSelectedCols = new String[indices.length]; + for (int i = 0; i < indices.length; i++) { + curSelectedCols[i] = selectedCols[indices[i]]; + } + + return curSelectedCols; + } + + /** + * build selector model. + */ + public static class BuildModel implements FlatMapFunction { + private static final long serialVersionUID = -4792429339624354557L; + private String selectedCol; + private String[] selectedCols; + private LinearModelType linearModelType; + + BuildModel(String selectedCol, String[] selectedCols, LinearModelType linearModelType) { + this.selectedCol = selectedCol; + this.selectedCols = selectedCols; + this.linearModelType = linearModelType; + } + + @Override + public void flatMap(Row row, Collector collector) throws Exception { + SelectorModelData data = new SelectorModelData(); + + if (LinearModelType.LR == linearModelType) { + ClassificationSelectorResult result = JsonConverter.fromJson((String) row.getField(0), + ClassificationSelectorResult.class); + + data.vectorColName = selectedCol; + data.selectedIndices = result.selectedIndices; + data.vectorColNames = getSCurSelectedCols(selectedCols, result.selectedIndices); + + } else { + RegressionSelectorResult result = JsonConverter.fromJson((String) row.getField(0), + RegressionSelectorResult.class); + + data.vectorColName = selectedCol; + data.selectedIndices = result.selectedIndices; + data.vectorColNames = getSCurSelectedCols(selectedCols, result.selectedIndices); + + } + + new SelectorModelDataConverter().save(data, collector); + } + } + + /** + * build linear model. + */ + public static class BuildLinearModel extends RichFlatMapFunction { + private LinearModelType linearModelType; + private String[] selectedCols; + private TypeInformation[] selectedColsType; + private String labelCol; + private TypeInformation labelType; + private Object[] labelValues; + private String positiveLabel; + private boolean inScorecard; + + public BuildLinearModel(LinearModelType linearModelType, + String[] selectedCols, + TypeInformation[] selectColsTypes, + String labelCol, + TypeInformation labelType, + String positiveLabel, + boolean inScorecard) { + this.linearModelType = linearModelType; + this.selectedCols = selectedCols; + this.selectedColsType = selectColsTypes; + this.labelCol = labelCol; + this.labelType = labelType; + this.positiveLabel = positiveLabel; + this.inScorecard = inScorecard; + } + + public void open(Configuration conf) { + if (LinearModelType.LR == this.linearModelType || inScorecard) { + List labels = this.getRuntimeContext().getBroadcastVariable("labelValues"); + labelValues = ModelSummaryHelper.orderLabels(labels, positiveLabel); + } + } + + @Override + public void flatMap(Row row, Collector out) throws Exception { + LinearModelData modelData; + DenseVector coefs; + int[] selectedIndices; + + if (LinearModelType.LR == linearModelType) { + ClassificationSelectorResult result = JsonConverter.fromJson((String) row.getField(0), + ClassificationSelectorResult.class); + selectedIndices = result.selectedIndices; + coefs = result.modelSummary.beta; + + } else { + RegressionSelectorResult result = JsonConverter.fromJson((String) row.getField(0), + RegressionSelectorResult.class); + selectedIndices = result.selectedIndices; + coefs = result.modelSummary.beta; + } + + String[] featureCols = getSCurSelectedCols(selectedCols, selectedIndices); + String[] featureTypes = new String[featureCols.length]; + for (int i = 0; i < featureCols.length; i++) { + featureTypes[i] = selectedColsType[selectedIndices[i]].getTypeClass().getSimpleName(); + } + + Params meta = new Params() + .set(ModelParamName.MODEL_NAME, "model") + .set(ModelParamName.LINEAR_MODEL_TYPE, + com.alibaba.alink.operator.common.linear.LinearModelType.valueOf(linearModelType.name())) + .set(ModelParamName.LABEL_VALUES, labelValues) + .set(ModelParamName.HAS_INTERCEPT_ITEM, true) + .set(ModelParamName.FEATURE_TYPES, featureTypes) + .set(LinearTrainParams.LABEL_COL, labelCol); + + modelData = BaseLinearModelTrainBatchOp.buildLinearModelData(meta, + featureCols, + this.labelType, + null, + true, + false, + Tuple2.of(coefs, new double[] {0})); + + new LinearModelDataConverter(this.labelType).save(modelData, out); + } + } + + private static List getOne(int selectedIdx, int[] vectorSizes) { + ArrayList result = new ArrayList <>(); + if (vectorSizes == null) { + result.add(selectedIdx); + } else { + for (int i = vectorSizes[selectedIdx]; i < vectorSizes[selectedIdx + 1]; i++) { + result.add(i); + } + } + return result; + } + + private static Tuple2 getForwardValue(ModelSummary modelSummary, ModelSummary lastSummary, + StepWiseType type) { + switch (type) { + case fTest: + LinearRegressionSummary linearRegressionSummary = (LinearRegressionSummary) modelSummary; + return Tuple2.of(linearRegressionSummary.fValue, linearRegressionSummary.pValue); + case scoreTest: + LogistRegressionSummary logistRegressionSummary = (LogistRegressionSummary) modelSummary; + return Tuple2.of(logistRegressionSummary.scoreChiSquareValue, logistRegressionSummary.scorePValue); + case marginalContribution: + double lastLoss = 0; + if (lastSummary != null) { + lastLoss = lastSummary.loss; + } + double mc = (modelSummary.loss - lastLoss) / modelSummary.count; + return Tuple2.of(mc, mc); + default: + throw new RuntimeException("It is not support."); + + } + + } + + private static double[] getBackwardValues(ModelSummary modelSummary, + BaseVectorSummarizer srt, + StepWiseType type, + int[] vectorSizes, + List selected) { + switch (type) { + case fTest: + LinearRegressionSummary linearRegressionSummary = (LinearRegressionSummary) modelSummary; + return linearRegressionSummary.tPVaues; + case scoreTest: + LogistRegressionSummary logistRegressionSummary = (LogistRegressionSummary) modelSummary; + return Arrays.copyOfRange(logistRegressionSummary.waldPValues, 1, + logistRegressionSummary.waldPValues.length); + case marginalContribution: + int featureNum = selected.size(); + // if (vectorSizes != null) { + // featureNum = vectorSizes.length - 1; + // } + double[] mcs = new double[featureNum]; + + for (int i = 0; i < featureNum; i++) { + ArrayList curSelectedDummay = new ArrayList <>(); + for (int j = 0; j < i; j++) { + curSelectedDummay.addAll(getOne(selected.get(j), vectorSizes)); + } + for (int j = i + 1; j < featureNum; j++) { + curSelectedDummay.addAll(getOne(selected.get(j), vectorSizes)); + } + + com.alibaba.alink.operator.common.linear.LinearModelType linearModelType; + + if (modelSummary instanceof LogistRegressionSummary) { + linearModelType = com.alibaba.alink.operator.common.linear.LinearModelType.LR; + } else { + linearModelType = com.alibaba.alink.operator.common.linear.LinearModelType.LinearReg; + } + + Tuple4 model = + Tuple4.of(modelSummary.beta, modelSummary.gradient, modelSummary.hessian, modelSummary.loss); + ModelSummary lrCur = + LocalLinearModel.calcModelSummary(model, srt, linearModelType, + getIndicesFromList(curSelectedDummay)); + mcs[i] = (lrCur.loss - modelSummary.loss) / modelSummary.count; + } + return mcs; + default: + throw new RuntimeException("It is not support."); + } + + } + + public static class ToVectorWithReservedCols implements MapFunction { + private static final long serialVersionUID = 5163307315870607698L; + private int vectorColIdx; + private int labelColIdx; + + public ToVectorWithReservedCols(int vectorColIndex, int labelColIdx) { + this.vectorColIdx = vectorColIndex; + this.labelColIdx = labelColIdx; + } + + @Override + public Vector map(Row in) throws Exception { + Vector vec = VectorUtil.getVector(in.getField(vectorColIdx)); + + if (vec == null) { + throw new RuntimeException( + "vector is null, please check your input data."); + } + + return vec.prefix(((Number) in.getField(labelColIdx)).doubleValue()); + } + } + + private static class CalcVectorSize implements MapPartitionFunction { + private static final long serialVersionUID = -4671985561279422505L; + private int[] selectedColIndices; + + CalcVectorSize(int[] selectedColIndices) { + this.selectedColIndices = selectedColIndices; + } + + @Override + public void mapPartition(Iterable iterable, Collector collector) throws Exception { + int[] vectorSizes = new int[selectedColIndices.length]; + Arrays.fill(vectorSizes, 0); + int curLocalSize = 0; + for (Row row : iterable) { + for (int i = 0; i < vectorSizes.length; i++) { + Vector vec = VectorUtil.getVector(row.getField(selectedColIndices[i])); + if (vec instanceof DenseVector) { + curLocalSize = vec.size(); + } else { + SparseVector sv = (SparseVector) vec; + curLocalSize = sv.size() == -1 ? sv.numberOfValues() : sv.size(); + } + vectorSizes[i] = Math.max(vectorSizes[i], curLocalSize); + } + } + collector.collect(vectorSizes); + } + } + + public static class VectorAssembler extends RichMapFunction { + private static final long serialVersionUID = 2474917145545199423L; + private int[] vectorSizes; + private int[] selectedColIndices; + private int labelColIdx; + + public VectorAssembler(int[] selectedColIndices, int labelColIdx) { + this.selectedColIndices = selectedColIndices; + this.labelColIdx = labelColIdx; + + } + + public void open(Configuration conf) { + this.vectorSizes = (int[]) this.getRuntimeContext().getBroadcastVariable("vectorSizes").get(0); + } + + @Override + public Row map(Row row) throws Exception { + int featureNum = selectedColIndices.length; + Object[] values = new Object[featureNum]; + for (int i = 0; i < featureNum; i++) { + Vector vec = VectorUtil.getVector(row.getField(selectedColIndices[i])); + if (vec instanceof SparseVector) { + ((SparseVector) vec).setSize(vectorSizes[i]); + } + values[i] = vec; + } + + Row out = new Row(2); + out.setField(0, VectorAssemblerMapper.assembler(values)); + out.setField(1, row.getField(labelColIdx)); + + return out; + } + + } + + private static int[] getIndicesFromList(List indices) { + int[] result = new int[indices.size()]; + for (int i = 0; i < result.length; i++) { + result[i] = indices.get(i); + } + return result; + } + + private static int argmax(double[] values) { + if (values == null && values.length == 0) { + throw new RuntimeException("max values is null."); + } + int idx = 0; + double maxVal = values[0]; + for (int i = 1; i < values.length; i++) { + if (maxVal < values[i]) { + maxVal = values[i]; + idx = i; + } + } + return idx; + } + + private static Boolean isIdxExist(int[] indices, int idx) { + for (int i = 0; i < indices.length; i++) { + if (idx == indices[i]) { + return true; + } + } + return false; + } +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/ClassificationSelectorResult.java b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/ClassificationSelectorResult.java new file mode 100644 index 000000000..021af24d4 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/ClassificationSelectorResult.java @@ -0,0 +1,114 @@ +package com.alibaba.alink.operator.common.finance.stepwiseSelector; + +import com.alibaba.alink.common.utils.AlinkSerializable; +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.operator.common.linear.LogistRegressionSummary; + +public class ClassificationSelectorResult extends SelectorResult { + public ClassificationSelectorStep[] entryVars; + public LogistRegressionSummary modelSummary; + + @Override + public String toVizData() { + if (selectedIndices == null) { + selectedIndices = new int[modelSummary.beta.size() - 1]; + for (int i = 0; i < selectedIndices.length; i++) { + selectedIndices[i] = i; + } + } + + int featureSize = selectedIndices.length; + + VizData data = new VizData(); + data.entryVars = entryVars; + + if (data.entryVars == null) { + data.entryVars = new ClassificationSelectorStep[0]; + } + + if (selectedCols != null && selectedCols.length != 0) { + for (int i = 0; i < data.entryVars.length; i++) { + data.entryVars[i].enterCol = selectedCols[i]; + } + } + + //weight + data.weights = new Weight[featureSize + 1]; + for (int i = 0; i < featureSize + 1; i++) { + data.weights[i] = new Weight(); + data.weights[i].variable = getColName(selectedCols, selectedIndices, i); + data.weights[i].weight = modelSummary.beta.get(i); + } + + //model summary + data.summary = new ModelSummary[3]; + data.summary[0] = new ModelSummary(); + data.summary[0].criterion = "AIC"; + data.summary[0].value = modelSummary.aic; + + data.summary[1] = new ModelSummary(); + data.summary[1].criterion = "SC"; + data.summary[1].value = modelSummary.sc; + + data.summary[2] = new ModelSummary(); + data.summary[2].criterion = "-2* LL"; + data.summary[2].value = 2 * modelSummary.loss; + + //para est + data.paramEsts = new ParamEst[featureSize + 1]; + for (int i = 0; i < featureSize + 1; i++) { + data.paramEsts[i] = new ParamEst(); + data.paramEsts[i].variable = getColName(selectedCols, selectedIndices, i); + data.paramEsts[i].weight = String.valueOf(modelSummary.beta.get(i)); + data.paramEsts[i].stdEsts = String.valueOf(modelSummary.stdEsts[i]); + data.paramEsts[i].stdErrs = String.valueOf(modelSummary.stdErrs[i]); + data.paramEsts[i].chiSquareValue = String.valueOf(modelSummary.waldChiSquareValue[i]); + data.paramEsts[i].pValue = String.valueOf(modelSummary.waldPValues[i]); + data.paramEsts[i].lowerConfidence = String.valueOf(modelSummary.lowerConfidence[i]); + data.paramEsts[i].uperConfidence = String.valueOf(modelSummary.uperConfidence[i]); + } + + return JsonConverter.toJson(data); + } + + static String getColName(String[] selectedCols, int[] selectedIndices, int id) { + if (id == 0) { + return "Intercept"; + } else { + if (selectedCols == null || selectedCols.length == 0) { + return String.valueOf(selectedIndices[id - 1]); + } else { + return selectedCols[id - 1]; + } + } + } + + private static class Weight implements AlinkSerializable { + public String variable; + public double weight; + } + + private static class ParamEst implements AlinkSerializable { + public String variable; + public String weight; + public String stdEsts; + public String stdErrs; + public String chiSquareValue; + public String pValue; + public String lowerConfidence; + public String uperConfidence; + } + + private static class ModelSummary implements AlinkSerializable { + public String criterion; + public double value; + } + + private static class VizData implements AlinkSerializable { + public Weight[] weights; + public ParamEst[] paramEsts; + public ModelSummary[] summary; + public ClassificationSelectorStep[] entryVars; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/ClassificationSelectorStep.java b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/ClassificationSelectorStep.java new file mode 100644 index 000000000..2e50939d2 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/ClassificationSelectorStep.java @@ -0,0 +1,7 @@ +package com.alibaba.alink.operator.common.finance.stepwiseSelector; + +public class ClassificationSelectorStep extends SelectorStep { + public double scoreValue; + public double pValue; + public int numberIn; +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/OptimMethod.java b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/OptimMethod.java new file mode 100644 index 000000000..8c24b94c2 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/OptimMethod.java @@ -0,0 +1,6 @@ +package com.alibaba.alink.operator.common.finance.stepwiseSelector; + +public enum OptimMethod { + LBFGS, + Netwon +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/RegressionSelectorResult.java b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/RegressionSelectorResult.java new file mode 100644 index 000000000..7b30ea964 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/RegressionSelectorResult.java @@ -0,0 +1,122 @@ +package com.alibaba.alink.operator.common.finance.stepwiseSelector; + +import com.alibaba.alink.common.utils.AlinkSerializable; +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.operator.common.linear.LinearRegressionSummary; + +public class RegressionSelectorResult extends SelectorResult { + public RegressionSelectorStep[] entryVars; + public LinearRegressionSummary modelSummary; + public int[] deleteIndices; + + @Override + public String toVizData() { + if (selectedIndices == null) { + selectedIndices = new int[modelSummary.beta.size() - 1]; + for (int i = 0; i < selectedIndices.length; i++) { + selectedIndices[i] = i; + } + } + int featureSize = selectedIndices.length; + + VizData data = new VizData(); + data.entryVars = entryVars; + + if (data.entryVars == null) { + data.entryVars = new RegressionSelectorStep[0]; + } + + if (selectedCols != null && selectedCols.length != 0) { + for (int i = 0; i < data.entryVars.length; i++) { + data.entryVars[i].enterCol = selectedCols[i]; + } + } + + //weight + data.weights = new Weight[featureSize + 1]; + + data.weights[0] = new Weight(); + data.weights[0].variable = "Intercept"; + data.weights[0].weight = modelSummary.beta.get(0); + + for (int i = 0; i < featureSize; i++) { + data.weights[i + 1] = new Weight(); + if (selectedCols != null && selectedCols.length != 0) { + data.weights[i + 1].variable = selectedCols[i]; + } else { + data.weights[i + 1].variable = String.valueOf(selectedIndices[i]); + } + data.weights[i + 1].weight = modelSummary.beta.get(i + 1); + } + + //model summary + data.summary = new ModelSummary[2]; + data.summary[0] = new ModelSummary(); + data.summary[0].criterion = "r2"; + data.summary[0].value = modelSummary.r2; + + data.summary[1] = new ModelSummary(); + data.summary[1].criterion = "ra2"; + data.summary[1].value = modelSummary.ra2; + + //para est + data.paramEsts = new ParamEst[featureSize + 1]; + data.paramEsts[0] = new ParamEst(); + + data.paramEsts[0].variable = "Intercept"; + data.paramEsts[0].weight = String.valueOf(modelSummary.beta.get(0)); + data.paramEsts[0].stdEsts = "-"; + data.paramEsts[0].stdErrs = "-"; + data.paramEsts[0].tValues = "-"; + data.paramEsts[0].pValue = "-"; + data.paramEsts[0].lowerConfidence = "-"; + data.paramEsts[0].uperConfidence = "-"; + + for (int i = 0; i < featureSize; i++) { + data.paramEsts[i + 1] = new ParamEst(); + if (selectedCols != null && selectedCols.length != 0) { + data.paramEsts[i + 1].variable = selectedCols[i]; + } else { + data.paramEsts[i + 1].variable = String.valueOf(selectedIndices[i]); + } + data.paramEsts[i + 1].weight = String.valueOf(modelSummary.beta.get(i + 1)); + data.paramEsts[i + 1].stdEsts = String.valueOf(modelSummary.stdEsts[i]); + data.paramEsts[i + 1].stdErrs = String.valueOf(modelSummary.stdErrs[i]); + data.paramEsts[i + 1].tValues = String.valueOf(modelSummary.tValues[i]); + data.paramEsts[i + 1].pValue = String.valueOf(modelSummary.tPVaues[i]); + data.paramEsts[i + 1].lowerConfidence = String.valueOf(modelSummary.lowerConfidence[i]); + data.paramEsts[i + 1].uperConfidence = String.valueOf(modelSummary.uperConfidence[i]); + } + + return JsonConverter.toJson(data); + } + + private static class Weight implements AlinkSerializable { + public String variable; + public double weight; + } + + private static class ParamEst implements AlinkSerializable { + public String variable; + public String weight; + public String stdEsts; + public String stdErrs; + public String tValues; + public String pValue; + public String lowerConfidence; + public String uperConfidence; + } + + private static class ModelSummary implements AlinkSerializable { + public String criterion; + public double value; + } + + private static class VizData implements AlinkSerializable { + public Weight[] weights; + public ParamEst[] paramEsts; + public ModelSummary[] summary; + public RegressionSelectorStep[] entryVars; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/RegressionSelectorStep.java b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/RegressionSelectorStep.java new file mode 100644 index 000000000..c79364fde --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/RegressionSelectorStep.java @@ -0,0 +1,10 @@ +package com.alibaba.alink.operator.common.finance.stepwiseSelector; + +public class RegressionSelectorStep extends SelectorStep { + public double ra2; + public double r2; + public double mallowCp; + public double fValue; + public double pValue; + public int numberIn; +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/SelectorModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/SelectorModelMapper.java new file mode 100644 index 000000000..1b49ec8d8 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/SelectorModelMapper.java @@ -0,0 +1,75 @@ +package com.alibaba.alink.operator.common.finance.stepwiseSelector; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.linalg.VectorUtil; +import com.alibaba.alink.common.mapper.ModelMapper; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.common.dataproc.vector.VectorAssemblerMapper; +import com.alibaba.alink.operator.common.feature.SelectorModelData; +import com.alibaba.alink.operator.common.feature.SelectorModelDataConverter; +import com.alibaba.alink.params.finance.SelectorPredictParams; + +import java.util.List; + +public class SelectorModelMapper extends ModelMapper { + private static final long serialVersionUID = -4884089344356950010L; + + private SelectorModelData smd; + private int[] selectedIndices; + private int selectedIdx; + + public SelectorModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { + super(modelSchema, dataSchema, params); + } + + @Override + public void loadModel(List modelRows) { + this.smd = new SelectorModelDataConverter().load(modelRows); + if (smd.vectorColNames != null) { + selectedIndices = TableUtil.findColIndicesWithAssert(this.getDataSchema().getFieldNames(), + smd.vectorColNames); + } else { + String colName = smd.vectorColName; + if (params.contains(SelectorPredictParams.SELECTED_COL)) { + colName = params.get(SelectorPredictParams.SELECTED_COL); + } + selectedIdx = TableUtil.findColIndexWithAssert(this.getDataSchema().getFieldNames(), colName); + } + } + + /** + * Returns the tuple of selectedCols, resultCols, resultTypes, reservedCols. + */ + @Override + + protected Tuple4 [], String[]> prepareIoSchema(TableSchema modelSchema, + TableSchema dataSchema, + Params params) { + return Tuple4.of(dataSchema.getFieldNames(), + new String[] {params.get(SelectorPredictParams.PREDICTION_COL)}, + new TypeInformation[] {AlinkTypes.VECTOR}, + params.get(SelectorPredictParams.RESERVED_COLS)); + } + + @Override + protected void map(SlicedSelectedSample selection, SlicedResult result) { + Vector vec = null; + if (smd.vectorColNames != null) { + Object[] values = new Object[smd.vectorColNames.length]; + for (int i = 0; i < smd.vectorColNames.length; i++) { + values[i] = VectorUtil.getVector(selection.get(selectedIndices[i])); + } + vec = (Vector) VectorAssemblerMapper.assembler(values); + } else { + vec = VectorUtil.getVector(selection.get(selectedIdx)).slice(smd.selectedIndices); + } + result.set(0, vec); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/SelectorResult.java b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/SelectorResult.java new file mode 100644 index 000000000..8b9a78c62 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/SelectorResult.java @@ -0,0 +1,12 @@ +package com.alibaba.alink.operator.common.finance.stepwiseSelector; + +import com.alibaba.alink.common.utils.AlinkSerializable; + +import java.io.Serializable; + +public abstract class SelectorResult implements AlinkSerializable, Serializable { + public int[] selectedIndices; + public String[] selectedCols; + + public abstract String toVizData(); +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/SelectorStep.java b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/SelectorStep.java new file mode 100644 index 000000000..3dd4dcab2 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/SelectorStep.java @@ -0,0 +1,7 @@ +package com.alibaba.alink.operator.common.finance.stepwiseSelector; + +import com.alibaba.alink.common.utils.AlinkSerializable; + +public class SelectorStep implements AlinkSerializable { + public String enterCol; +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/StepWiseType.java b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/StepWiseType.java new file mode 100644 index 000000000..998b8554b --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/StepWiseType.java @@ -0,0 +1,7 @@ +package com.alibaba.alink.operator.common.finance.stepwiseSelector; + +public enum StepWiseType { + fTest, + scoreTest, + marginalContribution +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/StepwiseVizData.java b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/StepwiseVizData.java new file mode 100644 index 000000000..4c8946839 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/finance/stepwiseSelector/StepwiseVizData.java @@ -0,0 +1,38 @@ +package com.alibaba.alink.operator.common.finance.stepwiseSelector; + +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.utils.AlinkSerializable; + +import java.util.List; + +public class StepwiseVizData implements AlinkSerializable { + public String model; + public String modelSummary; + public String parameterEsts; + public String selector; + + public static String formatRows(List row) { + StringBuilder sbd = new StringBuilder(); + + for (int i = 0; i < row.size(); ++i) { + Row data = row.get(i); + for (int j = 0; j < data.getArity(); j++) { + Object obj = data.getField(j); + if (obj instanceof Double || obj instanceof Float) { + sbd.append(String.format("%.4f", ((Number) obj).doubleValue())); + } else { + sbd.append(obj); + } + if (j != data.getArity() - 1) { + sbd.append(","); + } + } + + if (i != row.size() - 1) { + sbd.append("\n"); + } + } + return sbd.toString(); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/fm/BaseFmTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/fm/BaseFmTrainBatchOp.java index 66fbb1394..b0612ee84 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/fm/BaseFmTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/fm/BaseFmTrainBatchOp.java @@ -9,6 +9,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.configuration.Configuration; @@ -34,7 +35,7 @@ import com.alibaba.alink.common.linalg.Vector; import com.alibaba.alink.common.linalg.VectorUtil; import com.alibaba.alink.common.model.ModelParamName; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; @@ -181,7 +182,7 @@ public void flatMap(Tuple2 value, DataSet modelRows = transformModel(model, labelValues, featSize, params, dim, isRegProc, labelType); this.setOutput(modelRows, new FmModelDataConverter(labelType).getModelSchema()); - this.setSideOutputTables(getSideTablesOfCoefficient(modelRows, labelType)); + this.setSideOutputTables(getSideTablesOfCoefficient(modelRows, model.project(1), labelType)); return (T) this; } @@ -399,7 +400,7 @@ public void mapPartition(Iterable values, Collector modelRow, final TypeInformation labelType) { + private Table[] getSideTablesOfCoefficient(DataSet modelRow, DataSet> cinfo, final TypeInformation labelType) { DataSet model = modelRow.mapPartition(new MapPartitionFunction() { private static final long serialVersionUID = 2063366042018382802L; @@ -422,7 +423,8 @@ public void mapPartition(Iterable values, Collector out) { public void mapPartition(Iterable values, Collector out) { FmModelData model = values.iterator().next(); - double[] cInfo = model.convergenceInfo; + double[] cInfo = ((Tuple1 )getRuntimeContext() + .getBroadcastVariable("cinfo").get(0)).f0; Params meta = new Params(); meta.set(ModelParamName.VECTOR_SIZE, model.vectorSize); meta.set(ModelParamName.LABEL_VALUES, model.labelValues); @@ -433,7 +435,8 @@ public void mapPartition(Iterable values, Collector out) { out.collect(Row.of(1, JsonConverter.toJson(cInfo))); } - }).setParallelism(1).withBroadcastSet(model, "model"); + }).setParallelism(1) + .withBroadcastSet(cinfo, "cinfo"); Table summaryTable = DataSetConversionUtil.toTable(getMLEnvironmentId(), summary, new TableSchema( new String[]{"title", "info"}, new TypeInformation[]{Types.INT, Types.STRING})); @@ -527,7 +530,13 @@ public double dldy(double yTruth, double y) { } private double sigmoid(double y) { - return 1.0 / (1.0 + Math.exp(-y)); + if (y < -37) { + return 0.0; + } else if (y > 34) { + return 1.0; + } else { + return 1.0 / (1.0 + Math.exp(-y)); + } } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/fm/FmModelDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/fm/FmModelDataConverter.java index 5dddac88b..4cc1eaf9b 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/fm/FmModelDataConverter.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/fm/FmModelDataConverter.java @@ -43,8 +43,7 @@ public void save(FmModelData modelData, Collector collector) { .set(ModelParamName.FEATURE_COL_NAMES, modelData.featureColNames) .set(ModelParamName.LABEL_VALUES, modelData.labelValues) .set(ModelParamName.DIM, modelData.dim) - .set(ModelParamName.REGULAR, modelData.regular) - .set(ModelParamName.LOSS_CURVE, modelData.convergenceInfo); + .set(ModelParamName.REGULAR, modelData.regular); FmDataFormat factors = modelData.fmModel; collector.collect(Row.of(null, meta.toJson(), null)); @@ -77,7 +76,6 @@ public FmModelData load(List rows) { modelData.dim = meta.get(ModelParamName.DIM); modelData.regular = meta.contains(ModelParamName.REGULAR) ? meta.get(ModelParamName.REGULAR) : null; modelData.vectorSize = meta.get(ModelParamName.VECTOR_SIZE); - modelData.convergenceInfo = meta.get(ModelParamName.LOSS_CURVE); if (meta.contains(ModelParamName.LABEL_VALUES)) { modelData.labelValues = meta.get(ModelParamName.LABEL_VALUES); diff --git a/core/src/main/java/com/alibaba/alink/operator/common/fm/FmModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/fm/FmModelMapper.java index ade056a05..87ed7296d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/fm/FmModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/fm/FmModelMapper.java @@ -87,7 +87,13 @@ public double getY(SparseVector feature, boolean isBinCls) { } private static double logit(double y) { - return 1. / (1. + Math.exp(-y)); + if (y < -37) { + return 0.0; + } else if (y > 34) { + return 1.0; + } else { + return 1.0 / (1.0 + Math.exp(-y)); + } } public FmModelData getModel() { diff --git a/core/src/main/java/com/alibaba/alink/operator/common/fm/FmRegressorModelTrainInfo.java b/core/src/main/java/com/alibaba/alink/operator/common/fm/FmRegressorModelTrainInfo.java index 3258ff6b6..f272e2554 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/fm/FmRegressorModelTrainInfo.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/fm/FmRegressorModelTrainInfo.java @@ -67,16 +67,16 @@ public String toString() { sbd.append(PrettyDisplayUtils.displayMap(map, 2, false)).append("\n"); sbd.append(PrettyDisplayUtils.displayHeadline("train convergence info", '-')); - if (convInfo.length < 6) { + if (convInfo.length < 20) { for (String s : convInfo) { sbd.append(s).append("\n"); } } else { - for (int i = 0; i < 3; ++i) { + for (int i = 0; i < 10; ++i) { sbd.append(convInfo[i]).append("\n"); } sbd.append("" + "... ... ... ..." + "\n"); - for (int i = convInfo.length - 3; i < convInfo.length; ++i) { + for (int i = convInfo.length - 10; i < convInfo.length; ++i) { sbd.append(convInfo[i]).append("\n"); } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/fm/FmTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/fm/FmTrainBatchOp.java index ff2e69bab..922616972 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/fm/FmTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/fm/FmTrainBatchOp.java @@ -129,7 +129,6 @@ public void flatMap(Tuple2 value, Collector out) th params.get(FmTrainParams.LAMBDA_2)}; modelData.labelColName = params.get(FmTrainParams.LABEL_COL); modelData.task = params.get(ModelParamName.TASK); - modelData.convergenceInfo = value.f1; if (!isRegProc) { modelData.labelValues = this.labelValues; } else { diff --git a/core/src/main/java/com/alibaba/alink/operator/common/image/ReadImageToTensorMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/image/ReadImageToTensorMapper.java index a4ce62121..01d9ea6e8 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/image/ReadImageToTensorMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/image/ReadImageToTensorMapper.java @@ -6,7 +6,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.io.filesystem.FilePath; import com.alibaba.alink.common.linalg.tensor.FloatTensor; import com.alibaba.alink.common.mapper.Mapper; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/RedisStringOutputFormat.java b/core/src/main/java/com/alibaba/alink/operator/common/io/RedisStringOutputFormat.java index d8dc26b7f..eb032867f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/io/RedisStringOutputFormat.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/RedisStringOutputFormat.java @@ -7,9 +7,8 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkIllegalDataException; -import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.io.redis.Redis; import com.alibaba.alink.common.io.redis.RedisClassLoaderFactory; import com.alibaba.alink.common.utils.TableUtil; @@ -17,7 +16,6 @@ import com.alibaba.alink.params.io.RedisStringSinkParams; import java.io.IOException; -import java.util.Optional; public class RedisStringOutputFormat extends RichOutputFormat { diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvFileInputSplit.java b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvFileInputSplit.java index b9d119c70..5656ae452 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvFileInputSplit.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvFileInputSplit.java @@ -26,6 +26,7 @@ public class CsvFileInputSplit implements InputSplit { private static final long serialVersionUID = -6929924752443032143L; + /** * Starting position of this split. */ @@ -63,13 +64,37 @@ public CsvFileInputSplit(int numSplits, int splitNo, long contentLength) { this.end = Long.min(start + length + BUFFER_SIZE, contentLength); } + public CsvFileInputSplit(int numSplits, int splitNo, long start, long length, long end) { + this.numSplits = numSplits; + this.splitNo = splitNo; + this.length = length; + this.start = start; + this.end = end; + } + @Override public String toString() { return "split: " + splitNo + "/" + numSplits + ", " + start + " " + length + " " + end; } + public static CsvFileInputSplit fromString(String splitStr) { + int slashPos = splitStr.indexOf("/"); + int spacePosFirst = splitStr.indexOf(" ", 7); + int spacePosSecond = splitStr.indexOf(" ", spacePosFirst + 1); + int spacePosThird = splitStr.indexOf(" ", spacePosSecond + 1); + + int splitNo = Integer.valueOf(splitStr.substring(7, slashPos)); + int numSplits = Integer.valueOf(splitStr.substring(slashPos + 1, spacePosFirst-1)); + long start = Long.valueOf(splitStr.substring(spacePosFirst + 1, spacePosSecond)); + long len = Long.valueOf(splitStr.substring(spacePosSecond + 1, spacePosThird)); + long end = Long.valueOf(splitStr.substring(spacePosThird + 1)); + + return new CsvFileInputSplit(numSplits, splitNo, start, len, end); + } + @Override public int getSplitNumber() { return this.splitNo; } + } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvFormatter.java b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvFormatter.java index a31526cfa..c347e3bfa 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvFormatter.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvFormatter.java @@ -12,6 +12,7 @@ public class CsvFormatter { private TypeInformation[] types; private String fieldDelim; + private String rowDelim; private String quoteChar; private String escapeChar; @@ -25,6 +26,25 @@ public class CsvFormatter { public CsvFormatter(TypeInformation[] types, String fieldDelim, @Nullable Character quoteChar) { this.types = types; this.fieldDelim = fieldDelim; + this.rowDelim = null; + if (quoteChar != null) { + this.quoteChar = quoteChar.toString(); + this.escapeChar = this.quoteChar; + } + } + + /** + * The Constructor. + * + * @param types Column types. + * @param fieldDelim Field delimiter in the text line. + * @param quoteChar Quoting character. Used to quote a string field if it has field delimiters. + */ + public CsvFormatter(TypeInformation[] types, String fieldDelim, @Nullable String rowDelim, + @Nullable Character quoteChar) { + this.types = types; + this.fieldDelim = fieldDelim; + this.rowDelim = rowDelim; if (quoteChar != null) { this.quoteChar = quoteChar.toString(); this.escapeChar = this.quoteChar; @@ -49,7 +69,8 @@ public String format(Row row) { } if (quoteChar != null && types[i].equals(Types.STRING)) { String str = (String) v; - if (str.isEmpty() || str.contains(fieldDelim) || str.contains(quoteChar)) { + if (str.isEmpty() || str.contains(fieldDelim) || str.contains(quoteChar) || + (rowDelim != null && str.contains(rowDelim))) { sbd.append(quoteChar); sbd.append(str.replace(quoteChar, escapeChar + quoteChar)); sbd.append(quoteChar); diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvInputFormatBeta.java b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvInputFormatBeta.java new file mode 100644 index 000000000..9155af77b --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvInputFormatBeta.java @@ -0,0 +1,101 @@ +package com.alibaba.alink.operator.common.io.csv; + +import org.apache.flink.types.Row; + +import com.alibaba.alink.operator.common.io.reader.FileSplitReader; +import com.alibaba.alink.operator.common.io.reader.QuoteUtil; + +import java.io.IOException; + +public class CsvInputFormatBeta extends GenericCsvInputFormatBeta { + + public CsvInputFormatBeta(FileSplitReader reader, + String lineDelim, boolean ignoreFirstLine) { + super(reader, lineDelim, ignoreFirstLine); + } + + public CsvInputFormatBeta(FileSplitReader reader, + String lineDelim, boolean ignoreFirstLine, + Character quoteChar) { + super(reader, lineDelim, ignoreFirstLine, quoteChar); + } + + public CsvInputFormatBeta(FileSplitReader reader, + String lineDelim, boolean ignoreFirstLine, + boolean unsplittable, Character quoteChar) { + super(reader, lineDelim, ignoreFirstLine, unsplittable, quoteChar); + } + + @Override + public void open(CsvFileInputSplit split) throws IOException { + super.open(split); + if (split.start > 0 || ignoreFirstLine) { + if (!readLine()) { + // if the first partial record already pushes the stream over + // the limit of our split, then no record starts within this split + setEnd(true); + } + } else { + fillBuffer(0); + } + } + + public void openWithoutSkipLine(CsvFileInputSplit split) throws IOException { + super.open(split); + } + + @Override + public CsvFileInputSplit[] createInputSplits(int minNumSplits) throws IOException { + // if parsing quote character, file is unsplittable + if (this.unsplittable) { + minNumSplits = 1; + } + + CsvFileInputSplit[] splits; + splits = new CsvFileInputSplit[minNumSplits]; + long contentLength = reader.getFileLength(); + for (int i = 0; i < splits.length; i++) { + splits[i] = new CsvFileInputSplit(minNumSplits, i, contentLength); + } + return splits; + } + + /** + * This format is serialize a split to row of string. + */ + public static class CsvSplitInputFormat extends CsvInputFormatBeta { + private static final int BUFFER_SIZE = 1024 * 1024; + + public CsvSplitInputFormat(FileSplitReader reader, String lineDelim, boolean ignoreFirstLine) { + super(reader, lineDelim, ignoreFirstLine); + } + + public CsvSplitInputFormat(FileSplitReader reader, String lineDelim, boolean ignoreFirstLine, + Character quoteChar) { + super(reader, lineDelim, ignoreFirstLine, false, quoteChar); + } + + @Override + public void open(CsvFileInputSplit split) throws IOException { + super.openWithoutSkipLine(split); + } + + /** + * This format is used to scan each split. It only returns one record then set end flag to true. + * The record includes quote character num and a string that stores variable used to rebuild the split. + * + * @param record Object that may be reused. + * @return quote count, split status and split serialized result + * @throws IOException + */ + @Override + public Row nextRecord(Row record) throws IOException { + long quoteNum = QuoteUtil.analyzeSplit(this.reader, quoteCharacter); + + this.setEnd(true); + this.reader.close(); + + return Row.of(quoteNum, this.reader.getSplitNumber(), this.currentSplit.toString()); + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvParser.java b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvParser.java index 953b8f5e1..b68fcf5b6 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvParser.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvParser.java @@ -8,7 +8,7 @@ import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.StringUtils; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.exceptions.AkParseErrorException; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; @@ -184,6 +184,9 @@ private Tuple2 parseField(FieldParser parser, String token if (token.equals("\"\"")) { return Tuple2.of(true, null); // spark output's null value as "" } + if (quoteChar != null && token.startsWith(quoteChar.toString()) && token.endsWith(quoteChar.toString())) { + token = token.substring(1, token.length() - 1); + } byte[] bytes = token.getBytes(); parser.resetErrorStateAndParse(bytes, 0, bytes.length, fieldDelim.getBytes(), null); FieldParser.ParseErrorState errorState = parser.getErrorState(); diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvTypeConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvTypeConverter.java index b22caee58..a162e7469 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvTypeConverter.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvTypeConverter.java @@ -4,7 +4,8 @@ import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.ml.api.misc.param.Params; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import com.alibaba.alink.common.linalg.VectorType; import com.alibaba.alink.common.linalg.tensor.DataType; import com.alibaba.alink.params.dataproc.ToTensorParams; @@ -76,7 +77,7 @@ public static DataType tensorTypeInformationToTensorType(TypeInformation typ return DataType.BOOLEAN; } - throw new IllegalArgumentException("Unsupported tensor type. " + typeInformation); + throw new AkUnsupportedOperationException("Unsupported tensor type. " + typeInformation); } public static PipelineModel toTensorPipelineModel( @@ -117,7 +118,7 @@ public static VectorType vectorTypeInformationToVectorType(TypeInformation t return VectorType.SPARSE; } - throw new IllegalArgumentException("Unsupported vector type. " + typeInformation); + throw new AkUnsupportedOperationException("Unsupported vector type. " + typeInformation); } public static PipelineModel toVectorPipelineModel( diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvUtil.java b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvUtil.java index 4a784be5b..62563260b 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvUtil.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/CsvUtil.java @@ -78,17 +78,26 @@ public static class FormatCsvFunc extends RichMapFunction { private transient CsvFormatter formatter; private final TypeInformation [] colTypes; private final String fieldDelim; + private final String rowDelim; private final Character quoteChar; public FormatCsvFunc(TypeInformation [] colTypes, String fieldDelim, Character quoteChar) { this.colTypes = colTypes; this.fieldDelim = fieldDelim; + this.rowDelim = null; + this.quoteChar = quoteChar; + } + + public FormatCsvFunc(TypeInformation [] colTypes, String fieldDelim, String rowDelim, Character quoteChar) { + this.colTypes = colTypes; + this.fieldDelim = fieldDelim; + this.rowDelim = rowDelim; this.quoteChar = quoteChar; } @Override public void open(Configuration parameters) throws Exception { - this.formatter = new CsvFormatter(colTypes, fieldDelim, quoteChar); + this.formatter = new CsvFormatter(colTypes, fieldDelim, rowDelim, quoteChar); } @Override diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/GenericCsvInputFormat.java b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/GenericCsvInputFormat.java index 6c1c74ebb..55d4fc9cc 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/GenericCsvInputFormat.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/GenericCsvInputFormat.java @@ -20,7 +20,6 @@ import org.apache.flink.api.common.io.DefaultInputSplitAssigner; import org.apache.flink.api.common.io.InputFormat; -import org.apache.flink.api.common.io.ParseException; import org.apache.flink.api.common.io.statistics.BaseStatistics; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.configuration.Configuration; @@ -29,6 +28,7 @@ import org.apache.flink.types.parser.FieldParser; import org.apache.flink.util.InstantiationUtil; +import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.exceptions.AkParseErrorException; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; import com.alibaba.alink.operator.common.io.reader.FileSplitReader; @@ -103,7 +103,9 @@ private static Class [] extractTypeClasses(TypeInformation[] fieldTypes) { private void initBuffers() { if (BUFFER_SIZE <= this.lineDelim.length) { - throw new IllegalArgumentException("Buffer size must be greater than length of delimiter."); + throw new AkIllegalDataException(String.format( + "Buffer size is %d, and delimiter length is %d. Buffer size must be greater than length of delimiter.", + BUFFER_SIZE, this.lineDelim.length)); } if (this.readBuffer == null || this.readBuffer.length != BUFFER_SIZE) { @@ -129,7 +131,8 @@ private void initializeParsers() { if (fieldClasses[i] != null) { Class > parserType = FieldParser.getParserForType(fieldClasses[i]); if (parserType == null) { - throw new AkParseErrorException("No parser available for type '" + fieldClasses[i].getName() + "'."); + throw new AkParseErrorException( + "No parser available for type '" + fieldClasses[i].getName() + "'."); } FieldParser p = InstantiationUtil.instantiate(parserType, FieldParser.class); @@ -160,7 +163,8 @@ public void open(CsvFileInputSplit split) throws IOException { this.bytesRead = 0L; initBuffers(); - this.reader.open(split, split.start, split.end - 1); + //this.reader.open(split, split.start, split.end - 1); + this.reader.open(split); this.readerClosed = false; initializeParsers(); @@ -306,7 +310,7 @@ protected final boolean readLine() throws IOException { // check against the maximum record length if (((long) countInWrapBuffer) + count > LINE_LENGTH_LIMIT) { - throw new IOException("The record length exceeded the maximum record length (" + + throw new AkIllegalDataException("The record length exceeded the maximum record length (" + LINE_LENGTH_LIMIT + ")."); } @@ -366,7 +370,8 @@ private boolean fillBuffer(int offset) throws IOException { // unexpected EOF encountered, re-establish the connection if (read < 0) { this.reader.close(); - this.reader.open(this.split, this.split.start + bytesRead, this.split.end - 1); + //this.reader.open(this.split, this.split.start + bytesRead, this.split.end - 1); + this.reader.reopen(this.split, this.split.start + bytesRead); } tryTimes++; } @@ -446,7 +451,7 @@ protected Row readRecord(Row reuse, byte[] bytes, int offset, int numBytes) thro } else if (parser.getErrorState() == FieldParser.ParseErrorState.EMPTY_COLUMN) { reuseRow.setField(field, null); } else { - throw new ParseException( + throw new AkParseErrorException( String.format("Parsing error for column %1$s of row '%2$s' originated by %3$s: %4$s.", field, new String(bytes, offset, numBytes), parser.getClass().getSimpleName(), parser.getErrorState())); @@ -462,10 +467,11 @@ field, new String(bytes, offset, numBytes), parser.getClass().getSimpleName(), startPos++; } if (startPos + fieldDelim.length > offset + numBytes) { - throw new AkUnclassifiedErrorException("Can't find next field delimiter: " + "\"" + fieldDelimStr + "\"," - + " " + - "Perhaps the data is invalid or do not match the schema." + - "The row is: " + new String(bytes, offset, numBytes)); + throw new AkParseErrorException( + "Can't find next field delimiter: " + "\"" + fieldDelimStr + "\"," + + " " + + "Perhaps the data is invalid or do not match the schema." + + "The row is: " + new String(bytes, offset, numBytes)); } startPos += fieldDelim.length; } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/GenericCsvInputFormatBeta.java b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/GenericCsvInputFormatBeta.java new file mode 100644 index 000000000..b2b8b0f11 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/GenericCsvInputFormatBeta.java @@ -0,0 +1,543 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.alink.operator.common.io.csv; + +import org.apache.flink.api.common.io.DefaultInputSplitAssigner; +import org.apache.flink.api.common.io.InputFormat; +import org.apache.flink.api.common.io.statistics.BaseStatistics; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.io.InputSplit; +import org.apache.flink.core.io.InputSplitAssigner; +import org.apache.flink.types.Row; +import org.apache.flink.types.parser.FieldParser; + +import com.alibaba.alink.common.exceptions.AkIllegalDataException; +import com.alibaba.alink.common.exceptions.AkParseErrorException; +import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; +import com.alibaba.alink.operator.common.io.dummy.DummyFiledParser; +import com.alibaba.alink.operator.common.io.reader.FileSplitReader; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.charset.Charset; + +/** + * Generic csv input format which supports multiple types of data reader. + */ +public class GenericCsvInputFormatBeta implements InputFormat { + // -------------------------------------- Constants ------------------------------------------- + + private static final Logger LOG = LoggerFactory.getLogger(GenericCsvInputFormatBeta.class); + + private static final long serialVersionUID = -7327585548493815210L; + + private static final int LINE_LENGTH_LIMIT = 1024 * 1024 * 1024; + + protected static final long READ_WHOLE_SPLIT_FLAG = -1L; + + protected final FileSplitReader reader; + + // The charset used to convert strings to bytes + private String charsetName = "UTF-8"; + + // Charset is not serializable + private transient Charset charset; + + /** + * The default read buffer size = 1MB. + */ + private static final int BUFFER_SIZE = 1024 * 1024; + + // -------------------------------------------------------------------------------------------- + // Variables for internal parsing. + // They are all transient, because we do not want them so be serialized + // They are copied from FileInputFormat,java + // -------------------------------------------------------------------------------------------- + + /** + * The start of the split that this parallel instance must consume. + */ + protected transient long splitStart; + + /** + * remaining bytes of current split to read + */ + private transient long splitLength; + + /** + * The current split that this parallel instance must consume. + */ + protected transient T currentSplit; + + // -------------------------------------------------------------------------------------------- + // The configuration parameters. Configured on the instance and serialized to be shipped. + // -------------------------------------------------------------------------------------------- + + // The delimiter may be set with a byte-sequence or a String. In the latter + // case the byte representation is updated consistent with current charset. + private byte[] delimiter; + + // To speed up readRecord processing. Used to find windows line endings. + // It is set when open so that readRecord does not have to evaluate it + protected boolean lineDelimiterIsLinebreak = false; + + protected final boolean ignoreFirstLine; + + /** + * Some file input formats are not splittable on a block level (deflate) + * Therefore, the FileInputFormat can only read whole files. + * The file is unsplittable when quotedStringParsing is true. + */ + protected boolean unsplittable = false; + protected boolean quotedStringParsing = false; + + protected byte quoteCharacter; + + // -------------------------------------------------------------------------------------------- + // Transient variables copied from DelimitedInputFormat.java + // -------------------------------------------------------------------------------------------- + + private transient byte[] readBuffer; // buffer for holding data read by reader + + private transient byte[] wrapBuffer; + + private transient int readPos; // reading position of the read buffer + + private transient int limit; // number of valid bytes in the read buffer + + private transient byte[] currBuffer; // buffer in which current record byte sequence is found + private transient int currOffset; // offset in above buffer + private transient int currLen; // length of current byte sequence + + private transient boolean overLimit; // flag indicating whether we have read beyond the split + + private transient boolean end; + + private long offset = -1; + + private transient long bytesRead; // number of bytes read by reader + + private transient boolean readerClosed; + + // for parsing fields of a reacord + private transient FieldParser fieldParser = null; + private transient Object[] holders = null; + private boolean fieldInQuote; + + // -------------------------------------------------------------------------------------------- + // Constructors & Getters/setters for the configurable parameters + // -------------------------------------------------------------------------------------------- + + public GenericCsvInputFormatBeta(FileSplitReader reader, String delimiter, boolean ignoreFirstLine) { + this.reader = reader; + this.charset = Charset.forName(charsetName); + this.delimiter = delimiter.getBytes(); + this.ignoreFirstLine = ignoreFirstLine; + } + + public GenericCsvInputFormatBeta(FileSplitReader reader, String delimiter, boolean ignoreFirstLine, + Character quoteChar) { + this(reader, delimiter, ignoreFirstLine); + if (quoteChar != null) { + this.unsplittable = true; + this.quotedStringParsing = true; + this.quoteCharacter = (byte) quoteChar.charValue(); + } + } + + public GenericCsvInputFormatBeta(FileSplitReader reader, String delimiter, boolean ignoreFirstLine, + boolean unsplittable, Character quoteChar) { + this(reader, delimiter, ignoreFirstLine); + if (quoteChar != null) { + this.unsplittable = unsplittable; + this.quoteCharacter = (byte) quoteChar.charValue(); + } + this.quotedStringParsing = true; + } + + protected void setEnd(Boolean end) {this.end = end;} + + protected void setFieldInQuote(boolean fieldInQuote) {this.fieldInQuote = fieldInQuote;} + + /** + * Opens the given input split. This method opens the input stream to the specified file, allocates read buffers + * and positions the stream at the correct position, making sure that any partial record at the beginning is + * skipped. + * + * @param split The input split to open. + * @see org.apache.flink.api.common.io.FileInputFormat#open(org.apache.flink.core.fs.FileInputSplit) + */ + @Override + public void open(T split) throws IOException { + this.currentSplit = split; + // open and assign a split to the reader + this.reader.open(split); + this.readerClosed = false; + + this.splitStart = this.reader.getSplitStart(); + this.splitLength = this.reader.getSplitLength(); + + this.charset = Charset.forName(charsetName); + + this.bytesRead = 0L; + + initBuffers(); + + initializeParsers(); + + // left to right evaluation makes access [0] okay + // this marker is used to fasten up readRecord, so that it doesn't have to check each call if the line ending + // is set to default + if (delimiter.length == 1 && delimiter[0] == '\n') { + this.lineDelimiterIsLinebreak = true; + } + } + + // copied from DelimitedInputFormat.initBuffers + private void initBuffers() { + if (BUFFER_SIZE <= this.delimiter.length) { + throw new AkIllegalDataException(String.format( + "Buffer size is %d, and delimiter length is %d. Buffer size must be greater than length of delimiter.", + BUFFER_SIZE, this.delimiter.length)); + } + + if (this.readBuffer == null || this.readBuffer.length != BUFFER_SIZE) { + this.readBuffer = new byte[BUFFER_SIZE]; + } + if (this.wrapBuffer == null || this.wrapBuffer.length < 256) { + this.wrapBuffer = new byte[256]; + } + + this.readPos = 0; + this.limit = 0; + this.overLimit = false; + this.end = false; + } + + private void initializeParsers() { + + // instantiate the parsers + fieldParser = new DummyFiledParser(); + fieldParser.setCharset(charset); + if (this.quotedStringParsing) { + ((DummyFiledParser) fieldParser).enableQuotedStringParsing(this.quoteCharacter); + } + this.holders = new Object[] {fieldParser.createValue()}; + } + + /** + * Closes the input by releasing all buffers and closing the file input stream. + * + * @throws IOException Thrown, if the closing of the file stream causes an I/O error. + */ + @Override + public void close() throws IOException { + this.wrapBuffer = null; + this.readBuffer = null; + if (!this.readerClosed) { + this.reader.close(); + this.readerClosed = true; + } + } + + /** + * Checks whether the current split is at its end. + * + * @return True, if the split is at its end, false otherwise. + */ + @Override + public boolean reachedEnd() { + return this.end; + } + + @Override + public Row nextRecord(Row record) throws IOException { + if (readLine()) { + if (0 == this.currLen) { + // current line is empty, then skip + return nextRecord(record); + } else { + return readRecord(record, this.currBuffer, this.currOffset, this.currLen); + } + } else { + this.end = true; + return null; + } + } + + /** + * Copy one line of data from "readBuffer" to "currBuffer". + * If "readBuffer" is fully consumed, trigger "fillBuffer()" to fill it. + */ + protected final boolean readLine() throws IOException { + if (this.readerClosed || this.overLimit) { + return false; + } + + int countInWrapBuffer = 0; + + // position of matching positions in the delimiter byte array + int delimPos = 0; + boolean findQuote = this.fieldInQuote; + + while (true) { + if (this.readPos >= this.limit) { + // readBuffer is completely consumed. Fill it again but keep partially read delimiter bytes. + if (!fillBuffer(delimPos)) { + int countInReadBuffer = delimPos; + if (countInWrapBuffer + countInReadBuffer > 0) { + // we have bytes left to emit + if (countInReadBuffer > 0) { + // we have bytes left in the readBuffer. Move them into the wrapBuffer + if (this.wrapBuffer.length - countInWrapBuffer < countInReadBuffer) { + // reallocate + byte[] tmp = new byte[countInWrapBuffer + countInReadBuffer]; + System.arraycopy(this.wrapBuffer, 0, tmp, 0, countInWrapBuffer); + this.wrapBuffer = tmp; + } + + // copy readBuffer bytes to wrapBuffer + System.arraycopy(this.readBuffer, 0, this.wrapBuffer, countInWrapBuffer, + countInReadBuffer); + countInWrapBuffer += countInReadBuffer; + } + setResult(this.wrapBuffer, 0, countInWrapBuffer,findQuote); + return true; + } else { + return false; + } + } + } + + int startPos = this.readPos - delimPos; + int count; + // Search for next occurence of delimiter in read buffer. + while (this.readPos < this.limit && delimPos < this.delimiter.length) { + if (!findQuote && (this.readBuffer[this.readPos]) == this.delimiter[delimPos]) { + // Found the expected delimiter character. Continue looking for the next character of delimiter. + delimPos++; + } else if (quotedStringParsing && this.readBuffer[this.readPos] == quoteCharacter) { + findQuote = !findQuote; + } else { + // Delimiter does not match. + // We have to reset the read position to the character after the first matching character + // and search for the whole delimiter again. + readPos -= delimPos; + delimPos = 0; + } + readPos++; + } + + // check why we dropped out + if (delimPos == this.delimiter.length) { + // we found a delimiter + int readBufferBytesRead = this.readPos - startPos; + count = readBufferBytesRead - this.delimiter.length; + + // copy to byte array + if (countInWrapBuffer > 0) { + // check wrap buffer size + if (this.wrapBuffer.length < countInWrapBuffer + count) { + final byte[] nb = new byte[countInWrapBuffer + count]; + System.arraycopy(this.wrapBuffer, 0, nb, 0, countInWrapBuffer); + this.wrapBuffer = nb; + } + if (count >= 0) { + System.arraycopy(this.readBuffer, 0, this.wrapBuffer, countInWrapBuffer, count); + } + setResult(this.wrapBuffer, 0, countInWrapBuffer + count,findQuote); + return true; + } else { + setResult(this.readBuffer, startPos, count,findQuote); + return true; + } + } else { + // we reached the end of the readBuffer + count = this.limit - startPos; + + // check against the maximum record length + if (((long) countInWrapBuffer) + count > LINE_LENGTH_LIMIT) { + throw new AkIllegalDataException("The record length exceeded the maximum record length (" + + LINE_LENGTH_LIMIT + ")."); + } + + // Compute number of bytes to move to wrapBuffer + // Chars of partially read delimiter must remain in the readBuffer. We might need to go back. + int bytesToMove = count - delimPos; + // ensure wrapBuffer is large enough + if (this.wrapBuffer.length - countInWrapBuffer < bytesToMove) { + // reallocate + byte[] tmp = new byte[Math.max(this.wrapBuffer.length * 2, countInWrapBuffer + bytesToMove)]; + System.arraycopy(this.wrapBuffer, 0, tmp, 0, countInWrapBuffer); + this.wrapBuffer = tmp; + } + + // copy readBuffer to wrapBuffer (except delimiter chars) + System.arraycopy(this.readBuffer, startPos, this.wrapBuffer, countInWrapBuffer, bytesToMove); + countInWrapBuffer += bytesToMove; + // move delimiter chars to the beginning of the readBuffer + System.arraycopy(this.readBuffer, this.readPos - delimPos, this.readBuffer, 0, delimPos); + + } + } + } + + private void setResult(byte[] buffer, int offset, int len, boolean fieldInQuote) { + this.currBuffer = buffer; + this.currOffset = offset; + this.currLen = len; + this.fieldInQuote = fieldInQuote; + } + + /** + * Fills the read buffer with bytes read from the file starting from an offset. + * Returns false if has reached the end of the split and nothing is read. + */ + protected boolean fillBuffer(int offset) throws IOException { + int maxReadLength = this.readBuffer.length - offset; + // special case for reading the whole split. + if (this.splitLength == READ_WHOLE_SPLIT_FLAG) { + int read = this.reader.read(this.readBuffer, offset, maxReadLength); + if (read == -1) { + this.reader.close(); + this.readerClosed = true; + return false; + } else { + this.readPos = offset; + this.limit = read; + return true; + } + } + + // else .. + int toRead; + if (this.splitLength > 0) { + // if we have some data to read in the split, read that + toRead = this.splitLength > maxReadLength ? maxReadLength : (int) this.splitLength; + } else { + // if we have exhausted our split, we need to complete the current record, or read one + // more across the next split. + // the reason is that the next split will skip over the beginning until it finds the first + // delimiter, discarding it as an incomplete chunk of data that belongs to the last record in the + // previous split. + toRead = maxReadLength; + this.overLimit = true; + } + + int tryTimes = 0; + int maxTryTimes = 10; + int read = -1; + long start = this.reader.getSplitStart(); + long end = this.reader.getSplitEnd(); + + while (this.bytesRead + start < end && read == -1 && tryTimes < maxTryTimes) { + read = this.reader.read(this.readBuffer, offset, toRead); + + // unexpected EOF encountered, re-establish the connection + if (read < 0) { + this.reader.close(); + //this.reader.open(this.currentSplit); + this.reader.reopen(this.currentSplit, splitStart + bytesRead); + } + tryTimes++; + } + + if (tryTimes >= maxTryTimes) { + throw new AkUnclassifiedErrorException("Fail to read data."); + } + + if (read == -1) { + this.reader.close(); + this.readerClosed = true; + return false; + } else { + this.splitLength -= read; + this.readPos = offset; // position from where to start reading + this.limit = read + offset; // number of valid bytes in the read buffer + this.bytesRead += read; + return true; + } + } + + @Override + public void configure(Configuration parameters) { + } + + @Override + public BaseStatistics getStatistics(BaseStatistics cachedStatistics) throws IOException { + return null; + } + + @Override + public T[] createInputSplits(int minNumSplits) throws IOException { + return null; + } + + @Override + public InputSplitAssigner getInputSplitAssigner(T[] inputSplits) { + return new DefaultInputSplitAssigner(inputSplits); + } + + protected Row readRecord(Row reuse, byte[] bytes, int offset, int numBytes) throws IOException { + Row reuseRow; + if (reuse == null) { + reuseRow = new Row(1); + } else { + reuseRow = reuse; + } + + // Found window's end line, so find carriage return before the newline + if (this.lineDelimiterIsLinebreak && numBytes > 0 && bytes[offset + numBytes - 1] == '\r') { + //reduce the number of bytes so that the Carriage return is not taken as data + numBytes--; + } + + int startPos = offset; + int field = 0; + FieldParser parser = (FieldParser ) fieldParser; + + int newStartPos = parser.resetErrorStateAndParse( + bytes, + startPos, + offset + numBytes, + delimiter, + holders[field]); + + if (parser.getErrorState() != FieldParser.ParseErrorState.NONE) { + if (parser.getErrorState() == FieldParser.ParseErrorState.EMPTY_COLUMN) { + reuseRow.setField(field, null); + } else { + throw new AkParseErrorException( + String.format("Parsing error for column %1$s of row '%2$s' originated by %3$s: %4$s.", + field, new String(bytes, offset, numBytes), parser.getClass().getSimpleName(), + parser.getErrorState())); + } + } else { + reuseRow.setField(field, parser.getLastResult()); + } + + if (newStartPos >= 0) { + startPos = newStartPos; + } + + return reuseRow; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/InternalCsvSourceBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/InternalCsvSourceBatchOp.java index 554688d82..72e82575f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/InternalCsvSourceBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/InternalCsvSourceBatchOp.java @@ -14,16 +14,17 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.Internal; +import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; import com.alibaba.alink.common.io.filesystem.FilePath; import com.alibaba.alink.common.io.filesystem.copy.csv.RowCsvInputFormat; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.source.BaseSourceBatchOp; +import com.alibaba.alink.operator.batch.utils.DataSetUtil; import com.alibaba.alink.operator.common.io.partition.CsvSourceCollectorCreator; -import com.alibaba.alink.operator.common.io.partition.Utils; import com.alibaba.alink.operator.common.io.reader.HttpFileSplitReader; import com.alibaba.alink.params.io.CsvSourceParams; @@ -123,12 +124,13 @@ public Table initializeDataSource() { Tuple2 , TableSchema> schemaAndData; try { - schemaAndData = Utils.readFromPartitionBatch( + schemaAndData = DataSetUtil.readFromPartitionBatch( getParams(), getMLEnvironmentId(), - new CsvSourceCollectorCreator(dummySchema, rowDelim, ignoreFirstLine) + new CsvSourceCollectorCreator(dummySchema, rowDelim, ignoreFirstLine, quoteChar) ); } catch (IOException e) { - throw new IllegalStateException(e); + throw new AkUnclassifiedErrorException( + String.format("Fail to list directories in %s and select partitions", getFilePath().getPathStr())); } rows = schemaAndData.f0; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/InternalCsvSourceBetaBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/InternalCsvSourceBetaBatchOp.java new file mode 100644 index 000000000..688e2e69b --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/InternalCsvSourceBetaBatchOp.java @@ -0,0 +1,175 @@ +package com.alibaba.alink.operator.common.io.csv; + +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.io.InputSplit; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.MLEnvironmentFactory; +import com.alibaba.alink.common.annotation.Internal; +import com.alibaba.alink.common.exceptions.AkIllegalDataException; +import com.alibaba.alink.common.io.annotations.AnnotationUtils; +import com.alibaba.alink.common.io.annotations.IOType; +import com.alibaba.alink.common.io.annotations.IoOpAnnotation; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.source.BaseSourceBatchOp; +import com.alibaba.alink.operator.batch.utils.DataSetUtil; +import com.alibaba.alink.operator.common.io.partition.CsvSourceCollectorCreator; +import com.alibaba.alink.operator.common.io.reader.FSFileSplitReader; +import com.alibaba.alink.operator.common.io.reader.FileSplitReader; +import com.alibaba.alink.operator.common.io.reader.HttpFileSplitReader; +import com.alibaba.alink.params.io.CsvSourceParams; + +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; +import java.util.List; + +@Internal +@IoOpAnnotation(name = "internal_csv_beta", ioType = IOType.SourceBatch) +public class InternalCsvSourceBetaBatchOp extends BaseSourceBatchOp + implements CsvSourceParams { + + public InternalCsvSourceBetaBatchOp() { + this(new Params()); + } + + public InternalCsvSourceBetaBatchOp(Params params) { + super(AnnotationUtils.annotatedName(InternalCsvSourceBetaBatchOp.class), params); + + } + + @Override + public Table initializeDataSource() { + final String filePath = getFilePath().getPathStr(); + final String schemaStr = getSchemaStr(); + final String fieldDelim = getFieldDelimiter(); + final String rowDelim = getRowDelimiter(); + final Character quoteChar = getQuoteChar(); + final boolean skipBlankLine = getSkipBlankLine(); + final boolean lenient = getLenient(); + + final String[] colNames = TableUtil.getColNames(schemaStr); + final TypeInformation [] colTypes = TableUtil.getColTypes(schemaStr); + + boolean ignoreFirstLine = getIgnoreFirstLine(); + String protocol = ""; + + try { + URL url = new URL(filePath); + protocol = url.getProtocol(); + } catch (MalformedURLException ignored) { + } + + DataSet rows; + DataSet splits; + ExecutionEnvironment execEnv = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment(); + TableSchema dummySchema; + + String partitions = getPartitions(); + + if (partitions == null) { + FileSplitReader reader; + + if (protocol.equalsIgnoreCase("http") || protocol.equalsIgnoreCase("https")) { + reader = new HttpFileSplitReader(filePath); + } else { + reader = new FSFileSplitReader(getFilePath()); + } + + if (getQuoteChar() != null) { + dummySchema = new TableSchema(new String[] {"_QUOTE_NUM_", "_SPLIT_NUMBER_", "_SPLIT_INFO_"}, + new TypeInformation[] {Types.LONG, Types.LONG, Types.STRING}); + + splits = execEnv + .createInput(reader.convertFileSplitToInputFormat(rowDelim, ignoreFirstLine, quoteChar), + new RowTypeInfo(dummySchema.getFieldTypes(), dummySchema.getFieldNames())) + .name("csv_split_summary_source"); + + rows = splits.flatMap(new RichFlatMapFunction () { + boolean[] filedInQuote; + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + + List splits = getRuntimeContext().getBroadcastVariable("splits"); + + filedInQuote = new boolean[splits.size()]; + long[] sum = new long[splits.size()]; + + for (Row row : splits) { + long splitNum = (long) row.getField(1); + sum[(int) splitNum] = (long) row.getField(0); + } + + filedInQuote[0] = false; + + for (int i = 1; i < sum.length; i++) { + sum[i] = sum[i - 1] + sum[i]; + filedInQuote[i] = sum[i - 1] % 2 == 1; + } + } + + @Override + public void flatMap(Row value, Collector out) throws Exception { + long splitNum = (long) value.getField(1); + String splitStr = (String) value.getField(2); + InputSplit split = reader.convertStringToSplitObject(splitStr); + + GenericCsvInputFormatBeta inputFormat = reader.getInputFormat(rowDelim, ignoreFirstLine, + quoteChar); + inputFormat.setFieldInQuote(filedInQuote[(int) splitNum]); + inputFormat.open(split); + + while (true) { + Row line = inputFormat.nextRecord(null); + if (line == null) { + break; + } + out.collect(line); + } + } + }).withBroadcastSet(splits, "splits").name("csv_flat_map"); + } else { + dummySchema = new TableSchema(new String[] {"f1"}, new TypeInformation[] {Types.STRING}); + + rows = execEnv + .createInput(reader.getInputFormat(rowDelim, ignoreFirstLine, quoteChar), + new RowTypeInfo(dummySchema.getFieldTypes(), dummySchema.getFieldNames())) + .name("csv_source"); + } + } else { + dummySchema = new TableSchema(new String[] {"f1"}, new TypeInformation[] {Types.STRING}); + + Tuple2 , TableSchema> schemaAndData; + + try { + schemaAndData = DataSetUtil.readFromPartitionBatch( + getParams(), getMLEnvironmentId(), + new CsvSourceCollectorCreator(dummySchema, rowDelim, ignoreFirstLine, quoteChar) + ); + } catch (IOException e) { + throw new AkIllegalDataException( + String.format("Fail to list directories in %s and select partitions", getFilePath().getPathStr())); + } + + rows = schemaAndData.f0; + } + + rows = rows.flatMap(new CsvUtil.ParseCsvFunc(colTypes, fieldDelim, quoteChar, skipBlankLine, lenient)); + + return DataSetConversionUtil.toTable(getMLEnvironmentId(), rows, colNames, colTypes); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/InternalCsvSourceStreamOp.java b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/InternalCsvSourceStreamOp.java index 6c81252c2..63541be58 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/io/csv/InternalCsvSourceStreamOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/csv/InternalCsvSourceStreamOp.java @@ -14,17 +14,18 @@ import com.alibaba.alink.common.MLEnvironmentFactory; import com.alibaba.alink.common.annotation.Internal; +import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.io.annotations.AnnotationUtils; import com.alibaba.alink.common.io.annotations.IOType; import com.alibaba.alink.common.io.annotations.IoOpAnnotation; import com.alibaba.alink.common.io.filesystem.FilePath; import com.alibaba.alink.common.io.filesystem.copy.csv.RowCsvInputFormat; -import com.alibaba.alink.common.utils.DataStreamConversionUtil; +import com.alibaba.alink.operator.stream.utils.DataStreamConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.common.io.partition.CsvSourceCollectorCreator; -import com.alibaba.alink.operator.common.io.partition.Utils; import com.alibaba.alink.operator.common.io.reader.HttpFileSplitReader; import com.alibaba.alink.operator.stream.source.BaseSourceStreamOp; +import com.alibaba.alink.operator.stream.utils.DataStreamUtil; import com.alibaba.alink.params.io.CsvSourceParams; import java.io.IOException; @@ -104,12 +105,13 @@ public Table initializeDataSource() { Tuple2 , TableSchema> schemaAndData; try { - schemaAndData = Utils.readFromPartitionStream( + schemaAndData = DataStreamUtil.readFromPartitionStream( getParams(), getMLEnvironmentId(), - new CsvSourceCollectorCreator(dummySchema, rowDelim, ignoreFirstLine) + new CsvSourceCollectorCreator(dummySchema, rowDelim, ignoreFirstLine, quoteChar) ); } catch (IOException e) { - throw new IllegalStateException(e); + throw new AkIllegalDataException( + String.format("Fail to list directories in %s and select partitions", getFilePath().getPathStr())); } rows = schemaAndData.f0; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/dummy/DummyFiledParser.java b/core/src/main/java/com/alibaba/alink/operator/common/io/dummy/DummyFiledParser.java new file mode 100644 index 000000000..92b743de9 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/dummy/DummyFiledParser.java @@ -0,0 +1,77 @@ +package com.alibaba.alink.operator.common.io.dummy; + +import org.apache.flink.types.parser.FieldParser; + +public class DummyFiledParser extends FieldParser { + + private boolean quotedStringParsing = false; + private byte quoteCharacter; + + private String result; + + public void enableQuotedStringParsing(byte quoteCharacter) { + this.quotedStringParsing = true; + this.quoteCharacter = quoteCharacter; + } + + /** + * DummyFieldParser doesn't split the input line with field delimiter. It only trims the last delimiter of the + * line, then regards the input line as one whole filed + * + * @param bytes The byte array that holds the value. + * @param startPos The index where the field starts + * @param limit The limit unto which the byte contents is valid for the parser. The limit is the + * position one after the last valid byte. + * @param delim The field delimiter character + * @param reuse An optional reusable field to hold the value + * @return The index of the next delimiter, if the field was parsed correctly. A value less than 0 otherwise. + */ + @Override + public int parseField(byte[] bytes, int startPos, int limit, byte[] delimiter, String reusable) { + + if (startPos == limit) { + setErrorState(ParseErrorState.EMPTY_COLUMN); + this.result = ""; + return limit; + } + + int i = startPos; + + final int delimLimit = limit - delimiter.length + 1; + + // look for delimiter + boolean lookForQuote = false; + while (i < delimLimit) { + if (!lookForQuote && delimiterNext(bytes, i, delimiter)) { + i += delimiter.length; + break; + } else if (this.quotedStringParsing && bytes[i] == quoteCharacter) { + lookForQuote = !lookForQuote; + } + i++; + } + + if (i >= delimLimit) { + this.result = new String(bytes, startPos, limit - startPos, getCharset()); + return limit; + } else { + // delimiter found. + if (i == startPos) { + setErrorState(ParseErrorState.EMPTY_COLUMN); // mark empty column + } + this.result = new String(bytes, startPos, i - startPos, getCharset()); + return i + delimiter.length; + } + } + + @Override + public String createValue() { + return ""; + } + + @Override + public String getLastResult() { + return this.result; + } +} + diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/partition/CsvSinkCollectorCreator.java b/core/src/main/java/com/alibaba/alink/operator/common/io/partition/CsvSinkCollectorCreator.java index ea451644a..73cd0347e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/io/partition/CsvSinkCollectorCreator.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/partition/CsvSinkCollectorCreator.java @@ -40,7 +40,7 @@ public void collect(Row record) { flattenCsvFromRow.map(formatCsvFunc.map(record)) ); } catch (Exception e) { - throw new AkUnclassifiedErrorException("Error. ",e); + throw new AkUnclassifiedErrorException("CsvSinkCollectorCreator collect error. ",e); } } @@ -49,7 +49,7 @@ public void close() { try { textOutputFormat.close(); } catch (IOException e) { - throw new AkUnclassifiedErrorException("Error. ",e); + throw new AkUnclassifiedErrorException("CsvSinkCollectorCreator close error. ",e); } } }; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/partition/CsvSourceCollectorCreator.java b/core/src/main/java/com/alibaba/alink/operator/common/io/partition/CsvSourceCollectorCreator.java index e7582a105..ca725438e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/io/partition/CsvSourceCollectorCreator.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/partition/CsvSourceCollectorCreator.java @@ -1,53 +1,73 @@ package com.alibaba.alink.operator.common.io.partition; +import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.core.fs.FileInputSplit; import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; +import com.alibaba.alink.common.io.filesystem.AkUtils; +import com.alibaba.alink.common.io.filesystem.AkUtils.FileProcFunction; import com.alibaba.alink.common.io.filesystem.FilePath; -import com.alibaba.alink.common.io.filesystem.copy.csv.RowCsvInputFormat; +import com.alibaba.alink.operator.common.io.csv.GenericCsvInputFormatBeta; +import com.alibaba.alink.operator.common.io.reader.FSFileSplitReader; import java.io.IOException; public class CsvSourceCollectorCreator implements SourceCollectorCreator { - private final TableSchema dummySchema; - private final String rowDelim; + private final Character quoteChar; + private final boolean ignoreFirstLine; - public CsvSourceCollectorCreator(TableSchema dummySchema, String rowDelim, boolean ignoreFirstLine) { - this.dummySchema = dummySchema; + private final String[] dataFieldNames; + + private final TypeInformation[] dataFieldTypes; + + public CsvSourceCollectorCreator(TableSchema dummySchema, String rowDelim, boolean ignoreFirstLine, + Character quoteChar) { + this.dataFieldNames = dummySchema.getFieldNames(); + this.dataFieldTypes = dummySchema.getFieldTypes(); this.rowDelim = rowDelim; this.ignoreFirstLine = ignoreFirstLine; + this.quoteChar = quoteChar; } @Override public TableSchema schema() { - return dummySchema; + return new TableSchema(dataFieldNames, dataFieldTypes); } @Override public void collect(FilePath filePath, Collector collector) throws IOException { - RowCsvInputFormat inputFormat = new RowCsvInputFormat( - filePath.getPath(), dummySchema.getFieldTypes(), - rowDelim, rowDelim, new int[] {0}, true, - filePath.getFileSystem() - ); - inputFormat.setSkipFirstLineAsHeader(ignoreFirstLine); - - try { - inputFormat.open(new FileInputSplit(1, filePath.getPath(), 0, -1, null)); + AkUtils.getFromFolderForEach( + filePath, + new FileProcFunction () { + @Override + public Boolean apply(FilePath filePath) throws IOException { + FSFileSplitReader reader = new FSFileSplitReader(filePath); + GenericCsvInputFormatBeta inputFormat = reader.getInputFormat(rowDelim, ignoreFirstLine, + quoteChar); - while (!inputFormat.reachedEnd()) { - collector.collect(inputFormat.nextRecord(null)); - } + try { + inputFormat.open(new FileInputSplit(1, filePath.getPath(), 0, reader.getFileLength(), null)); - } finally { - inputFormat.close(); - } + while (!inputFormat.reachedEnd()) { + Row record = inputFormat.nextRecord(null); + if (record != null) { + collector.collect(record); + } else { + break; + } + } + } finally { + inputFormat.close(); + } + return true; + } + }); } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/partition/LocalUtils.java b/core/src/main/java/com/alibaba/alink/operator/common/io/partition/LocalUtils.java index c80ee2995..853bf40fb 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/io/partition/LocalUtils.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/partition/LocalUtils.java @@ -9,6 +9,7 @@ import com.alibaba.alink.common.MTableUtil; import com.alibaba.alink.common.MTableUtil.GroupFunction; +import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.io.filesystem.AkUtils; import com.alibaba.alink.common.io.filesystem.BaseFileSystem; import com.alibaba.alink.common.io.filesystem.FilePath; @@ -110,10 +111,9 @@ public void calc(List values, Collector out) { ) ); } - } catch (Exception ex) { - ex.printStackTrace(); - throw new RuntimeException(ex.toString()); - } + } catch (IOException e) { + throw new AkIllegalDataException("Fail to create partition directories or write files.",e); + } } }); diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/partition/Utils.java b/core/src/main/java/com/alibaba/alink/operator/common/io/partition/Utils.java deleted file mode 100644 index a019942c7..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/common/io/partition/Utils.java +++ /dev/null @@ -1,173 +0,0 @@ -package com.alibaba.alink.operator.common.io.partition; - -import org.apache.flink.api.common.functions.FlatMapFunction; -import org.apache.flink.api.common.functions.GroupReduceFunction; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.core.fs.Path; -import org.apache.flink.ml.api.misc.param.Params; -import org.apache.flink.streaming.api.datastream.DataStream; -import org.apache.flink.table.api.TableSchema; -import org.apache.flink.types.Row; -import org.apache.flink.util.Collector; - -import com.alibaba.alink.common.MTableUtil; -import com.alibaba.alink.common.io.filesystem.AkUtils; -import com.alibaba.alink.common.io.filesystem.AkUtils2; -import com.alibaba.alink.common.io.filesystem.BaseFileSystem; -import com.alibaba.alink.common.io.filesystem.FilePath; -import com.alibaba.alink.common.utils.TableUtil; -import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.io.dummy.DummyOutputFormat; -import com.alibaba.alink.operator.local.LocalOperator; -import com.alibaba.alink.operator.stream.StreamOperator; -import com.alibaba.alink.operator.stream.sink.Export2FileOutputFormat; -import com.alibaba.alink.params.io.HasFilePath; -import com.alibaba.alink.params.io.shared.HasPartitionColsDefaultAsNull; -import com.alibaba.alink.params.io.shared.HasPartitions; -import org.apache.commons.lang3.ArrayUtils; - -import java.io.IOException; -import java.util.List; - -public class Utils { - - public static Tuple2 , TableSchema> readFromPartitionBatch( - final Params params, final Long sessionId, - final SourceCollectorCreator sourceCollectorCreator) throws IOException { - - return readFromPartitionBatch(params, sessionId, sourceCollectorCreator, null); - } - - public static Tuple2 , TableSchema> readFromPartitionBatch( - final Params params, final Long sessionId, - final SourceCollectorCreator sourceCollectorCreator, String[] partitionCols) throws IOException { - - final FilePath filePath = FilePath.deserialize(params.get(HasFilePath.FILE_PATH)); - final String partitions = params.get(HasPartitions.PARTITIONS); - - BatchOperator selected = AkUtils2 - .selectPartitionBatchOp(sessionId, filePath, partitions, partitionCols); - - final String[] colNames = selected.getColNames(); - - return Tuple2.of( - selected - .getDataSet() - .rebalance() - .flatMap(new FlatMapFunction () { - @Override - public void flatMap(Row value, Collector out) throws Exception { - Path path = filePath.getPath(); - - for (int i = 0; i < value.getArity(); ++i) { - path = new Path(path, String.format("%s=%s", colNames[i], value.getField(i))); - } - - sourceCollectorCreator.collect(new FilePath(path, filePath.getFileSystem()), out); - } - }), - sourceCollectorCreator.schema() - ); - } - - public static Tuple2 , TableSchema> readFromPartitionStream( - final Params params, final Long sessionId, - final SourceCollectorCreator sourceCollectorCreator) throws IOException { - - return readFromPartitionStream(params, sessionId, sourceCollectorCreator, null); - } - - public static Tuple2 , TableSchema> readFromPartitionStream( - final Params params, final Long sessionId, - final SourceCollectorCreator sourceCollectorCreator, String[] partitionCols) throws IOException { - - final FilePath filePath = FilePath.deserialize(params.get(HasFilePath.FILE_PATH)); - final String partitions = params.get(HasPartitions.PARTITIONS); - - StreamOperator selected = AkUtils2 - .selectPartitionStreamOp(sessionId, filePath, partitions, partitionCols); - - final String[] colNames = selected.getColNames(); - - return Tuple2.of( - selected - .getDataStream() - .rebalance() - .flatMap(new FlatMapFunction () { - @Override - public void flatMap(Row value, Collector out) throws Exception { - Path path = filePath.getPath(); - - for (int i = 0; i < value.getArity(); ++i) { - path = new Path(path, String.format("%s=%s", colNames[i], value.getField(i))); - } - - sourceCollectorCreator.collect(new FilePath(path, filePath.getFileSystem()), out); - } - }), - sourceCollectorCreator.schema() - ); - } - - public static void partitionAndWriteFile( - BatchOperator input, SinkCollectorCreator sinkCollectorCreator, Params params) { - - TableSchema schema = input.getSchema(); - - String[] partitionCols = params.get(HasPartitionColsDefaultAsNull.PARTITION_COLS); - final int[] partitionColIndices = TableUtil.findColIndicesWithAssertAndHint(schema, partitionCols); - final String[] reservedCols = ArrayUtils.removeElements(schema.getFieldNames(), partitionCols); - final int[] reservedColIndices = TableUtil.findColIndices(schema.getFieldNames(), reservedCols); - final FilePath localFilePath = FilePath.deserialize(params.get(HasFilePath.FILE_PATH)); - - input - .getDataSet() - .groupBy(partitionCols) - .reduceGroup(new GroupReduceFunction () { - @Override - public void reduce(Iterable values, Collector out) throws IOException { - Path root = localFilePath.getPath(); - BaseFileSystem fileSystem = localFilePath.getFileSystem(); - - Collector collector = null; - Path localPath = null; - - for (Row row : values) { - if (collector == null) { - localPath = new Path(root.getPath()); - - for (int partitionColIndex : partitionColIndices) { - localPath = new Path(localPath, row.getField(partitionColIndex).toString()); - } - - fileSystem.mkdirs(localPath); - - collector = sinkCollectorCreator.createCollector(new FilePath( - new Path( - localPath, "0" + Export2FileOutputFormat.IN_PROGRESS_FILE_SUFFIX - ), - fileSystem - )); - } - - collector.collect(Row.project(row, reservedColIndices)); - } - - if (collector != null) { - collector.close(); - - fileSystem.rename( - new Path( - localPath, "0" + Export2FileOutputFormat.IN_PROGRESS_FILE_SUFFIX - ), - new Path( - localPath, "0" - ) - ); - } - } - }) - .output(new DummyOutputFormat <>()); - } -} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/reader/FSFileSplitReader.java b/core/src/main/java/com/alibaba/alink/operator/common/io/reader/FSFileSplitReader.java new file mode 100644 index 000000000..9f4ca35ba --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/reader/FSFileSplitReader.java @@ -0,0 +1,530 @@ +package com.alibaba.alink.operator.common.io.reader; + +import org.apache.flink.api.common.io.FilePathFilter; +import org.apache.flink.api.common.io.GlobFilePathFilter; +import org.apache.flink.api.common.io.statistics.BaseStatistics; +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.GlobalConfiguration; +import org.apache.flink.core.fs.BlockLocation; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.fs.FileInputSplit; +import org.apache.flink.core.fs.FileStatus; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.core.fs.Path; +import org.apache.flink.core.io.InputSplit; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; +import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; +import com.alibaba.alink.common.io.filesystem.BaseFileSystem; +import com.alibaba.alink.common.io.filesystem.FilePath; +import com.alibaba.alink.common.io.filesystem.FileSystemUtils; +import com.alibaba.alink.common.io.filesystem.copy.FileInputFormat.InputSplitOpenThread; +import com.alibaba.alink.operator.common.io.csv.GenericCsvInputFormatBeta; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +public class FSFileSplitReader implements FileSplitReader, AutoCloseable { + private final Path readerPath; + private final BaseFileSystem fs; + protected static final long openTimeout = 300000; + + private FSCsvInputFormat inputFormat = null; + + private transient Path filePath; + private transient FileInputSplit split; + private transient FSDataInputStream stream; + + public FSFileSplitReader(FilePath filePath) { + this.readerPath = filePath.getPath(); + this.fs = filePath.getFileSystem(); + } + + @Override + public void open(InputSplit split) throws IOException { + long splitStart = ((FileInputSplit) split).getStart(); + this.reopen(split, splitStart); + } + + @Override + public void reopen(InputSplit split, long splitStart) throws IOException { + this.split = (FileInputSplit) split; + long splitLength = this.split.getLength(); + + //System.out.println( + // String.valueOf(split.getSplitNumber()) + " opening the Input Split " + this.split.getPath() + " [" + // + splitStart + "," + splitLength + "]: "); + + final InputSplitOpenThread isot = new InputSplitOpenThread(this.split, this.openTimeout, this.fs); + isot.start(); + try { + this.stream = isot.waitForCompletion(); + } catch (Throwable t) { + throw new AkUnclassifiedErrorException("Error opening the Input Split " + this.split.getPath() + + " [" + splitStart + "," + splitLength + "]: " + t.getMessage(), t); + } + this.filePath = this.split.getPath(); + if (splitStart > 0) { + this.stream.seek(splitStart); + } + + } + + @Override + public void close() throws IOException { + if (this.stream != null) { + // close input stream + this.stream.close(); + stream = null; + } + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + return this.stream.read(b, off, len); + } + + @Override + public long getFileLength() { + try { + FileStatus stat = fs.getFileStatus(readerPath); + return stat.getLen(); + } catch (IOException e) { + return 0; + } + + } + + @Override + public long getSplitLength() { + return split.getLength(); + } + + @Override + public long getSplitStart() { + return split.getStart(); + } + + public String getSplit() {return this.split.toString();} + + /** + * The ending position is get from file status because FileInputSplit is different from CsvFileInput and doesn't + * have a variable that represent ending position of this split. + * + * @return return file length, if fail to get file length then return 0 + * @see com.alibaba.alink.operator.common.io.csv.CsvFileInputSplit#end + */ + @Override + public long getSplitEnd() { + try { + FileStatus stat = fs.getFileStatus(filePath); + return stat.getLen(); + } catch (IOException e) { + return 0; + } + } + + @Override + public long getSplitNumber() { + return split.getSplitNumber(); + } + + public Path getFilePath() { + return readerPath; + } + + public BaseFileSystem getFs() { + return fs; + } + + @Override + public FSCsvInputFormat getInputFormat(String lineDelim, boolean ignoreFirstLine, Character quoteChar) { + if (inputFormat == null) { + inputFormat = new FSCsvInputFormat(this, lineDelim, ignoreFirstLine, + quoteChar); + } + return inputFormat; + } + + @Override + public FSCsvInputFormat convertFileSplitToInputFormat(String lineDelim, boolean ignoreFirstLine, + Character quoteChar) { + return new FSCsvSplitInputFormat(this, lineDelim, ignoreFirstLine, quoteChar); + } + + @Override + public InputSplit convertStringToSplitObject(String splitStr) { + + return FSCsvSplitInputFormat.fromString(splitStr); + } + + public static class FSCsvInputFormat extends GenericCsvInputFormatBeta { + // -------------------------------------- Constants ------------------------------------------- + + private static final Logger LOG = LoggerFactory.getLogger(FSCsvInputFormat.class); + + private static final float MAX_SPLIT_SIZE_DISCREPANCY = 1.1f; + + /** + * The timeout (in milliseconds) to wait for a filesystem stream to respond. + */ + private static long DEFAULT_OPENING_TIMEOUT; + + static { + initDefaultsFromConfiguration(GlobalConfiguration.loadConfiguration()); + } + + /** + * Initialize defaults for input format. Needs to be a static method because it is configured for local + * cluster execution. + * + * @param configuration The configuration to load defaults from + */ + private static void initDefaultsFromConfiguration(Configuration configuration) { + final long to = configuration.getLong(ConfigConstants.FS_STREAM_OPENING_TIMEOUT_KEY, + ConfigConstants.DEFAULT_FS_STREAM_OPENING_TIMEOUT); + if (to < 0) { + LOG.error("Invalid timeout value for filesystem stream opening: " + to + ". Using default value of " + + ConfigConstants.DEFAULT_FS_STREAM_OPENING_TIMEOUT); + DEFAULT_OPENING_TIMEOUT = ConfigConstants.DEFAULT_FS_STREAM_OPENING_TIMEOUT; + } else if (to == 0) { + DEFAULT_OPENING_TIMEOUT = 300000; // 5 minutes + } else { + DEFAULT_OPENING_TIMEOUT = to; + } + } + + // -------------------------------------------------------------------------------------------- + // The configuration parameters. Configured on the instance and serialized to be shipped. + // -------------------------------------------------------------------------------------------- + + private final Path filePath; + /** + * The minimal split size, set by the configure() method. + */ + protected long minSplitSize = 0; + + /** + * The desired number of splits, as set by the configure() method. + */ + protected int numSplits = -1; + + /** + * Stream opening timeout. + */ + protected long openTimeout = DEFAULT_OPENING_TIMEOUT; + + private long offset = -1; + + /** + * The flag to specify whether recursive traversal of the input directory + * structure is enabled. + */ + protected boolean enumerateNestedFiles = true; + + /** + * Files filter for determining what files/directories should be included. + */ + private FilePathFilter filesFilter = new GlobFilePathFilter(); + + private final BaseFileSystem fs; + + // -------------------------------------------------------------------------------------------- + // Constructors + // -------------------------------------------------------------------------------------------- + + public FSCsvInputFormat(FSFileSplitReader reader, + String lineDelim, boolean ignoreFirstLine) { + super(reader, lineDelim, ignoreFirstLine); + this.fs = ((FSFileSplitReader) reader).getFs(); + this.filePath = ((FSFileSplitReader) reader).getFilePath(); + } + + public FSCsvInputFormat(FSFileSplitReader reader, + String lineDelim, boolean ignoreFirstLine, + Character quoteChar) { + super(reader, lineDelim, ignoreFirstLine, quoteChar); + this.fs = ((FSFileSplitReader) reader).getFs(); + this.filePath = ((FSFileSplitReader) reader).getFilePath(); + } + + public FSCsvInputFormat(FSFileSplitReader reader, + String lineDelim, boolean ignoreFirstLine, + boolean unsplitable, Character quoteChar) { + super(reader, lineDelim, ignoreFirstLine, unsplitable, quoteChar); + this.fs = ((FSFileSplitReader) reader).getFs(); + this.filePath = ((FSFileSplitReader) reader).getFilePath(); + } + + @Override + public void configure(Configuration parameters) { + + } + + @Override + public BaseStatistics getStatistics(BaseStatistics cachedStatistics) throws IOException { + return null; + } + + @Override + public void open(FileInputSplit split) throws IOException { + super.open(split); + if (split.getStart() > 0 || ignoreFirstLine) { + if (!readLine()) { + // if the first partial record already pushes the stream over + // the limit of our split, then no record starts within this split + setEnd(true); + } + + } else { + fillBuffer(0); + } + } + + public void openWithoutSkipLine(FileInputSplit split) throws IOException { + super.open(split); + } + + @Override + public FileInputSplit[] createInputSplits(int minNumSplits) throws IOException { + if (minNumSplits < 1) { + throw new AkIllegalArgumentException("Number of input splits has to be at least 1."); + } + + // take the desired number of splits into account + minNumSplits = Math.max(minNumSplits, this.numSplits); + + final List inputSplits = new ArrayList (minNumSplits); + + // get all the files that are involved in the splits + List files = new ArrayList <>(); + long totalLength = 0; + + if (filePath != null) { + final FileSystem fs = FileSystemUtils.getFlinkFileSystem(this.fs, filePath.toString()); + final FileStatus pathFile = fs.getFileStatus(filePath); + + if (pathFile.isDir()) { + totalLength += addFilesInDir(filePath, files, true); + } else { + files.add(pathFile); + totalLength += pathFile.getLen(); + } + } + + // returns if unsplittable + if (unsplittable) { + int splitNum = 0; + for (final FileStatus file : files) { + final FileSystem fs = FileSystemUtils.getFlinkFileSystem(this.fs, file.getPath().toString()); + final BlockLocation[] blocks = fs.getFileBlockLocations(file, 0, file.getLen()); + Set hosts = new HashSet (); + for (BlockLocation block : blocks) { + hosts.addAll(Arrays.asList(block.getHosts())); + } + long len = file.getLen(); + //len = READ_WHOLE_SPLIT_FLAG; + FileInputSplit fis = new FileInputSplit(splitNum++, file.getPath(), 0, len, + hosts.toArray(new String[hosts.size()])); + inputSplits.add(fis); + } + return inputSplits.toArray(new FileInputSplit[inputSplits.size()]); + } + + final long maxSplitSize = totalLength / minNumSplits + (totalLength % minNumSplits == 0 ? 0 : 1); + + // now that we have the files, generate the splits + int splitNum = 0; + for (final FileStatus file : files) { + + final FileSystem fs = FileSystemUtils.getFlinkFileSystem(this.fs, file.getPath().toString()); + final long len = file.getLen(); + final long blockSize = file.getBlockSize(); + + final long minSplitSize; + if (this.minSplitSize <= blockSize) { + minSplitSize = this.minSplitSize; + } else { + if (LOG.isWarnEnabled()) { + LOG.warn("Minimal split size of " + this.minSplitSize + " is larger than the block size of " + + blockSize + ". Decreasing minimal split size to block size."); + } + minSplitSize = blockSize; + } + + final long splitSize = Math.max(minSplitSize, Math.min(maxSplitSize, blockSize)); + final long halfSplit = splitSize >>> 1; + + final long maxBytesForLastSplit = (long) (splitSize * MAX_SPLIT_SIZE_DISCREPANCY); + if (len > 0) { + + // get the block locations and make sure they are in order with respect to their offset + final BlockLocation[] blocks = fs.getFileBlockLocations(file, 0, len); + Arrays.sort(blocks); + + long bytesUnassigned = len; + long position = 0; + + int blockIndex = 0; + + while (bytesUnassigned > maxBytesForLastSplit) { + // get the block containing the majority of the data + blockIndex = getBlockIndexForPosition(blocks, position, halfSplit, blockIndex); + // create a new split + FileInputSplit fis = new FileInputSplit(splitNum++, file.getPath(), position, splitSize, + blocks[blockIndex].getHosts()); + inputSplits.add(fis); + + // adjust the positions + position += splitSize; + bytesUnassigned -= splitSize; + } + + // assign the last split + if (bytesUnassigned > 0) { + blockIndex = getBlockIndexForPosition(blocks, position, halfSplit, blockIndex); + final FileInputSplit fis = new FileInputSplit(splitNum++, file.getPath(), position, + bytesUnassigned, blocks[blockIndex].getHosts()); + inputSplits.add(fis); + } + } else { + // special case with a file of zero bytes size + final BlockLocation[] blocks = fs.getFileBlockLocations(file, 0, 0); + String[] hosts; + if (blocks.length > 0) { + hosts = blocks[0].getHosts(); + } else { + hosts = new String[0]; + } + final FileInputSplit fis = new FileInputSplit(splitNum++, file.getPath(), 0, 0, hosts); + inputSplits.add(fis); + } + } + + return inputSplits.toArray(new FileInputSplit[inputSplits.size()]); + } + + private int getBlockIndexForPosition(BlockLocation[] blocks, long offset, long halfSplitSize, int startIndex) { + // go over all indexes after the startIndex + for (int i = startIndex; i < blocks.length; i++) { + long blockStart = blocks[i].getOffset(); + long blockEnd = blockStart + blocks[i].getLength(); + + if (offset >= blockStart && offset < blockEnd) { + // got the block where the split starts + // check if the next block contains more than this one does + if (i < blocks.length - 1 && blockEnd - offset < halfSplitSize) { + return i + 1; + } else { + return i; + } + } + } + throw new AkIllegalArgumentException("The given offset is not contained in the any block."); + } + + /** + * Enumerate all files in the directory and recursive if enumerateNestedFiles is true. + * + * @return the total length of accepted files. + */ + private long addFilesInDir(Path path, List files, boolean logExcludedFiles) + throws IOException { + final FileSystem fs = FileSystemUtils.getFlinkFileSystem(this.fs, path.toString()); + + long length = 0; + + for (FileStatus dir : fs.listStatus(path)) { + if (dir.isDir()) { + length += addFilesInDir(dir.getPath(), files, logExcludedFiles); + } else { + if (acceptFile(dir)) { + files.add(dir); + length += dir.getLen(); + } else { + if (logExcludedFiles && LOG.isDebugEnabled()) { + LOG.debug( + "Directory " + dir.getPath().toString() + + " did not pass the file-filter and is excluded" + + "."); + } + } + } + } + return length; + } + + private boolean acceptFile(FileStatus fileStatus) { + final String name = fileStatus.getPath().getName(); + final FilePathFilter filesFilter = new GlobFilePathFilter(); + return !name.startsWith("_") + && !name.startsWith(".") + && !filesFilter.filterPath(fileStatus.getPath()); + } + } + + public static class FSCsvSplitInputFormat extends FSCsvInputFormat { + + public FSCsvSplitInputFormat(FSFileSplitReader reader, String lineDelim, boolean ignoreFirstLine, + Character quoteChar) { + super(reader, lineDelim, ignoreFirstLine, false, quoteChar); + } + + @Override + public void open(FileInputSplit split) throws IOException { + super.openWithoutSkipLine(split); + } + + /** + * This format is used to scan each split. It only returns one record then set end flag to true. + * The record includes quote character num and a string that stores variable used to rebuild the split. + * + * @param record Object that may be reused. + * @return quote count, split status and split serialized result + * @throws IOException + */ + @Override + public Row nextRecord(Row record) throws IOException { + long quoteNum = QuoteUtil.analyzeSplit(this.reader, quoteCharacter); + + StringBuilder sbd = new StringBuilder(); + sbd.append(this.currentSplit.toString()); + sbd.append("["); + for (String host : this.currentSplit.getHostnames()) { + sbd.append(host).append(";"); + } + sbd.append("]"); + + this.setEnd(true); + this.reader.close(); + + return Row.of(quoteNum, this.reader.getSplitNumber(), sbd.toString()); + } + + public static FileInputSplit fromString(String splitStr) { + int leftBracketsPosFirst = splitStr.indexOf("["); + int rightBracketsPosFirst = splitStr.indexOf("]"); + int leftBracketsPosSecond = splitStr.indexOf("[", leftBracketsPosFirst + 1); + int rightBracketsPosSecond = splitStr.indexOf("]", rightBracketsPosFirst + 1); + int filePathEndPos = splitStr.lastIndexOf(":"); + int plusPos = splitStr.lastIndexOf("+"); + + int num = Integer.valueOf(splitStr.substring(leftBracketsPosFirst + 1, rightBracketsPosFirst)); + String filePath = splitStr.substring(rightBracketsPosFirst + 2, filePathEndPos); + long start = Long.valueOf(splitStr.substring(filePathEndPos + 1, plusPos)); + long length = Long.valueOf(splitStr.substring(plusPos + 1, leftBracketsPosSecond)); + String[] hosts = splitStr.substring(leftBracketsPosSecond + 1, rightBracketsPosSecond).split(";"); + Path path = new Path(filePath); + + return new FileInputSplit(num, path, start, length, hosts); + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/reader/FileSplitReader.java b/core/src/main/java/com/alibaba/alink/operator/common/io/reader/FileSplitReader.java index a3298e55e..2d118215c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/io/reader/FileSplitReader.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/reader/FileSplitReader.java @@ -1,6 +1,8 @@ package com.alibaba.alink.operator.common.io.reader; -import com.alibaba.alink.operator.common.io.csv.CsvFileInputSplit; +import org.apache.flink.core.io.InputSplit; + +import com.alibaba.alink.operator.common.io.csv.GenericCsvInputFormatBeta; import java.io.IOException; import java.io.Serializable; @@ -13,7 +15,9 @@ public interface FileSplitReader extends Serializable { /** * Open for reading data range [start, end] */ - void open(CsvFileInputSplit split, long start, long end) throws IOException; + void open(InputSplit split) throws IOException; + + void reopen(InputSplit split, long splitStart) throws IOException; /** * Close the reader. @@ -43,4 +47,18 @@ public interface FileSplitReader extends Serializable { * @return The length. */ long getFileLength(); + + long getSplitLength(); + + long getSplitStart(); + + long getSplitEnd(); + + long getSplitNumber(); + + GenericCsvInputFormatBeta getInputFormat(String lineDelim, boolean ignoreFirstLine, Character quoteChar); + + GenericCsvInputFormatBeta convertFileSplitToInputFormat(String lineDelim, boolean ignoreFirstLine, Character quoteChar); + + InputSplit convertStringToSplitObject(String splitStr); } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/reader/HttpFileSplitReader.java b/core/src/main/java/com/alibaba/alink/operator/common/io/reader/HttpFileSplitReader.java index 44adb777a..4d4ab0950 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/io/reader/HttpFileSplitReader.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/reader/HttpFileSplitReader.java @@ -1,7 +1,11 @@ package com.alibaba.alink.operator.common.io.reader; -import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; +import org.apache.flink.core.io.InputSplit; + +import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.operator.common.io.csv.CsvFileInputSplit; +import com.alibaba.alink.operator.common.io.csv.CsvInputFormatBeta; +import com.alibaba.alink.operator.common.io.csv.CsvInputFormatBeta.CsvSplitInputFormat; import java.io.IOException; import java.io.InputStream; @@ -15,6 +19,8 @@ public class HttpFileSplitReader implements FileSplitReader, AutoCloseable { private static final long serialVersionUID = 510714179228030029L; private String path; + private CsvInputFormatBeta inputFormat = null; + private transient CsvFileInputSplit split; private transient HttpURLConnection connection; private transient InputStream stream; @@ -22,8 +28,7 @@ public HttpFileSplitReader(String path) { this.path = path; } - @Override - public void open(CsvFileInputSplit split, long start, long end) throws IOException { + public void open(InputSplit split, long start, long end) throws IOException { assert start >= 0; URL url = new URL(path); this.connection = (HttpURLConnection) url.openConnection(); @@ -34,6 +39,20 @@ public void open(CsvFileInputSplit split, long start, long end) throws IOExcepti this.connection.setRequestProperty("Range", String.format("bytes=%d-%d", start, end)); this.connection.connect(); this.stream = this.connection.getInputStream(); + this.split = (CsvFileInputSplit) split; + } + + @Override + public void open(InputSplit split) throws IOException { + long start = ((CsvFileInputSplit) split).start; + long end = ((CsvFileInputSplit) split).end - 1; + this.open(split, start, end); + } + + @Override + public void reopen(InputSplit split, long start) throws IOException { + long end = ((CsvFileInputSplit) split).end - 1; + this.open(split, start, end); } @Override @@ -66,23 +85,64 @@ public long getFileLength() { boolean splitable = acceptRanges != null && acceptRanges.equalsIgnoreCase("bytes"); if (contentLength < 0) { - throw new AkUnclassifiedErrorException("The content length can't be determined."); + throw new AkIllegalDataException("The content length can't be determined because content length < 0."); } // If the http server does not accept ranges, then we quit the program. // This is because 'accept ranges' is required to achieve robustness (through re-connection), // and efficiency (through concurrent read). if (!splitable) { - throw new AkUnclassifiedErrorException("The http server does not support range reading."); + throw new AkIllegalDataException("Http-Header doesn't have header 'Accept-Ranges' or the value of " + + "'Accept-Ranges' value not equal 'bytes', The http server does not support range reading."); } return contentLength; } catch (Exception e) { - throw new AkUnclassifiedErrorException("Fail to connect to http server", e); + throw new AkIllegalDataException(String.format("Fail to connect to http address %s", path), e); } finally { if (headerConnection != null) { headerConnection.disconnect(); } } } + + @Override + public long getSplitStart() { + return split.start; + } + + @Override + public long getSplitEnd() { + return split.end; + } + + @Override + public long getSplitNumber() { + return split.getSplitNumber(); + } + + @Override + public long getSplitLength() { + return split.length; + } + + @Override + public CsvInputFormatBeta getInputFormat(String lineDelim, boolean ignoreFirstLine, + Character quoteChar) { + if (inputFormat == null) { + inputFormat = new CsvInputFormatBeta(this, lineDelim, ignoreFirstLine, quoteChar); + } + return inputFormat; + } + + @Override + public CsvInputFormatBeta convertFileSplitToInputFormat(String lineDelim, boolean ignoreFirstLine, + Character quoteChar) { + return new CsvSplitInputFormat(this, lineDelim, ignoreFirstLine, quoteChar); + } + + @Override + public InputSplit convertStringToSplitObject(String splitStr) { + return CsvFileInputSplit.fromString(splitStr); + } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/reader/QuoteUtil.java b/core/src/main/java/com/alibaba/alink/operator/common/io/reader/QuoteUtil.java new file mode 100644 index 000000000..a85c0c293 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/reader/QuoteUtil.java @@ -0,0 +1,32 @@ +package com.alibaba.alink.operator.common.io.reader; + +import java.io.IOException; + +public class QuoteUtil { + private static final int BUFFER_SIZE = 1024 * 1024; + + public static long analyzeSplit(FileSplitReader reader,byte quoteCharacter) throws IOException { + byte[] buf = new byte[BUFFER_SIZE]; + long splitLength = reader.getSplitLength(); + long byteRead = 0; + long quoteNum = 0; + int read; + + while (byteRead < splitLength) { + read = reader.read(buf, 0, (int) Long.min(BUFFER_SIZE, splitLength - byteRead)); + if (read > 0) { + byteRead += read; + }else{ + break; + } + + for (int i = 0; i < read; i++) { + if (buf[i] == quoteCharacter) { + quoteNum++; + } + } + } + return quoteNum; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/io/types/JdbcTypeConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/io/types/JdbcTypeConverter.java index 44fe90243..b828cc738 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/io/types/JdbcTypeConverter.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/io/types/JdbcTypeConverter.java @@ -44,7 +44,7 @@ public class JdbcTypeConverter { m1.put(SqlTimeTypeInfo.TIME, Types.TIME); m1.put(SqlTimeTypeInfo.TIMESTAMP, Types.TIMESTAMP); m1.put(BasicTypeInfo.BIG_DEC_TYPE_INFO, Types.DECIMAL); - m1.put(PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO, Types.BINARY); + m1.put(PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO, Types.VARBINARY); MAP_FLINK_TYPE_TO_INDEX = Collections.unmodifiableMap(m1); HashMap > m3 = new HashMap <>(); @@ -62,7 +62,7 @@ public class JdbcTypeConverter { m3.put(Types.TIME, SqlTimeTypeInfo.TIME); m3.put(Types.TIMESTAMP, SqlTimeTypeInfo.TIMESTAMP); m3.put(Types.DECIMAL, BasicTypeInfo.BIG_DEC_TYPE_INFO); - m3.put(Types.BINARY, PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO); + m3.put(Types.VARBINARY, PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO); MAP_INDEX_TO_FLINK_TYPE = Collections.unmodifiableMap(m3); } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/linear/AftRegObjFunc.java b/core/src/main/java/com/alibaba/alink/operator/common/linear/AftRegObjFunc.java index 20cf81f3f..2881b1443 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/linear/AftRegObjFunc.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/linear/AftRegObjFunc.java @@ -12,6 +12,8 @@ import com.alibaba.alink.common.linalg.Vector; import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc; +import java.util.List; + /** * Accelerated failure time Regression object function. */ @@ -66,7 +68,7 @@ public static double getDotProduct(Vector labelVector, DenseVector coefVector) { * @return the loss value and weight. */ @Override - protected double calcLoss(Tuple3 labelVector, DenseVector coefVector) { + public double calcLoss(Tuple3 labelVector, DenseVector coefVector) { /* * loss = censor * coefVector.get(coefVector.size() - 1) - censor * epsilon + Math.exp(epsilon) * the last one of coefVector is the log(sigma) @@ -84,7 +86,7 @@ protected double calcLoss(Tuple3 labelVector, DenseVect * @param updateGrad gradient need to update. */ @Override - protected void updateGradient(Tuple3 labelVector, DenseVector coefVector, + public void updateGradient(Tuple3 labelVector, DenseVector coefVector, DenseVector updateGrad) { double sigma = Math.exp(coefVector.get(coefVector.size() - 1)); @@ -116,7 +118,7 @@ protected void updateGradient(Tuple3 labelVector, Dense * @param updateHessian hessian matrix need to update. */ @Override - protected void updateHessian(Tuple3 labelVector, DenseVector coefVector, + public void updateHessian(Tuple3 labelVector, DenseVector coefVector, DenseMatrix updateHessian) { double sigma = Math.exp(coefVector.get(coefVector.size() - 1)); double epsilon = (labelVector.f1 - getDotProduct(labelVector.f2, coefVector)) / sigma; @@ -317,7 +319,7 @@ public double[] calcSearchValues(Iterable > labe */ @Override public double[] constraintCalcSearchValues( - Iterable > labelVectors, + List > labelVectors, DenseVector coefVector, DenseVector dirVec, double beta, int numStep) { double[] losses = new double[numStep + 1]; double[] coefArray = coefVector.getData(); diff --git a/core/src/main/java/com/alibaba/alink/operator/common/linear/BaseConstrainedLinearModelTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/linear/BaseConstrainedLinearModelTrainBatchOp.java new file mode 100644 index 000000000..0013d2f5e --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/linear/BaseConstrainedLinearModelTrainBatchOp.java @@ -0,0 +1,985 @@ +package com.alibaba.alink.operator.common.linear; + +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import com.alibaba.alink.common.MLEnvironment; +import com.alibaba.alink.common.MLEnvironmentFactory; +import com.alibaba.alink.common.annotation.FeatureColsVectorColMutexRule; +import com.alibaba.alink.common.annotation.InputPorts; +import com.alibaba.alink.common.annotation.OutputPorts; +import com.alibaba.alink.common.annotation.ParamSelectColumnSpec; +import com.alibaba.alink.common.annotation.PortSpec; +import com.alibaba.alink.common.annotation.PortType; +import com.alibaba.alink.common.annotation.TypeCollections; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.SparseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.linalg.VectorUtil; +import com.alibaba.alink.common.model.ModelParamName; +import com.alibaba.alink.common.viz.AlinkViz; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.common.viz.VizDataWriterForModelInfo; +import com.alibaba.alink.common.viz.VizDataWriterInterface; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.finance.ScorecardTrainBatchOp; +import com.alibaba.alink.operator.common.evaluation.EvaluationUtil; +import com.alibaba.alink.operator.common.linear.unarylossfunc.LogLossFunc; +import com.alibaba.alink.operator.common.linear.unarylossfunc.SquareLossFunc; +import com.alibaba.alink.operator.common.optim.FeatureConstraint; +import com.alibaba.alink.operator.common.optim.Lbfgs; +import com.alibaba.alink.operator.common.optim.Newton; +import com.alibaba.alink.operator.common.optim.activeSet.ConstraintObjFunc; +import com.alibaba.alink.operator.common.optim.activeSet.Sqp; +import com.alibaba.alink.operator.common.optim.barrierIcq.LogBarrier; +import com.alibaba.alink.operator.common.optim.local.ConstrainedLocalOptimizer; +import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; +import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; +import com.alibaba.alink.operator.common.statistics.basicstatistic.SparseVectorSummary; +import com.alibaba.alink.operator.common.tree.Preprocessing; +import com.alibaba.alink.params.finance.ConstrainedLinearModelParams; +import com.alibaba.alink.params.finance.ConstrainedLogisticRegressionTrainParams; +import com.alibaba.alink.params.finance.HasConstrainedOptimizationMethod; +import com.alibaba.alink.params.shared.linear.LinearTrainParams; +import org.apache.commons.math3.optim.PointValuePair; +import org.apache.commons.math3.optim.linear.LinearConstraint; +import org.apache.commons.math3.optim.linear.LinearConstraintSet; +import org.apache.commons.math3.optim.linear.LinearObjectiveFunction; +import org.apache.commons.math3.optim.linear.Relationship; +import org.apache.commons.math3.optim.linear.SimplexSolver; +import org.apache.commons.math3.optim.nonlinear.scalar.GoalType; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Base class of linear model training. Linear binary classification and linear regression algorithms should inherit + * this class. Then it only need to write the code of loss function and regular item. + * + * constraint can be set by params or second input op. + * if has constraint, standardization is false. + * + * if in scorecard or lr, positive value is required. + * + * default optim is sqp whether has constraint or not. + * + * @param parameter of this class. Maybe the linearRegression or Lr parameter. + */ +@InputPorts(values = @PortSpec(PortType.DATA)) +@OutputPorts(values = { + @PortSpec(PortType.MODEL) +}) + +@ParamSelectColumnSpec(name = "featureCols", + allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@ParamSelectColumnSpec(name = "vectorCol", + allowedTypeCollections = TypeCollections.VECTOR_TYPES) +@ParamSelectColumnSpec(name = "labelCol") +@ParamSelectColumnSpec(name = "weightCol", + allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@FeatureColsVectorColMutexRule +public abstract class BaseConstrainedLinearModelTrainBatchOp> + extends BatchOperator + implements AlinkViz { + private static final long serialVersionUID = 1180583968098354917L; + private String modelName; + private LinearModelType linearModelType; + //for viz, when feature num > NUM_FEATURE_THRESHOLD, not viz + private static final int NUM_FEATURE_THRESHOLD = 10000; + private static final String META = "meta"; + private static final String MEAN_VAR = "meanVar"; + //inner variable for broadcast. + private static final String VECTOR_SIZE = "vectorSize"; + private static final String LABEL_VALUES = "labelValues"; + + /** + * @param params parameters needed by training process. + * @param modelType model type: LR, LinearReg + * @param modelName name of model. + */ + public BaseConstrainedLinearModelTrainBatchOp(Params params, LinearModelType modelType, String modelName) { + super(params); + this.modelName = modelName; + this.linearModelType = modelType; + } + + /** + * @param inputs first is data, second is constraint. + * first is required and second is optioned. + * constraint can from second input or params. + */ + @Override + public T linkFrom(BatchOperator ... inputs) { + Params params = getParams(); + + BatchOperator in = inputs[0]; + DataSet constraints = null; + + //when has constraint, STANDARDIZATION is false. + if (inputs.length == 2 && inputs[1] != null) { + constraints = inputs[1].getDataSet(); + params.set(LinearTrainParams.STANDARDIZATION, false); + } + + //constraint can from second input or params. + String cons = params.get(ConstrainedLinearModelParams.CONSTRAINT); + + //in scorecard, it is always "" not null. + if (!"".equals(cons)) { + constraints = MLEnvironmentFactory + .get(this.getMLEnvironmentId()).getExecutionEnvironment().fromElements( + Row.of(FeatureConstraint.fromJson(cons))); + params.set(LinearTrainParams.STANDARDIZATION, false); + } else { + if (constraints != null) { + constraints = inputs[1].getDataSet(); + //here deal with input constraint which may be string (from constrained linear model) + constraints = constraints.map(new GenerateConstraint()); + } else { + constraints = MLEnvironmentFactory + .get(this.getMLEnvironmentId()).getExecutionEnvironment().fromElements( + Row.of(new FeatureConstraint())); + } + } + String positiveLabel = null; + //three cases: lr, linearreg in scorecard, linearreg. + boolean useLabel = LinearModelType.LR == linearModelType || getParams().get(ScorecardTrainBatchOp + .IN_SCORECARD); + if (useLabel) { + positiveLabel = params.get(ConstrainedLogisticRegressionTrainParams.POS_LABEL_VAL_STR); + } + + //may set true when not have constraint. + if (!params.contains(HasConstrainedOptimizationMethod.CONST_OPTIM_METHOD)) { + params.set(HasConstrainedOptimizationMethod.CONST_OPTIM_METHOD, + HasConstrainedOptimizationMethod.ConstOptimMethod.SQP); + } + String method = params.get(HasConstrainedOptimizationMethod.CONST_OPTIM_METHOD).toString().toUpperCase(); + + // Get type of processing: regression or not + boolean isRegProc = linearModelType == LinearModelType.LinearReg; + boolean standardization = params.get(LinearTrainParams.STANDARDIZATION); + + // Get label info: including label values and label type. + //If linear regression not in scorecard, label values is new Object(). + Tuple2 , TypeInformation> labelInfo = getLabelInfo(in, params, !useLabel); + + //Tuple3 format train data + DataSet > initData = BaseLinearModelTrainBatchOp + .transform(in, params, isRegProc, standardization); + + //Tuple3 + DataSet > utilInfo = BaseLinearModelTrainBatchOp + .getUtilInfo(initData, standardization, isRegProc); + + //get means and variances + DataSet meanVar = utilInfo.map( + new MapFunction , DenseVector[]>() { + private static final long serialVersionUID = 7127767376687624403L; + + @Override + public DenseVector[] map(Tuple3 value) + throws Exception { + return value.f0; + } + }); + + //get feature size + DataSet featSize = utilInfo.map( + new MapFunction , Integer>() { + private static final long serialVersionUID = 2773811388068064638L; + + @Override + public Integer map(Tuple3 value) + throws Exception { + return value.f2[0]; + } + }); + + //change orders of labels, first value is positiveLabel. + DataSet labelValues = utilInfo.flatMap(new BuildLabels(isRegProc, positiveLabel)); + + // + // vector is after standard if standardization is true. + // label is change to 1.0/0.0, if lr or linearReg in scoreCard. + DataSet > trainData + = BaseLinearModelTrainBatchOp + .preProcess(initData, params, isRegProc, meanVar, labelValues, featSize); + + //stat non zero of all features. + DataSet countZero = StatisticsHelper.summary(trainData.map( + new MapFunction , Vector>() { + private static final long serialVersionUID = 6207307350053531656L; + + @Override + public Vector map(Tuple3 value) throws Exception { + return value.f2; + } + }).withForwardedFields()) + .map(new MapFunction () { + private static final long serialVersionUID = 2322849507320367330L; + + @Override + public DenseVector map(BaseVectorSummary value) throws Exception { + if (value instanceof SparseVectorSummary) { + return (DenseVector) ((SparseVectorSummary) value).numNonZero(); + } + return new DenseVector(0); + } + }); + + // Solve the optimization problem. + // return + Tuple2 , DataSet > optParam = getOptParam(constraints, params, featSize, + linearModelType, MLEnvironmentFactory.get(this.getMLEnvironmentId()), method, countZero); + DataSet > coefVectorSet = optimize(params, optParam, trainData, modelName, + method); + + // Prepare the meta info of linear model. + DataSet meta = labelInfo.f0 + .mapPartition(new CreateMeta(modelName, linearModelType, params, useLabel, positiveLabel)) + .setParallelism(1); + + // Build linear model rows, the format to be output. + DataSet modelRows; + String[] featureColTypes = BaseLinearModelTrainBatchOp.getFeatureTypes(in, + params.get(LinearTrainParams.FEATURE_COLS)); + modelRows = coefVectorSet + .mapPartition(new BaseLinearModelTrainBatchOp.BuildModelFromCoefs(labelInfo.f1, + params.get(LinearTrainParams.FEATURE_COLS), + params.get(LinearTrainParams.STANDARDIZATION), + params.get(LinearTrainParams.WITH_INTERCEPT), + featureColTypes)) + .withBroadcastSet(meta, META) + .withBroadcastSet(meanVar, MEAN_VAR) + .setParallelism(1); + // Convert the model rows to table. + this.setOutput(modelRows, new LinearModelDataConverter(labelInfo.f1).getModelSchema()); + writeVizData(modelRows, featSize); + return (T) this; + } + + private static class GenerateConstraint implements MapFunction { + private static final long serialVersionUID = -6999309059934707482L; + + @Override + public Row map(Row value) { + FeatureConstraint cons; + if (value.getField(0) instanceof FeatureConstraint) { + cons = (FeatureConstraint) value.getField(0); + } else { + cons = FeatureConstraint.fromJson((String) value.getField(0)); + } + return Row.of(cons); + } + } + + protected static DataSet > transform(BatchOperator in, Params params, + DataSet labelValues, + boolean isRegProc, String posLabel, + TypeInformation labelType) { + String[] featureColNames = params.get(LinearTrainParams.FEATURE_COLS); + String labelName = params.get(LinearTrainParams.LABEL_COL); + String weightColName = params.get(LinearTrainParams.WEIGHT_COL); + String vectorColName = params.get(LinearTrainParams.VECTOR_COL); + TableSchema dataSchema = in.getSchema(); + if (null == featureColNames && null == vectorColName) { + featureColNames = TableUtil.getNumericCols(dataSchema, new String[] {labelName}); + params.set(LinearTrainParams.FEATURE_COLS, featureColNames); + } + int[] featureIndices = null; + int labelIdx = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), labelName); + if (featureColNames != null) { + featureIndices = new int[featureColNames.length]; + for (int i = 0; i < featureColNames.length; ++i) { + int idx = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), featureColNames[i]); + featureIndices[i] = idx; + TypeInformation type = in.getSchema().getFieldTypes()[idx]; + + Preconditions.checkState(TableUtil.isSupportedNumericType(type), + "linear algorithm only support numerical data type. type is : " + type); + } + } + int weightIdx = weightColName != null ? TableUtil.findColIndexWithAssertAndHint(in.getColNames(), + weightColName) + : -1; + int vecIdx = vectorColName != null ? TableUtil.findColIndexWithAssertAndHint(in.getColNames(), vectorColName) + : -1; + return in.getDataSet().map( + new Transform(isRegProc, weightIdx, vecIdx, featureIndices, labelIdx, posLabel, labelType)) + .withBroadcastSet(labelValues, LABEL_VALUES); + } + + private static class Transform extends RichMapFunction > { + private static final long serialVersionUID = 3541655329500762922L; + private String positiveLableValueString; + + private boolean isRegProc; + private int weightIdx; + private int vecIdx; + private int labelIdx; + private int[] featureIndices; + private TypeInformation type; + + public Transform(boolean isRegProc, int weightIdx, int vecIdx, int[] featureIndices, + int labelIdx, String posLabel, TypeInformation type) { + this.isRegProc = isRegProc; + this.weightIdx = weightIdx; + this.vecIdx = vecIdx; + this.featureIndices = featureIndices; + this.labelIdx = labelIdx; + this.positiveLableValueString = posLabel; + this.type = type; + } + + @Override + public void open(Configuration parameters) throws Exception { + + if (!this.isRegProc) { + List labelRows = getRuntimeContext().getBroadcastVariable(LABEL_VALUES); + Object[] labels = orderLabels(labelRows, positiveLableValueString); + if (this.positiveLableValueString == null) { + throw new RuntimeException("constrained logistic regression must set positive label!"); + } + EvaluationUtil.ComparableLabel posLabel = + new EvaluationUtil.ComparableLabel(this.positiveLableValueString, this.type); + if (!posLabel.equals(new EvaluationUtil.ComparableLabel(labels[0].toString(), type)) + && !posLabel.equals(new EvaluationUtil.ComparableLabel(labels[1].toString(), type))) { + throw new RuntimeException("the user defined positive label is not in the data!"); + } + } + } + + @Override + public Tuple3 map(Row row) throws Exception { + Double weight = weightIdx != -1 ? ((Number) row.getField(weightIdx)).doubleValue() : 1.0; + Double val = FeatureLabelUtil.getLabelValue(row, this.isRegProc, + labelIdx, this.positiveLableValueString); + if (featureIndices != null) { + DenseVector vec = new DenseVector(featureIndices.length); + for (int i = 0; i < featureIndices.length; ++i) { + vec.set(i, ((Number) row.getField(featureIndices[i])).doubleValue()); + } + return Tuple3.of(weight, val, vec); + } else { + Vector vec = VectorUtil.getVector(row.getField(vecIdx)); + Preconditions.checkState((vec != null), + "vector for linear model train is null, please check your input data."); + + return Tuple3.of(weight, val, vec); + } + + } + } + + protected static Object[] orderLabels(Iterable unorderedLabelRows, String positiveLabel) { + List tmpArr = new ArrayList <>(); + for (Object row : unorderedLabelRows) { + tmpArr.add(row); + } + Object[] labels = tmpArr.toArray(new Object[0]); + + Preconditions.checkState((labels.length >= 2), "labels count should be more than 2 in classification algo."); + String str1 = labels[1].toString(); + + if (str1.equals(positiveLabel)) { + Object t = labels[0]; + labels[0] = labels[1]; + labels[1] = t; + } + return labels; + } + + protected static Object[] orderLabels(Object[] unorderedLabelRows, String positiveLabel) { + + Preconditions.checkState((unorderedLabelRows.length >= 2), + "labels count should be more than 2 in classification algo."); + String str1 = unorderedLabelRows[1].toString(); + + if (str1.equals(positiveLabel)) { + Object t = unorderedLabelRows[0]; + unorderedLabelRows[0] = unorderedLabelRows[1]; + unorderedLabelRows[1] = t; + } + return unorderedLabelRows; + } + + /** + * + * @param constraints + * @param params + * @param vectorSize + * @param modelType + * @param session + * @param method + * @param countZero + * @return OptimObjFunc, coefficientDim> + */ + private static Tuple2 , DataSet > getOptParam(DataSet constraints, + Params params, + DataSet vectorSize, + LinearModelType modelType, + MLEnvironment session, String method, + DataSet countZero) { + boolean hasInterceptItem = params.get(LinearTrainParams.WITH_INTERCEPT); + String[] featureColNames = params.get(LinearTrainParams.FEATURE_COLS); + String vectorColName = params.get(LinearTrainParams.VECTOR_COL); + if ("".equals(vectorColName)) { + vectorColName = null; + } + if (org.apache.commons.lang3.ArrayUtils.isEmpty(featureColNames)) { + featureColNames = null; + } + + DataSet coefficientDim; + + if (vectorColName != null && vectorColName.length() != 0) { + coefficientDim = vectorSize; + } else { + coefficientDim = session.getExecutionEnvironment().fromElements(featureColNames.length + + (hasInterceptItem ? 1 : 0)); + } + // Loss object function + //checkout feasible constraint or not. + DataSet objFunc = session.getExecutionEnvironment() + .fromElements(getObjFunction(modelType, params)) + .map(new GetConstraint(featureColNames, hasInterceptItem, method, + params.get(ScorecardTrainBatchOp.WITH_ELSE))) + .withBroadcastSet(coefficientDim, "coef") + .withBroadcastSet(constraints, "constraints") + .withBroadcastSet(countZero, "countZero"); + return Tuple2.of(objFunc, coefficientDim); + } + + /** + * optimize linear problem + * @param optParam + * @param trainData + * + * @return coefficient of linear problem. + */ + public static DataSet > optimize(Params params, + Tuple2 , DataSet > + optParam, + DataSet > + trainData, + String modelName, String method) { + DataSet objFunc = optParam.f0; + DataSet coefficientDim = optParam.f1; + + if (params.contains(HasConstrainedOptimizationMethod.CONST_OPTIM_METHOD)) { + switch (ConstrainedOptMethod.valueOf(method)) { + case SQP: + return new Sqp(objFunc, trainData, coefficientDim, params).optimize(); + case BARRIER: + return new LogBarrier(objFunc, trainData, coefficientDim, params).optimize(); + case LBFGS: + return new Lbfgs(objFunc, trainData, coefficientDim, params).optimize(); + case NEWTON: + return new Newton(objFunc, trainData, coefficientDim, params).optimize(); + default: + throw new RuntimeException("do not support the " + method + " method!"); + } + } + //default opt method. + return new Sqp(objFunc, trainData, coefficientDim, params).optimize(); + } + + protected enum ConstrainedOptMethod { + SQP, + BARRIER, + LBFGS, + NEWTON + } + + private static class GetConstraint extends RichMapFunction { + private static final long serialVersionUID = -7810872210451727729L; + private int coefDim; + private FeatureConstraint constraint; + private String[] featureColNames; + private boolean hasInterceptItem; + private ConstrainedOptMethod method; + private DenseVector countZero = null; + private Map hasElse; + + GetConstraint(String[] featureColNames, boolean hasInterceptItem, String method, + Map withElse) { + this.featureColNames = featureColNames; + this.hasInterceptItem = hasInterceptItem; + this.method = ConstrainedOptMethod.valueOf(method.toUpperCase()); + this.hasElse = withElse; + } + + @Override + public void open(Configuration parameters) throws Exception { + coefDim = (int) getRuntimeContext().getBroadcastVariable("coef").get(0); + constraint = (FeatureConstraint) ((Row) getRuntimeContext().getBroadcastVariable("constraints").get(0)) + .getField(0); + if (constraint.fromScorecard()) { + this.countZero = (DenseVector) getRuntimeContext().getBroadcastVariable("countZero").get(0); + } + } + + @Override + public OptimObjFunc map(OptimObjFunc value) throws Exception { + ConstraintObjFunc objFunc = (ConstraintObjFunc) value; + if (!(ConstrainedOptMethod.LBFGS.equals(method) || ConstrainedOptMethod.NEWTON.equals(method))) { + ConstrainedLocalOptimizer.extractConstraintsForFeatureAndBin(constraint, objFunc, featureColNames, + coefDim, hasInterceptItem, countZero, hasElse); + //checkout feasible constraint or not. + int length = objFunc.equalityItem.size() + objFunc.inequalityItem.size(); + if (length != 0) { + int dim = objFunc.equalityConstraint.numRows() != 0 ? objFunc.equalityConstraint.numCols() : + objFunc.inequalityConstraint.numCols(); + double[] objData = new double[dim]; + LinearObjectiveFunction obj = new LinearObjectiveFunction(objData, 0); + List cons = new ArrayList <>(); + for (int i = 0; i < objFunc.equalityItem.size(); i++) { + double[] constraint = new double[dim]; + System.arraycopy(objFunc.equalityConstraint.getRow(i), 0, constraint, 0, dim); + double item = objFunc.equalityItem.get(i); + cons.add(new LinearConstraint(constraint, Relationship.EQ, item)); + } + for (int i = 0; i < objFunc.inequalityItem.size(); i++) { + double[] constraint = new double[dim]; + System.arraycopy(objFunc.inequalityConstraint.getRow(i), 0, constraint, 0, dim); + double item = objFunc.inequalityItem.get(i); + cons.add(new LinearConstraint(constraint, Relationship.GEQ, item)); + } + LinearConstraintSet conSet = new LinearConstraintSet(cons); + try { + PointValuePair pair = new SimplexSolver().optimize(obj, conSet, GoalType.MINIMIZE); + } catch (Exception e) { + throw new RuntimeException("infeasible constraint!", e); + } + } + } + + return objFunc; + } + } + + /** + * Get obj function. + * + * @param modelType model type. + * @param params parameters for train. + * @return + */ + public static OptimObjFunc getObjFunction(LinearModelType modelType, Params params) { + OptimObjFunc objFunc; + // For different model type, we must set corresponding loss object function. + if (modelType == LinearModelType.LinearReg) { + objFunc = new ConstraintObjFunc(new SquareLossFunc(), params); + } else if (modelType == LinearModelType.LR) { + objFunc = new ConstraintObjFunc(new LogLossFunc(), params); + } else { + throw new RuntimeException("Not implemented yet!"); + } + return objFunc; + } + + /** + * Write visualized data. + * + * @param modelRows model data in row format. + * @param vectorSize vector Size. + */ + private void writeVizData(DataSet modelRows, DataSet vectorSize) { + VizDataWriterInterface writer = getVizDataWriter(); + if (writer == null) { + return; + } + DataSet processedModelRows = modelRows.mapPartition(new RichMapPartitionFunction () { + private static final long serialVersionUID = -7146244281747193903L; + + @Override + public void mapPartition(Iterable values, Collector out) { + final int featureSize = (Integer) (getRuntimeContext() + .getBroadcastVariable(VECTOR_SIZE).get(0)); + if (featureSize <= NUM_FEATURE_THRESHOLD) { + values.forEach(out::collect); + } else { + String errorStr = "Not support models with #features > " + NUM_FEATURE_THRESHOLD; + out.collect(Row.of(errorStr)); + } + } + }).withBroadcastSet(vectorSize, VECTOR_SIZE).setParallelism(1); + VizDataWriterForModelInfo.writeModelInfo(writer, this.getClass().getSimpleName(), + this.getOutputTable().getSchema(), processedModelRows, getParams()); + } + + /** + * Create meta info. + */ + public static class CreateMeta implements MapPartitionFunction { + private static final long serialVersionUID = -7148219424266582224L; + private String modelName; + private LinearModelType modelType; + private boolean hasInterceptItem; + private String vectorColName; + private String labelName; + private boolean calcLabel; + private String positiveLabel; + + public CreateMeta(String modelName, LinearModelType modelType, + Params params, boolean calcLabel, String positiveLabel) { + this.modelName = modelName; + this.modelType = modelType; + this.hasInterceptItem = params.get(LinearTrainParams.WITH_INTERCEPT); + this.vectorColName = params.get(LinearTrainParams.VECTOR_COL); + this.labelName = params.get(LinearTrainParams.LABEL_COL); + this.calcLabel = calcLabel; + this.positiveLabel = positiveLabel; + } + + @Override + public void mapPartition(Iterable rows, Collector metas) throws Exception { + Object[] labels = null; + if (calcLabel) { + labels = orderLabels(rows, positiveLabel); + } + + Params meta = new Params(); + meta.set(ModelParamName.MODEL_NAME, this.modelName); + meta.set(ModelParamName.LINEAR_MODEL_TYPE, this.modelType); + meta.set(ModelParamName.LABEL_VALUES, labels); + meta.set(ModelParamName.HAS_INTERCEPT_ITEM, this.hasInterceptItem); + meta.set(ModelParamName.VECTOR_COL_NAME, vectorColName); + meta.set(LinearTrainParams.LABEL_COL, labelName); + metas.collect(meta); + } + } + + /** + * The size of coefficient. Transform dimension of trainData, if has Intercept item, dimension++. + */ + private static class DimTrans extends AbstractRichFunction + implements MapFunction { + private static final long serialVersionUID = 1997987979691400583L; + private boolean hasInterceptItem; + private Integer featureDim = null; + + public DimTrans(boolean hasInterceptItem) { + this.hasInterceptItem = hasInterceptItem; + } + + @Override + public void open(Configuration parameters) throws Exception { + this.featureDim = (Integer) getRuntimeContext() + .getBroadcastVariable(VECTOR_SIZE).get(0); + } + + @Override + public Integer map(Integer integer) throws Exception { + return this.featureDim + (this.hasInterceptItem ? 1 : 0); + } + } + + protected static Tuple3 , DataSet , DataSet > getStatInfo( + DataSet > trainData, + final boolean standardization) { + //may pass the param isScorecard, so that only need to get it when scorecard. + DataSet summary = StatisticsHelper.summary(trainData.map( + new MapFunction , Vector>() { + private static final long serialVersionUID = 6207307350053531656L; + + @Override + public Vector map(Tuple3 value) throws Exception { + return value.f2; + } + }).withForwardedFields()); + DataSet countZero = summary.map(new MapFunction () { + + private static final long serialVersionUID = 2322849507320367330L; + + @Override + public DenseVector map(BaseVectorSummary value) throws Exception { + if (value instanceof SparseVectorSummary) { + return (DenseVector) ((SparseVectorSummary) value).numNonZero(); + } + return new DenseVector(0); + } + }); + if (standardization) { + DataSet coefficientDim = summary.map(new MapFunction () { + private static final long serialVersionUID = -8051245706564042978L; + + @Override + public Integer map(BaseVectorSummary value) throws Exception { + return value.vectorSize(); + } + }); + DataSet meanVar = summary.map(new MapFunction () { + private static final long serialVersionUID = -6992060467629008691L; + + @Override + public DenseVector[] map(BaseVectorSummary value) { + if (value instanceof SparseVectorSummary) { + // If train data format is sparse vector, use maxAbs as variance and set mean zero, + // then, the standardization operation will turn into a scale operation. + // Because if do standardization to sparse vector, vector will be convert to be a dense one. + DenseVector max = ((SparseVector) value.max()).toDenseVector(); + DenseVector min = ((SparseVector) value.min()).toDenseVector(); + for (int i = 0; i < max.size(); ++i) { + max.set(i, Math.max(Math.abs(max.get(i)), Math.abs(min.get(i)))); + min.set(i, 0.0); + } + return new DenseVector[] {min, max}; + } else { + return new DenseVector[] {(DenseVector) value.mean(), + (DenseVector) value.standardDeviation()}; + } + } + }); + return Tuple3.of(coefficientDim, meanVar, countZero); + } else { + // If not do standardization, the we use mapReduce to get vector Dim. Mean and var set zero vector. + DataSet coefficientDim = trainData.mapPartition( + new MapPartitionFunction , Integer>() { + private static final long serialVersionUID = 3426157421982727224L; + + @Override + public void mapPartition(Iterable > values, Collector + out) + throws Exception { + int ret = -1; + for (Tuple3 val : values) { + if (val.f2 instanceof DenseVector) { + ret = ((DenseVector) val.f2).getData().length; + break; + } else { + + int[] ids = ((SparseVector) val.f2).getIndices(); + for (int id : ids) { + ret = Math.max(ret, id + 1); + } + } + ret = Math.max(ret, val.f2.size()); + } + + out.collect(ret); + } + }).reduceGroup(new GroupReduceFunction () { + private static final long serialVersionUID = 2752381384411882555L; + + @Override + public void reduce(Iterable values, Collector out) { + int ret = -1; + for (int vSize : values) { + ret = Math.max(ret, vSize); + } + out.collect(ret); + } + }); + + DataSet meanVar = coefficientDim.map(new MapFunction () { + private static final long serialVersionUID = 5448632685946933829L; + + @Override + public DenseVector[] map(Integer value) { + return new DenseVector[] {new DenseVector(0), new DenseVector(0)}; + } + }); + return Tuple3.of(coefficientDim, meanVar, countZero); + } + } + + /** + * Do standardization and interception to train data. + * + * @param initData initial data. + * @param params train parameters. + * @param meanVar mean and variance of train data. + * @return train data after standardization. + */ + protected static DataSet > preProcess( + DataSet > initData, + Params params, + DataSet meanVar) { + // Get parameters. + final boolean hasInterceptItem = params.get(LinearTrainParams.WITH_INTERCEPT); + final boolean standardization = params.get(LinearTrainParams.STANDARDIZATION); + + return initData.map( + new RichMapFunction , Tuple3 >() { + private static final long serialVersionUID = -5342628140781184056L; + private DenseVector[] meanVar; + + @Override + public void open(Configuration parameters) throws Exception { + this.meanVar = (DenseVector[]) getRuntimeContext() + .getBroadcastVariable(MEAN_VAR).get(0); + modifyMeanVar(standardization, meanVar); + } + + @Override + public Tuple3 map(Tuple3 value) + throws Exception { + + Vector aVector = value.f2; + if (aVector instanceof DenseVector) { + DenseVector bVector; + if (standardization) { + if (hasInterceptItem) { + bVector = new DenseVector(aVector.size() + 1); + bVector.set(0, 1.0); + for (int i = 0; i < aVector.size(); ++i) { + bVector.set(i + 1, (aVector.get(i) - meanVar[0].get(i)) / meanVar[1].get(i)); + } + } else { + bVector = (DenseVector) aVector; + for (int i = 0; i < aVector.size(); ++i) { + bVector.set(i, aVector.get(i) / meanVar[1].get(i)); + } + } + } else { + if (hasInterceptItem) { + bVector = new DenseVector(aVector.size() + 1); + bVector.set(0, 1.0); + for (int i = 0; i < aVector.size(); ++i) { + bVector.set(i + 1, aVector.get(i)); + } + } else { + bVector = (DenseVector) aVector; + } + } + return Tuple3.of(value.f0, value.f1, bVector); + + } else { + SparseVector bVector = (SparseVector) aVector; + + if (standardization) { + if (hasInterceptItem) { + + int[] indices = bVector.getIndices(); + double[] vals = bVector.getValues(); + for (int i = 0; i < indices.length; ++i) { + vals[i] = (vals[i] - meanVar[0].get(indices[i])) / meanVar[1].get( + indices[i]); + } + bVector = bVector.prefix(1.0); + } else { + int[] indices = bVector.getIndices(); + double[] vals = bVector.getValues(); + for (int i = 0; i < indices.length; ++i) { + vals[i] = vals[i] / meanVar[1].get(indices[i]); + } + } + } else { + if (hasInterceptItem) { + bVector = bVector.prefix(1.0); + } + } + return Tuple3.of(value.f0, value.f1, bVector); + } + } + }).withBroadcastSet(meanVar, MEAN_VAR); + } + + /** + * Get label info: including label values and label type. + * + * @param in input train data in BatchOperator format. + * @param params train parameters. + * @param isRegProc not use label, include linear regression which not in scorecard. + * @return label info. + */ + protected static Tuple2 , TypeInformation> getLabelInfo(BatchOperator in, + Params params, + boolean isRegProc) { + String labelName = params.get(LinearTrainParams.LABEL_COL); + // Prepare label values + DataSet labelValues; + TypeInformation labelType = null; + if (isRegProc) { + labelType = Types.DOUBLE; + labelValues = MLEnvironmentFactory.get(in.getMLEnvironmentId()) + .getExecutionEnvironment().fromElements(new Object()); + } else { + labelType = in.getColTypes()[TableUtil.findColIndexWithAssertAndHint(in.getColNames(), labelName)]; + labelValues = Preprocessing.distinctLabels(Preprocessing.select(in, new String[] {labelName}) + .getDataSet().map(new MapFunction () { + private static final long serialVersionUID = -419245917074561046L; + + @Override + public Object map(Row value) throws Exception { + return value.getField(0); + } + })).flatMap(new FlatMapFunction () { + private static final long serialVersionUID = -5089566319196319692L; + + @Override + public void flatMap(Object[] value, Collector out) throws Exception { + for (Object obj : value) { + out.collect(obj); + } + } + }); + } + return Tuple2.of(labelValues, labelType); + } + + /** + * modify mean and variance, if variance equals zero, then modify them. + * + * @param standardization do standardization or not. + * @param meanVar mean and variance. + */ + private static void modifyMeanVar(boolean standardization, DenseVector[] meanVar) { + if (standardization) { + for (int i = 0; i < meanVar[1].size(); ++i) { + if (meanVar[1].get(i) == 0) { + meanVar[1].set(i, 1.0); + meanVar[0].set(i, 0.0); + } + } + } + } + + private static class BuildLabels implements + FlatMapFunction , Object[]> { + private boolean isRegProc; + private String positiveLabel; + + BuildLabels(boolean isRegProc, String positiveLabel) { + this.isRegProc = isRegProc; + this.positiveLabel = positiveLabel; + } + + private static final long serialVersionUID = 5375954526931728363L; + + @Override + public void flatMap(Tuple3 value, + Collector out) + throws Exception { + if (!isRegProc) { + Preconditions.checkState((value.f1.length == 2), + "labels count should be 2 in in classification algo."); + + out.collect(orderLabels(value.f1, positiveLabel)); + } else { + out.collect(value.f1); + } + } + } +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/linear/BaseLinearModelTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/linear/BaseLinearModelTrainBatchOp.java index 323e44aa8..205b981f0 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/linear/BaseLinearModelTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/linear/BaseLinearModelTrainBatchOp.java @@ -13,6 +13,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.tuple.Tuple5; @@ -37,22 +38,16 @@ import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; import com.alibaba.alink.common.exceptions.AkIllegalDataException; import com.alibaba.alink.common.exceptions.AkPreconditions; -import com.alibaba.alink.common.exceptions.AkUnimplementedOperationException; -import com.alibaba.alink.common.lazy.WithTrainInfo; +import com.alibaba.alink.operator.batch.utils.WithTrainInfo; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.linalg.Vector; import com.alibaba.alink.common.linalg.VectorUtil; import com.alibaba.alink.common.model.ModelParamName; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.common.linear.unarylossfunc.LogLossFunc; -import com.alibaba.alink.operator.common.linear.unarylossfunc.PerceptronLossFunc; -import com.alibaba.alink.operator.common.linear.unarylossfunc.SmoothHingeLossFunc; -import com.alibaba.alink.operator.common.linear.unarylossfunc.SquareLossFunc; -import com.alibaba.alink.operator.common.linear.unarylossfunc.SvrLossFunc; import com.alibaba.alink.operator.common.optim.Lbfgs; import com.alibaba.alink.operator.common.optim.Optimizer; import com.alibaba.alink.operator.common.optim.OptimizerFactory; @@ -192,7 +187,8 @@ public void flatMap(Tuple3 value, DataSet > trainData = preProcess(initData, params, isRegProc, meanVar, labelValues, featSize); - DataSet initModelDataSet = getInitialModel(initModel, featSize, meanVar, params, linearModelType); + DataSet initModelDataSet = getInitialModel(initModel, featSize, meanVar, params, + linearModelType); // Solve the optimization problem. DataSet > coefVectorSet = optimize(params, featSize, @@ -215,7 +211,7 @@ public void flatMap(Tuple3 value, // Convert the model rows to table. this.setOutput(modelRows, new LinearModelDataConverter(labelType).getModelSchema()); - this.setSideOutputTables(getSideTablesOfCoefficient(modelRows, initData, featSize, + this.setSideOutputTables(getSideTablesOfCoefficient(coefVectorSet.project(1), modelRows, initData, featSize, params.get(LinearTrainParams.FEATURE_COLS), params.get(LinearTrainParams.WITH_INTERCEPT), getMLEnvironmentId())); @@ -227,48 +223,51 @@ public static DataSet getInitialModel(BatchOperator initModel, DataSet meanVar, Params params, final LinearModelType localLinearModelType) { - return initModel == null ? null : initModel.getDataSet().reduceGroup( - new RichGroupReduceFunction () { - @Override - public void reduce(Iterable values, Collector out) { - int featSize = (int) getRuntimeContext().getBroadcastVariable("featSize").get(0); - DenseVector[] meanVar = - (DenseVector[]) getRuntimeContext().getBroadcastVariable("meanVar").get(0); - List modelRows = new ArrayList <>(0); - for (Row row : values) { - modelRows.add(row); - } - LinearModelData model = new LinearModelDataConverter().load(modelRows); + return initModel == null ? null : initModel.getDataSet().reduceGroup( + new RichGroupReduceFunction () { + @Override + public void reduce(Iterable values, Collector out) { + int featSize = (int) getRuntimeContext().getBroadcastVariable("featSize").get(0); + DenseVector[] meanVar = + (DenseVector[]) getRuntimeContext().getBroadcastVariable("meanVar").get(0); + List modelRows = new ArrayList <>(0); + for (Row row : values) { + modelRows.add(row); + } + LinearModelData model = new LinearModelDataConverter().load(modelRows); - if (!(model.hasInterceptItem == params.get(HasWithIntercept.WITH_INTERCEPT))) { - throw new AkIllegalArgumentException("Initial linear model is not compatible with parameter setting." - + "InterceptItem parameter setting error."); - } - if (!(model.linearModelType == localLinearModelType)) { - throw new AkIllegalArgumentException("Initial linear model is not compatible with parameter setting." - + "linearModelType setting error."); - } - if (!(model.vectorSize == featSize)) { - throw new AkIllegalDataException("Initial linear model is not compatible with training data. " - + " vector size not equal, vector size in init model is : " + model.vectorSize + - " and vector size of train data is : " + featSize); - } - int n = meanVar[0].size(); - if (model.hasInterceptItem) { - double sum = 0.0; - for (int i = 1; i < n; ++i) { - sum += model.coefVector.get(i) * meanVar[0].get(i); - model.coefVector.set(i, model.coefVector.get(i) * meanVar[1].get(i)); + if (!(model.hasInterceptItem == params.get(HasWithIntercept.WITH_INTERCEPT))) { + throw new AkIllegalArgumentException( + "Initial linear model is not compatible with parameter setting." + + "InterceptItem parameter setting error."); } - model.coefVector.set(0, model.coefVector.get(0) + sum); - } else { - for (int i = 0; i < n; ++i) { - model.coefVector.set(i, model.coefVector.get(i) * meanVar[1].get(i)); + if (!(model.linearModelType == localLinearModelType)) { + throw new AkIllegalArgumentException( + "Initial linear model is not compatible with parameter setting." + + "linearModelType setting error."); + } + if (!(model.vectorSize == featSize)) { + throw new AkIllegalDataException("Initial linear model is not compatible with training " + + "data. " + + " vector size not equal, vector size in init model is : " + model.vectorSize + + " and vector size of train data is : " + featSize); } + int n = meanVar[0].size(); + if (model.hasInterceptItem) { + double sum = 0.0; + for (int i = 1; i < n; ++i) { + sum += model.coefVector.get(i) * meanVar[0].get(i); + model.coefVector.set(i, model.coefVector.get(i) * meanVar[1].get(i)); + } + model.coefVector.set(0, model.coefVector.get(0) + sum); + } else { + for (int i = 0; i < n; ++i) { + model.coefVector.set(i, model.coefVector.get(i) * meanVar[1].get(i)); + } + } + out.collect(model.coefVector); } - out.collect(model.coefVector); - } - }) + }) .withBroadcastSet(featSize, "featSize") .withBroadcastSet(meanVar, "meanVar"); } @@ -462,25 +461,25 @@ public void reduce(Iterable > values, }); } - public static Table[] getSideTablesOfCoefficient(DataSet modelRow, + public static Table[] getSideTablesOfCoefficient(DataSet > coefInfo, + DataSet modelRows, DataSet > inputData, DataSet vecSize, final String[] featureNames, final boolean hasInterception, long environmentId) { - DataSet model = modelRow.mapPartition(new MapPartitionFunction () { - private static final long serialVersionUID = 2063366042018382802L; - - @Override - public void mapPartition(Iterable values, Collector out) { - List rows = new ArrayList <>(); - for (Row row : values) { - rows.add(row); - } - out.collect(new LinearModelDataConverter().load(rows)); - } - }).setParallelism(1); - + DataSet model = modelRows.mapPartition(new MapPartitionFunction () { + private static final long serialVersionUID = 2063366042018382802L; + + @Override + public void mapPartition(Iterable values, Collector out) { + List rows = new ArrayList <>(); + for (Row row : values) { + rows.add(row); + } + out.collect(new LinearModelDataConverter().load(rows)); + } + }).setParallelism(1); DataSet > allInfo = inputData .mapPartition( new RichMapPartitionFunction , Tuple3 > @@ -570,15 +569,19 @@ public Tuple3 reduce(Tuple3 >() { private static final long serialVersionUID = 7815111101106759520L; private DenseVector coefVec; - private LinearModelData model; + private Tuple2 model; private double[] cinfo; + private Params metaInfo; @Override public void open(Configuration parameters) throws Exception { super.open(parameters); - model = ((LinearModelData) getRuntimeContext().getBroadcastVariable("model").get(0)); + cinfo = ((Tuple1 ) getRuntimeContext().getBroadcastVariable( + "cinfo").get(0)).f0; + LinearModelData model = (LinearModelData) getRuntimeContext().getBroadcastVariable( + "model").get(0); coefVec = model.coefVector; - cinfo = model.convergenceInfo; + metaInfo = model.getMetaInfo(); } @Override @@ -613,11 +616,13 @@ public void flatMap(Tuple3 value, } out.collect( - Tuple5.of(JsonConverter.toJson(model.getMetaInfo()), colNames, coefVec.getData(), + Tuple5.of(JsonConverter.toJson(metaInfo), colNames, coefVec.getData(), importance, cinfo)); } - }).setParallelism(1).withBroadcastSet(model, "model"); + }).setParallelism(1) + .withBroadcastSet(model, "model") + .withBroadcastSet(coefInfo, "cinfo"); DataSet importance = allInfo.mapPartition( new MapPartitionFunction , Row>() { @@ -793,7 +798,7 @@ public Integer map(Integer value) { } // Loss object function DataSet objFunc = session.getExecutionEnvironment() - .fromElements(getObjFunction(modelType, params)); + .fromElements(OptimObjFunc.getObjFunction(modelType, params)); Optimizer optimizer; if (params.contains(LinearTrainParams.OPTIM_METHOD)) { LinearTrainParams.OptimMethod method = params.get(LinearTrainParams.OPTIM_METHOD); @@ -807,41 +812,6 @@ public Integer map(Integer value) { return optimizer.optimize(); } - /** - * Get obj function. - * - * @param modelType Model type. - * @param params Parameters for train. - * @return Obj function. - */ - public static OptimObjFunc getObjFunction(LinearModelType modelType, Params params) { - OptimObjFunc objFunc; - // For different model type, we must set corresponding loss object function. - switch (modelType) { - case LinearReg: - objFunc = new UnaryLossObjFunc(new SquareLossFunc(), params); - break; - case SVR: - double svrTau = params.get(LinearSvrTrainParams.TAU); - objFunc = new UnaryLossObjFunc(new SvrLossFunc(svrTau), params); - break; - case LR: - objFunc = new UnaryLossObjFunc(new LogLossFunc(), params); - break; - case SVM: - objFunc = new UnaryLossObjFunc(new SmoothHingeLossFunc(), params); - break; - case Perceptron: - objFunc = new UnaryLossObjFunc(new PerceptronLossFunc(), params); - break; - case AFT: - objFunc = new AftRegObjFunc(params); - break; - default: - throw new AkUnimplementedOperationException("Linear model type is Not implemented yet!"); - } - return objFunc; - } /** * Transform train data to Tuple3 format. @@ -912,6 +882,8 @@ protected static String[] getFeatureTypes(BatchOperator in, String[] feature featureColTypes[i] = "short"; } else if (type.equals(Types.BOOLEAN)) { featureColTypes[i] = "bool"; + } else if (type.equals(Types.BIG_DEC)) { + featureColTypes[i] = "decimal"; } else { throw new AkIllegalArgumentException( "Linear algorithm only support numerical data type. Current type is : " + type); @@ -942,68 +914,68 @@ protected static DataSet > preProcess( final boolean standardization = params.get(LinearTrainParams.STANDARDIZATION); final boolean hasIntercept = params.get(LinearTrainParams.WITH_INTERCEPT); return initData.mapPartition( - new RichMapPartitionFunction , Tuple3 >() { - private static final long serialVersionUID = -3931917328901089041L; - private DenseVector[] meanVar; - private Object[] labelValues = null; - private int featureSize; + new RichMapPartitionFunction , Tuple3 >() { + private static final long serialVersionUID = -3931917328901089041L; + private DenseVector[] meanVar; + private Object[] labelValues = null; + private int featureSize; - @Override - public void open(Configuration parameters) { - this.meanVar = (DenseVector[]) getRuntimeContext() - .getBroadcastVariable(MEAN_VAR).get(0); - this.labelValues = (Object[]) getRuntimeContext() - .getBroadcastVariable(LABEL_VALUES).get(0); - this.featureSize = (int) getRuntimeContext().getBroadcastVariable("featureSize").get(0); - modifyMeanVar(standardization, meanVar); - } + @Override + public void open(Configuration parameters) { + this.meanVar = (DenseVector[]) getRuntimeContext() + .getBroadcastVariable(MEAN_VAR).get(0); + this.labelValues = (Object[]) getRuntimeContext() + .getBroadcastVariable(LABEL_VALUES).get(0); + this.featureSize = (int) getRuntimeContext().getBroadcastVariable("featureSize").get(0); + modifyMeanVar(standardization, meanVar); + } - @Override - public void mapPartition(Iterable > values, - Collector > out) { - for (Tuple3 value : values) { - Vector aVector = value.f2; - - if (value.f0 > 0) { - Double label = isRegProc ? Double.parseDouble(value.f1.toString()) - : (value.f1.equals(labelValues[0]) ? 1.0 : -1.0); - if (aVector instanceof DenseVector) { - if (aVector.size() < featureSize) { - DenseVector tmp = new DenseVector(featureSize); - for (int i = 0; i < aVector.size(); ++i) { - tmp.set(i, aVector.get(i)); - } - aVector = tmp; - } - if (standardization) { - if (hasIntercept) { + @Override + public void mapPartition(Iterable > values, + Collector > out) { + for (Tuple3 value : values) { + Vector aVector = value.f2; + + if (value.f0 > 0) { + Double label = isRegProc ? Double.parseDouble(value.f1.toString()) + : (value.f1.equals(labelValues[0]) ? 1.0 : -1.0); + if (aVector instanceof DenseVector) { + if (aVector.size() < featureSize) { + DenseVector tmp = new DenseVector(featureSize); for (int i = 0; i < aVector.size(); ++i) { - aVector.set(i, - (aVector.get(i) - meanVar[0].get(i)) / meanVar[1].get(i)); + tmp.set(i, aVector.get(i)); } - } else { - for (int i = 0; i < aVector.size(); ++i) { - aVector.set(i, aVector.get(i) / meanVar[1].get(i)); + aVector = tmp; + } + if (standardization) { + if (hasIntercept) { + for (int i = 0; i < aVector.size(); ++i) { + aVector.set(i, + (aVector.get(i) - meanVar[0].get(i)) / meanVar[1].get(i)); + } + } else { + for (int i = 0; i < aVector.size(); ++i) { + aVector.set(i, aVector.get(i) / meanVar[1].get(i)); + } } } - } - } else { - if (standardization) { - int[] indices = ((SparseVector) aVector).getIndices(); - double[] vals = ((SparseVector) aVector).getValues(); - for (int i = 0; i < indices.length; ++i) { - vals[i] = vals[i] / meanVar[1].get(indices[i]); + } else { + if (standardization) { + int[] indices = ((SparseVector) aVector).getIndices(); + double[] vals = ((SparseVector) aVector).getValues(); + for (int i = 0; i < indices.length; ++i) { + vals[i] = vals[i] / meanVar[1].get(indices[i]); + } + } + if (aVector.size() == -1 || aVector.size() == 0) { + ((SparseVector) aVector).setSize(featureSize); } } - if (aVector.size() == -1 || aVector.size() == 0) { - ((SparseVector) aVector).setSize(featureSize); - } + out.collect(Tuple3.of(value.f0, label, aVector)); } - out.collect(Tuple3.of(value.f0, label, aVector)); } } - } - }).withBroadcastSet(meanVar, MEAN_VAR) + }).withBroadcastSet(meanVar, MEAN_VAR) .withBroadcastSet(labelValues, LABEL_VALUES) .withBroadcastSet(featSize, "featureSize"); } @@ -1096,7 +1068,6 @@ public static LinearModelData buildLinearModelData(Params meta, } LinearModelData modelData = new LinearModelData(labelType, meta, featureNames, coefVector.f0); - modelData.convergenceInfo = coefVector.f1; modelData.labelName = meta.get(LinearTrainParams.LABEL_COL); modelData.featureTypes = meta.get(ModelParamName.FEATURE_TYPES); diff --git a/core/src/main/java/com/alibaba/alink/operator/common/linear/FeatureLabelUtil.java b/core/src/main/java/com/alibaba/alink/operator/common/linear/FeatureLabelUtil.java index 1cfd381db..0eb89d988 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/linear/FeatureLabelUtil.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/linear/FeatureLabelUtil.java @@ -16,19 +16,8 @@ public class FeatureLabelUtil { public static Vector getVectorFeature(Object input, boolean hasInterceptItem, Integer vectorSize) { - Vector aVector; Vector vec = VectorUtil.getVector(input); - if (vec instanceof SparseVector) { - SparseVector tmp = (SparseVector) vec; - if (null != vectorSize && tmp.size() > 0) { - tmp.setSize(vectorSize); - } - aVector = hasInterceptItem ? tmp.prefix(1.0) : tmp; - } else { - DenseVector tmp = (DenseVector) vec; - aVector = hasInterceptItem ? tmp.prefix(1.0) : tmp; - } - return aVector; + return hasInterceptItem ? vec.prefix(1.0) : vec; } /** diff --git a/core/src/main/java/com/alibaba/alink/operator/common/linear/LinearModelData.java b/core/src/main/java/com/alibaba/alink/operator/common/linear/LinearModelData.java index 423a7ed23..e3f3db7f0 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/linear/LinearModelData.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/linear/LinearModelData.java @@ -35,7 +35,6 @@ public class LinearModelData implements Serializable { public Object[] labelValues = null; public LinearModelType linearModelType; public boolean hasInterceptItem = true; - public double[] convergenceInfo;// public TypeInformation labelType;// /** diff --git a/core/src/main/java/com/alibaba/alink/operator/common/linear/LinearModelDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/linear/LinearModelDataConverter.java index 6ac66832e..a9d057f3c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/linear/LinearModelDataConverter.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/linear/LinearModelDataConverter.java @@ -119,7 +119,6 @@ private ModelData getModelData(LinearModelData data) { modelData.featureColTypes = data.featureTypes; modelData.coefVector = data.coefVector; modelData.coefVectors = data.coefVectors; - modelData.convergenceInfo = data.convergenceInfo; return modelData; } @@ -130,7 +129,6 @@ private void setModelData(ModelData modelData, LinearModelData data) { data.featureNames = modelData.featureColNames; data.featureTypes = modelData.featureColTypes; data.coefVector = modelData.coefVector; - data.convergenceInfo = modelData.convergenceInfo; if (data.modelName.equals("softmax")) { double[] w = modelData.coefVector.getData(); diff --git a/core/src/main/java/com/alibaba/alink/operator/common/linear/LinearModelTrainInfo.java b/core/src/main/java/com/alibaba/alink/operator/common/linear/LinearModelTrainInfo.java index 999725614..ac6d4578c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/linear/LinearModelTrainInfo.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/linear/LinearModelTrainInfo.java @@ -153,16 +153,16 @@ public String toString() { } } sbd.append(PrettyDisplayUtils.displayHeadline("train convergence info", '-')); - if (convInfo.length < 6) { + if (convInfo.length < 20) { for (String s : convInfo) { sbd.append(s).append("\n"); } } else { - for (int i = 0; i < 3; ++i) { + for (int i = 0; i < 10; ++i) { sbd.append(convInfo[i]).append("\n"); } sbd.append("" + "... ... ... ..." + "\n"); - for (int i = convInfo.length - 3; i < convInfo.length; ++i) { + for (int i = convInfo.length - 10; i < convInfo.length; ++i) { sbd.append(convInfo[i]).append("\n"); } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/linear/LinearRegressionSummary.java b/core/src/main/java/com/alibaba/alink/operator/common/linear/LinearRegressionSummary.java new file mode 100644 index 000000000..2edc85461 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/linear/LinearRegressionSummary.java @@ -0,0 +1,34 @@ +package com.alibaba.alink.operator.common.linear; + +import com.alibaba.alink.operator.common.finance.stepwiseSelector.RegressionSelectorStep; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.SelectorStep; + +public class LinearRegressionSummary extends ModelSummary { + + public double ra2; + public double r2; + public double mallowCp; + public double fValue; + public double pValue; + public double sse; + public double[] stdEsts; + public double[] stdErrs; + public double[] tValues; + public double[] tPVaues; + public double[] lowerConfidence; + public double[] uperConfidence; + + @Override + public SelectorStep toSelectStep(int inId) { + RegressionSelectorStep step = new RegressionSelectorStep(); + step.enterCol = String.valueOf(inId); + step.fValue = this.fValue; + step.mallowCp = this.mallowCp; + step.r2 = this.r2; + step.ra2 = this.ra2; + step.pValue = this.pValue; + step.numberIn = this.beta.size() - 1; + + return step; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/linear/LocalLinearModel.java b/core/src/main/java/com/alibaba/alink/operator/common/linear/LocalLinearModel.java new file mode 100644 index 000000000..7a8e667ee --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/linear/LocalLinearModel.java @@ -0,0 +1,462 @@ +package com.alibaba.alink.operator.common.linear; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; + +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.linalg.jama.JMatrixFunc; +import com.alibaba.alink.common.probabilistic.CDF; +import com.alibaba.alink.common.probabilistic.PDF; +import com.alibaba.alink.operator.common.optim.LocalOptimizer; +import com.alibaba.alink.operator.common.optim.local.ConstrainedLocalOptimizer; +import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc; +import com.alibaba.alink.operator.common.regression.LinearRegressionModel; +import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummarizer; +import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; +import com.alibaba.alink.params.ParamUtil; +import com.alibaba.alink.params.feature.HasConstraint; +import com.alibaba.alink.params.finance.HasConstrainedOptimizationMethod; +import com.alibaba.alink.params.regression.LinearRegPredictParams; +import com.alibaba.alink.params.regression.LinearRegTrainParams; +import com.alibaba.alink.params.shared.linear.HasL1; +import com.alibaba.alink.params.shared.linear.HasL2; +import com.alibaba.alink.params.shared.linear.LinearTrainParams; +import com.alibaba.alink.params.shared.linear.LinearTrainParams.OptimMethod; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class LocalLinearModel { + + public static ModelSummary trainWithSummary(List > trainData, + int[] indices, + LinearModelType modelType, + String optimMethod, + boolean hasIntercept, + boolean standardization, + String constraint, + double l1, + double l2, + BaseVectorSummarizer srt) { + if (indices == null) { + int size = trainData.get(0).f2.size(); + indices = new int[size]; + for (int i = 0; i < indices.length; i++) { + indices[i] = i; + } + } + + // if (modelType == LinearModelType.LinearReg && l1 <= 0 && l2 <= 0) { + // //train with srt. + // return null; + // } else { + + Tuple4 model = train(trainData, indices, modelType, + optimMethod, hasIntercept, standardization, constraint, l1, l2); + return calcModelSummary(model, srt, modelType, indices); + // } + } + + public static ModelSummary calcModelSummary(Tuple4 model, + BaseVectorSummarizer srt, + LinearModelType modelType, + int[] indices) { + + if (modelType == LinearModelType.LR) { + return calcLrSummary(model, srt); + } else { + return calcLinearRegressionSummary(model, srt, 0, indicesAddOne(indices)); + } + } + + public static Tuple4 train( + List > trainData, + int[] indices, + LinearModelType modelType, + OptimMethod optimMethod, + boolean hasIntercept, + boolean standardization, + double l1, + double l2) { + int featureSize = indices.length; + List > selectedData; + if (hasIntercept) { + selectedData = new ArrayList <>(); + for (Tuple3 data : trainData) { + selectedData.add(Tuple3.of(data.f0, data.f1, data.f2.slice(indices).prefix(1.0))); + } + } else { + selectedData = trainData; + } + final OptimObjFunc objFunc = OptimObjFunc.getObjFunction(modelType, new Params()); + + Params optParams = new Params() + .set(LinearTrainParams.OPTIM_METHOD, + ParamUtil.searchEnum(LinearTrainParams.OPTIM_METHOD, optimMethod.name())) + .set(LinearTrainParams.WITH_INTERCEPT, hasIntercept) + .set(LinearTrainParams.STANDARDIZATION, standardization) + .set(HasL1.L_1, l1) + .set(HasL2.L_2, l2); + + DenseVector initialWeights = DenseVector.zeros(featureSize + (hasIntercept ? 1 : 0)); + + Tuple4 temp; + if (optimMethod == OptimMethod.Newton) { + try { + temp = LocalOptimizer.newtonWithHessian(selectedData, initialWeights, optParams, objFunc); + } catch (Exception e) { + throw new RuntimeException("Local trainLinear failed.", e); + } + } else { + try { + Tuple2 tuple2 = LocalOptimizer.optimize(objFunc, selectedData, initialWeights, + optParams); + + Params newtonParams = new Params() + .set(LinearTrainParams.OPTIM_METHOD, ParamUtil.searchEnum(LinearTrainParams.OPTIM_METHOD, + "newton")) + .set(LinearTrainParams.WITH_INTERCEPT, hasIntercept) + .set(LinearTrainParams.STANDARDIZATION, standardization) + .set(HasL1.L_1, l1) + .set(HasL2.L_2, l2) + .set(LinearRegTrainParams.MAX_ITER, 1); + temp = LocalOptimizer.newtonWithHessian(selectedData, tuple2.f0, newtonParams, objFunc); + } catch (Exception e) { + throw new RuntimeException("Local trainLinear failed.", e); + } + } + + return Tuple4.of(temp.f0, temp.f1, temp.f2, temp.f3[temp.f3.length - 3]); + } + + public static Tuple4 constrainedTrain( + List > trainData, + int[] indices, + LinearModelType modelType, + HasConstrainedOptimizationMethod.ConstOptimMethod optimMethod, + boolean hasIntercept, + boolean standardization, + String constraint, + double l1, + double l2) { + List > selectedData; + if (hasIntercept) { + selectedData = new ArrayList <>(); + for (Tuple3 data : trainData) { + selectedData.add(Tuple3.of(data.f0, data.f1, data.f2.slice(indices).prefix(1.0))); + } + } else { + selectedData = trainData; + } + Params optParams = new Params() + .set(HasConstrainedOptimizationMethod.CONST_OPTIM_METHOD, + ParamUtil.searchEnum(HasConstrainedOptimizationMethod.CONST_OPTIM_METHOD, optimMethod.name())) + .set(LinearTrainParams.WITH_INTERCEPT, hasIntercept) + .set(LinearTrainParams.STANDARDIZATION, standardization) + .set(HasL1.L_1, l1) + .set(HasL2.L_2, l2) + .set(HasConstraint.CONSTRAINT, constraint); + + if (optimMethod == HasConstrainedOptimizationMethod.ConstOptimMethod.SQP || + optimMethod == HasConstrainedOptimizationMethod.ConstOptimMethod.Barrier) { + return ConstrainedLocalOptimizer.optimizeWithHessian(selectedData, modelType, optParams); + } else { + throw new RuntimeException("It is not support for constrainedTrain"); + } + } + + public static Tuple4 train( + List > trainData, + int[] indices, + LinearModelType modelType, + String optimMethod, + boolean hasIntercept, + boolean standardization, + String constraint, + double l1, + double l2) { + String optimMethodUper = optimMethod.toUpperCase().trim(); + if ("SQP".equals(optimMethodUper) || "BARRIER".equals(optimMethodUper)) { + return constrainedTrain(trainData, + indices, + modelType, + HasConstrainedOptimizationMethod.ConstOptimMethod.valueOf(optimMethodUper), + hasIntercept, + standardization, + constraint, + l1, + l2); + } else { + return train(trainData, + indices, + modelType, + OptimMethod.valueOf(optimMethodUper), + hasIntercept, + standardization, + l1, + l2); + } + + } + + private static LinearRegressionSummary calcLinearRegressionSummary( + Tuple4 model, + BaseVectorSummarizer srt, + int indexY, + int[] indexX) { + LinearRegressionSummary summary = calcLinearRegressionSummary(model.f0, srt, indexY, indexX); + summary.gradient = model.f1; + summary.hessian = model.f2; + summary.loss = model.f3; + return summary; + + } + + static LinearRegressionSummary calcLinearRegressionSummary(DenseVector beta, + BaseVectorSummarizer srt, + int indexY, + int[] indexX) { + BaseVectorSummary summary = srt.toSummary(); + + if (summary.count() == 0) { + throw new RuntimeException("table is empty!"); + } + + if (summary.vectorSize() < indexX.length) { + throw new RuntimeException("record size Less than features size!"); + } + + int nx = indexX.length; + long N = summary.count(); + if (N == 0) { + throw new RuntimeException("Y valid value num is zero!"); + } + + String nameY = "label"; + String[] nameX = new String[indexX.length]; + Arrays.fill(nameX, "col"); + + LinearRegressionModel lrr = new LinearRegressionModel(N, nameY, nameX); + + double[] XBar = new double[nx]; + for (int i = 0; i < nx; i++) { + XBar[i] = summary.mean(indexX[i]); + } + double yBar = summary.mean(indexY); + + double[][] cov = srt.covariance().getArrayCopy2D(); + DenseMatrix dot = srt.getOuterProduct(); + + DenseMatrix A = new DenseMatrix(nx, nx); + for (int i = 0; i < nx; i++) { + for (int j = 0; j < nx; j++) { + A.set(i, j, cov[indexX[i]][indexX[j]]); + } + } + DenseMatrix C = new DenseMatrix(nx, 1); + for (int i = 0; i < nx; i++) { + C.set(i, 0, cov[indexX[i]][indexY]); + } + + lrr.beta = beta.getData(); + + double S = summary.variance(indexY) * (summary.count() - 1); + double alpha = lrr.beta[0] - yBar; + double U = 0.0; + U += alpha * alpha * N; + for (int i = 0; i < nx; i++) { + U += 2 * alpha * summary.sum(indexX[i]) * lrr.beta[i + 1]; + } + for (int i = 0; i < nx; i++) { + for (int j = 0; j < nx; j++) { + U += lrr.beta[i + 1] * lrr.beta[j + 1] * (cov[indexX[i]][indexX[j]] * (N - 1) + summary.mean(indexX[i]) + * summary.mean(indexX[j]) * N); + } + } + + double ms = summary.normL2(indexY); + for (int i = 0; i < nx && indexX[i] < dot.numCols(); i++) { + ms -= 2 * lrr.beta[i + 1] * dot.get(indexY, indexX[i]); + } + ms -= 2 * lrr.beta[0] * summary.sum(indexY); + + for (int i = 0; i < nx; i++) { + for (int j = 0; j < nx; j++) { + if (indexX[i] < dot.numCols() && indexX[j] < dot.numCols()) { + ms += lrr.beta[i + 1] * lrr.beta[j + 1] * dot.get(indexX[i], indexX[j]); + } + } + ms += 2 * lrr.beta[i + 1] * lrr.beta[0] * summary.sum(indexX[i]); + } + + ms += summary.count() * lrr.beta[0] * lrr.beta[0]; + + lrr.SST = S; + lrr.SSR = U; + lrr.SSE = S - U; + if (lrr.SSE < 0) { + lrr.SSE = ms; + } + lrr.dfSST = N - 1; + lrr.dfSSR = nx; + lrr.dfSSE = N - nx - 1 - 1; //1 is intercept + lrr.R2 = Math.max(0.0, Math.min(1.0, lrr.SSR / lrr.SST)); + lrr.R = Math.sqrt(lrr.R2); + lrr.MST = lrr.SST / lrr.dfSST; + lrr.MSR = lrr.SSR / lrr.dfSSR; + lrr.MSE = lrr.SSE / lrr.dfSSE; + lrr.Ra2 = 1 - lrr.MSE / lrr.MST; + lrr.s = Math.sqrt(lrr.MSE); + lrr.F = lrr.MSR / lrr.MSE; + if (lrr.F < 0) { + lrr.F = 0; + } + lrr.AIC = N * Math.log(lrr.SSE) + 2 * nx; + + A.scaleEqual(N - 1); + + DenseMatrix invA = A.solveLS(JMatrixFunc.identity(A.numRows(), A.numRows())); + + for (int i = 0; i < nx; i++) { + lrr.FX[i] = lrr.beta[i + 1] * lrr.beta[i + 1] / (lrr.MSE * invA.get(i, i)); + lrr.TX[i] = lrr.beta[i + 1] / (lrr.s * Math.sqrt(invA.get(i, i))); + } + + int p = nameX.length; + double df2 = N - p - 1; + if (df2 <= 0) { + df2 = N - 2; + } + + lrr.pEquation = 1 - CDF.F(lrr.F, p, df2); + lrr.pX = new double[nx]; + for (int i = 0; i < nx; i++) { + lrr.pX[i] = (1 - CDF.studentT(Math.abs(lrr.TX[i]), df2)) * 2; + } + + LinearRegressionSummary lrSummary = new LinearRegressionSummary(); + + lrSummary.count = summary.count(); + lrSummary.beta = beta; + lrSummary.fValue = lrr.F; + lrSummary.mallowCp = lrr.getCp(indexX.length, lrr.SSE); + lrSummary.r2 = lrr.R2; + lrSummary.ra2 = lrr.Ra2; + lrSummary.pValue = lrr.pEquation; + lrSummary.tValues = lrr.TX; + lrSummary.tPVaues = lrr.pX; + lrSummary.sse = lrr.SSE; + lrSummary.stdEsts = new double[indexX.length]; + lrSummary.stdErrs = new double[indexX.length]; + lrSummary.lowerConfidence = new double[indexX.length]; + lrSummary.uperConfidence = new double[indexX.length]; + for (int i = 0; i < indexX.length; i++) { + double estimate = lrSummary.beta.get(i + 1); + lrSummary.stdEsts[i] = estimate * summary.standardDeviation(indexX[i]); + lrSummary.stdErrs[i] = lrr.s * Math.sqrt(invA.get(i, i)); + lrSummary.lowerConfidence[i] = estimate - 1.96 * lrSummary.stdErrs[i]; + lrSummary.uperConfidence[i] = estimate + 1.96 * lrSummary.stdErrs[i]; + } + + return lrSummary; + } + + public static LogistRegressionSummary calcLrSummary( + Tuple4 weightsAndHessian, + BaseVectorSummarizer srt) { + DenseVector weights = weightsAndHessian.f0; + DenseVector gradient = weightsAndHessian.f1; + DenseMatrix hessian = weightsAndHessian.f2; + double loss = weightsAndHessian.f3; + + int featureNum = gradient.size() - 1; + + LogistRegressionSummary summary = new LogistRegressionSummary(); + + summary.loss = loss; + summary.gradient = weightsAndHessian.f1; + summary.hessian = weightsAndHessian.f2; + summary.beta = weightsAndHessian.f0; + + summary.scoreChiSquareValue = hessian.solveLS(gradient).dot(gradient); + summary.scorePValue = PDF.chi2(summary.scoreChiSquareValue, 1); + + DenseMatrix hessianInv = hessian.pseudoInverse(); + summary.waldChiSquareValue = new double[featureNum + 1]; + summary.waldPValues = new double[featureNum + 1]; + for (int i = 0; i < featureNum + 1; i++) { + summary.waldChiSquareValue[i] = weights.get(i) * weights.get(i) / hessianInv.get(i, i); + summary.waldPValues[i] = PDF.chi2(summary.waldChiSquareValue[i], 1); + } + + summary.stdEsts = new double[featureNum + 1]; + summary.stdErrs = new double[featureNum + 1]; + summary.lowerConfidence = new double[featureNum + 1]; + summary.uperConfidence = new double[featureNum + 1]; + + BaseVectorSummary dataSummary = srt.toSummary(); + for (int i = 0; i < featureNum + 1; i++) { + summary.stdEsts[i] = dataSummary.standardDeviation(i) * summary.beta.get(i) / Math.sqrt(3) / Math.PI; + summary.stdErrs[i] = Math.sqrt(hessianInv.get(i, i)); + summary.lowerConfidence[i] = summary.beta.get(i) - 1.96 * summary.stdErrs[i]; + summary.uperConfidence[i] = summary.beta.get(i) + 1.96 * summary.stdErrs[i]; + } + + summary.aic = 2 * summary.loss + 2 * (featureNum + 1); + summary.sc = 2 * summary.loss + (featureNum + 1) * Math.log(dataSummary.count()); + + return summary; + } + + private static int[] indicesAddOne(int[] indices) { + int[] result = new int[indices.length]; + Arrays.setAll(result, i -> indices[i] + 1); + return result; + } + + public static String getDefaultOptimMethod(String optimMethod, String constrained) { + if (optimMethod == null || optimMethod.isEmpty()) { + if (constrained == null || constrained.isEmpty()) { + optimMethod = OptimMethod.LBFGS.name(); + } else { + optimMethod = HasConstrainedOptimizationMethod.ConstOptimMethod.SQP.name(); + } + } + return optimMethod; + } + + public static List > predict(LinearModelData model, List data) { + LinearModelDataConverter converter = new LinearModelDataConverter(model.labelType); + TableSchema modelSchema = converter.getModelSchema(); + TableSchema dataSchema = new TableSchema(new String[] {"features"}, new TypeInformation[] {AlinkTypes + .VECTOR}); + + LinearModelMapper mapper = new LinearModelMapper( + modelSchema, dataSchema, + new Params().set(LinearRegPredictParams.PREDICTION_COL, "pred")); + mapper.loadModel(model); + + List > result = new ArrayList <>(); + for (Vector vec : data) { + try { + if (model.hasInterceptItem) { + result.add(Tuple2.of(mapper.predict(vec.prefix(1.0)), vec)); + } else { + result.add(Tuple2.of(mapper.predict(vec), vec)); + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + return result; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/linear/LocalLinearRegression.java b/core/src/main/java/com/alibaba/alink/operator/common/linear/LocalLinearRegression.java new file mode 100644 index 000000000..cb5c64ed9 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/linear/LocalLinearRegression.java @@ -0,0 +1,71 @@ +package com.alibaba.alink.operator.common.linear; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Types; + +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.model.ModelParamName; +import com.alibaba.alink.params.shared.linear.LinearTrainParams; + +import java.util.ArrayList; +import java.util.List; + +public class LocalLinearRegression { + + public static LinearModelData train(List > trainData, + int[] indices, + boolean hasIntercept, + double l1, + double l2, + String optimMethod, + String constrained) { + + optimMethod = LocalLinearModel.getDefaultOptimMethod(optimMethod, constrained); + LinearModelType linearModelType = LinearModelType.LinearReg; + boolean standardization = true; + + List > selectedData = new ArrayList <>(); + + for (Tuple3 data : trainData) { + if (indices == null) { + indices = new int[data.f2.size()]; + for (int i = 0; i < indices.length; i++) { + indices[i] = i; + } + } + selectedData.add(Tuple3.of(data.f0, ((Number) data.f1).doubleValue(), data.f2.slice(indices))); + } + + //coef, grad, hession, loss + Tuple4 model = LocalLinearModel.train(selectedData, + indices, linearModelType, optimMethod, + hasIntercept, standardization, constrained, l1, l2); + + Params meta = new Params() + .set(ModelParamName.MODEL_NAME, "model") + .set(ModelParamName.LINEAR_MODEL_TYPE, linearModelType) + .set(ModelParamName.HAS_INTERCEPT_ITEM, hasIntercept) + .set(ModelParamName.VECTOR_COL_NAME, "features") + .set(ModelParamName.FEATURE_TYPES, new String[] {}) + .set(LinearTrainParams.LABEL_COL, "label"); + + return BaseLinearModelTrainBatchOp.buildLinearModelData(meta, + null, + Types.DOUBLE(), + null, + hasIntercept, + false, + Tuple2.of(model.f0, new double[] {model.f3})); + + } + + public static List > predict(LinearModelData model, List data) { + return LocalLinearModel.predict(model, data); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/linear/LocalLogistRegression.java b/core/src/main/java/com/alibaba/alink/operator/common/linear/LocalLogistRegression.java new file mode 100644 index 000000000..a20d3466a --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/linear/LocalLogistRegression.java @@ -0,0 +1,116 @@ +package com.alibaba.alink.operator.common.linear; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Types; + +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.model.ModelParamName; +import com.alibaba.alink.params.shared.linear.LinearTrainParams; + +import java.util.ArrayList; +import java.util.List; + +public class LocalLogistRegression { + + /** + * @param trainData: f0 is label, f1 is weight, f2 is feature. + * @param indices: selected indices. + * @param hasIntercept + * @param l1 + * @param l2 + * @param optimMethod + * @param constrained + * @return + */ + public static LinearModelData train(List > trainData, + int[] indices, + boolean hasIntercept, + double l1, + double l2, + String optimMethod, + String constrained) { + + optimMethod = LocalLinearModel.getDefaultOptimMethod(optimMethod, constrained); + LinearModelType linearModelType = LinearModelType.LR; + boolean standardization = false; + + Tuple2 >, Object[]> dataAndLabelValues = getLabelValues(trainData); + + List > selectedData = new ArrayList <>(); + for (Tuple3 data : dataAndLabelValues.f0) { + if (indices == null) { + indices = new int[data.f2.size()]; + for (int i = 0; i < indices.length; i++) { + indices[i] = i; + } + } + selectedData.add(Tuple3.of(data.f0, data.f1, data.f2.slice(indices).prefix(1.0))); + } + + Tuple4 model = + LocalLinearModel.train(dataAndLabelValues.f0, indices, linearModelType, optimMethod, + hasIntercept, standardization, constrained, l1, l2); + + Params meta = new Params() + .set(ModelParamName.MODEL_NAME, "model") + .set(ModelParamName.LINEAR_MODEL_TYPE, linearModelType) + .set(ModelParamName.LABEL_VALUES, dataAndLabelValues.f1) + .set(ModelParamName.HAS_INTERCEPT_ITEM, hasIntercept) + .set(ModelParamName.VECTOR_COL_NAME, "features") + .set(ModelParamName.FEATURE_TYPES, new String[] {}) + .set(LinearTrainParams.LABEL_COL, "label"); + + return BaseLinearModelTrainBatchOp.buildLinearModelData(meta, + null, + Types.DOUBLE(), + null, + hasIntercept, + false, + Tuple2.of(model.f0, new double[] {model.f3})); + + } + + /** + * @param trainData: f0: weight, f1 label, f2 feature. + * @return f0: data, f2: labelValues. + */ + public static Tuple2 >, Object[]> getLabelValues( + List > trainData) { + List > result = new ArrayList <>(); + Object[] labels = new Object[2]; + + int length = trainData.size(); + if (length < 1) { + throw new RuntimeException("row number must be larger than 0."); + } + labels[0] = trainData.get(0).f1; + for (int i = 1; i < length; i++) { + Object candidate = trainData.get(i).f1; + if (!candidate.equals(labels[0])) { + labels[1] = candidate; + break; + } + } + for (Tuple3 value : trainData) { + double label = labels[0].equals(value.f1.toString()) ? 1 : -1; + result.add(Tuple3.of(value.f0, label, value.f2)); + } + + return Tuple2.of(result, labels); + } + + /** + * @param model: LinearModelData + * @param data: data + * @return + */ + public List > predict(LinearModelData model, List data) { + return LocalLinearModel.predict(model, data); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/linear/LogistRegressionSummary.java b/core/src/main/java/com/alibaba/alink/operator/common/linear/LogistRegressionSummary.java new file mode 100644 index 000000000..0c67ddd6a --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/linear/LogistRegressionSummary.java @@ -0,0 +1,29 @@ +package com.alibaba.alink.operator.common.linear; + +import com.alibaba.alink.operator.common.finance.stepwiseSelector.ClassificationSelectorStep; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.SelectorStep; + +public class LogistRegressionSummary extends ModelSummary { + + public double scoreChiSquareValue; + public double scorePValue; + public double aic; + public double sc; + public double[] stdEsts; + public double[] stdErrs; + public double[] waldChiSquareValue; + public double[] waldPValues; + public double[] lowerConfidence; + public double[] uperConfidence; + + @Override + public SelectorStep toSelectStep(int inId) { + ClassificationSelectorStep step = new ClassificationSelectorStep(); + step.enterCol = String.valueOf(inId); + step.scoreValue = this.scoreChiSquareValue; + step.pValue = this.scorePValue; + step.numberIn = this.beta.size() - 1; + + return step; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/linear/ModelSummary.java b/core/src/main/java/com/alibaba/alink/operator/common/linear/ModelSummary.java new file mode 100644 index 000000000..b12aac272 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/linear/ModelSummary.java @@ -0,0 +1,17 @@ +package com.alibaba.alink.operator.common.linear; + +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.utils.AlinkSerializable; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.SelectorStep; + +public abstract class ModelSummary implements AlinkSerializable { + public double loss; + public DenseVector beta; + public DenseVector gradient; + public DenseMatrix hessian; + public long count; + + public abstract SelectorStep toSelectStep(int inId); + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/linear/ModelSummaryHelper.java b/core/src/main/java/com/alibaba/alink/operator/common/linear/ModelSummaryHelper.java new file mode 100644 index 000000000..8b621abac --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/linear/ModelSummaryHelper.java @@ -0,0 +1,391 @@ +package com.alibaba.alink.operator.common.linear; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.linalg.VectorUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.source.TableSourceBatchOp; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.BaseStepWiseSelectorBatchOp; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.ClassificationSelectorResult; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.RegressionSelectorResult; +import com.alibaba.alink.operator.common.finance.stepwiseSelector.SelectorResult; +import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; +import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummarizer; +import com.alibaba.alink.params.finance.HasConstrainedLinearModelType; + +import java.security.InvalidParameterException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import static com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper.transformToVector; + +public class ModelSummaryHelper { + + public static DataSet calModelSummary(BatchOperator data, + HasConstrainedLinearModelType.LinearModelType + linearModelType, + DataSet modelData, + String vectorCol, + String[] selectedCols, + String labelCol) { + if (HasConstrainedLinearModelType.LinearModelType.LR == linearModelType) { + return calBinarySummary(data, modelData, vectorCol, selectedCols, labelCol) + .map(new MapFunction () { + private static final long serialVersionUID = -3321750369444781812L; + + @Override + public SelectorResult map(LogistRegressionSummary summary) throws Exception { + ClassificationSelectorResult result = new ClassificationSelectorResult(); + result.modelSummary = summary; + result.selectedCols = selectedCols; + return result; + } + }); + } else { + return calRegSummary(data, modelData, vectorCol, selectedCols, labelCol) + .map(new MapFunction () { + private static final long serialVersionUID = 2198755005878386785L; + + @Override + public SelectorResult map(LinearRegressionSummary summary) throws Exception { + RegressionSelectorResult result = new RegressionSelectorResult(); + result.modelSummary = summary; + result.selectedCols = selectedCols; + return result; + } + }); + } + } + + public static DataSet calRegSummary(BatchOperator data, + DataSet modelData, + String vectorCol, + String[] selectedCols, + String labelCol) { + DataSet summarizer = null; + if (null != vectorCol && !vectorCol.isEmpty()) { + int selectedColIdxNew = TableUtil.findColIndexWithAssertAndHint(data.getColNames(), vectorCol); + int labelIdxNew = TableUtil.findColIndexWithAssertAndHint(data.getColNames(), labelCol); + summarizer = StatisticsHelper.summarizer(data.getDataSet() + .map(new BaseStepWiseSelectorBatchOp.ToVectorWithReservedCols(selectedColIdxNew, labelIdxNew)), true); + } + + if (null != selectedCols && selectedCols.length != 0) { + String[] statCols = new String[selectedCols.length + 1]; + statCols[0] = labelCol; + System.arraycopy(selectedCols, 0, statCols, 1, selectedCols.length); + summarizer = StatisticsHelper.summarizer(transformToVector(data, statCols, null), true); + } + + if (null == summarizer) { + throw new InvalidParameterException("select col and select cols must be set one"); + } + + return summarizer + .flatMap(new CalRegSummary()) + .withBroadcastSet(modelData, "linearModelData"); + + } + + public static DataSet calBinarySummary(BatchOperator data, + DataSet modelData, + String vectorCol, + String[] featureCols, + String labelCol) { + + int[] featureIndices = null; + int labelIdx = TableUtil.findColIndexWithAssertAndHint(data.getColNames(), labelCol); + if (featureCols != null) { + featureIndices = new int[featureCols.length]; + for (int i = 0; i < featureCols.length; ++i) { + int idx = TableUtil.findColIndexWithAssertAndHint(data.getColNames(), featureCols[i]); + featureIndices[i] = idx; + TypeInformation type = data.getSchema().getFieldTypes()[idx]; + Preconditions.checkState(TableUtil.isSupportedNumericType(type), + "linear algorithm only support numerical data type. type is : " + type); + } + } + int weightIdx = -1; + int vecIdx = vectorCol != null ? TableUtil.findColIndexWithAssertAndHint(data.getColNames(), vectorCol) : -1; + + //transform data + DataSet > tuple3Data = data.getDataSet() + .map(new TransformLrLabel(vecIdx, featureIndices, labelIdx, weightIdx)) + .withBroadcastSet(modelData, "linearModelData") + .name("TransferLrData"); + + //summarizer + DataSet summarizer = StatisticsHelper.summarizer( + tuple3Data.map(new MapFunction , Vector>() { + private static final long serialVersionUID = -1205844850032698897L; + + @Override + public Vector map(Tuple3 tuple3) throws Exception { + tuple3.f2.set(0, tuple3.f0); + return tuple3.f2; + } + }), false); + + //stat + return tuple3Data.mapPartition(new CalcGradientAndHessian()) + .withBroadcastSet(modelData, "linearModelData") + .reduce(new ReduceFunction >() { + private static final long serialVersionUID = -9187304403661961376L; + + @Override + public Tuple4 reduce( + Tuple4 left, + Tuple4 right) throws Exception { + return Tuple4.of(left.f0, + left.f1.plus(right.f1), + left.f2.plus(right.f2), + left.f3 + right.f3 + ); + } + }).name("combine gradient and hessian") + .mapPartition(new CalcLrSummary()) + .withBroadcastSet(summarizer, "Summarizer"); + } + + /** + * return: data, labels + */ + public static Tuple2 > transformLrLabel(BatchOperator data, + String labelCol, + String positiveLabel, + Long sessionId) { + int labelColIdx = TableUtil.findColIndexWithAssertAndHint(data.getColNames(), labelCol); + + DataSet lables = getLabelInfo(data, labelCol); + DataSet transformedData = data.getDataSet() + .map(new TransformLrLabelWithLabel(labelColIdx, positiveLabel)) + .withBroadcastSet(lables, "labels"); + + TypeInformation[] outTypes = data.getColTypes(); + outTypes[labelColIdx] = Types.DOUBLE; + + return Tuple2.of(new TableSourceBatchOp( + DataSetConversionUtil.toTable(sessionId, transformedData, data.getColNames(), outTypes)) + , lables); + } + + public static class TransformLrLabel extends RichMapFunction > { + private static final long serialVersionUID = -1178726287840080632L; + private LinearModelData modelData; + private String positiveLableValueString; + private int vecIdx; + private int[] featureIndices; + private int labelIdx; + private int weightIdx; + + public TransformLrLabel(int vecIdx, int[] featureIndices, int labelIdx, int weightIdx) { + this.vecIdx = vecIdx; + this.featureIndices = featureIndices; + this.labelIdx = labelIdx; + this.weightIdx = weightIdx; + } + + public void open(Configuration conf) { + this.modelData = (LinearModelData) this.getRuntimeContext(). + getBroadcastVariable("linearModelData") + .get(0); + this.positiveLableValueString = this.modelData.labelValues[0].toString(); + } + + @Override + public Tuple3 map(Row row) throws Exception { + return transferLabel(row, featureIndices, vecIdx, labelIdx, weightIdx, positiveLableValueString); + } + } + + public static class TransformLrLabelWithLabel extends RichMapFunction { + private static final long serialVersionUID = 9009298353079015437L; + private String positiveLableValueString = null; + private int labelIdx; + + public TransformLrLabelWithLabel(int labelIdx, String positiveLabel) { + this.labelIdx = labelIdx; + this.positiveLableValueString = positiveLabel; + } + + @Override + public void open(Configuration parameters) throws Exception { + List labelRows = getRuntimeContext().getBroadcastVariable("labels"); + this.positiveLableValueString = orderLabels(labelRows, positiveLableValueString)[0].toString(); + } + + @Override + public Row map(Row row) throws Exception { + Double val = FeatureLabelUtil.getLabelValue(row, false, + labelIdx, positiveLableValueString); + row.setField(labelIdx, val); + return row; + } + } + + public static Object[] orderLabels(Iterable unorderedLabelRows, String positiveLabel) { + List tmpArr = new ArrayList <>(); + for (Object row : unorderedLabelRows) { + tmpArr.add(row); + } + Object[] labels = tmpArr.toArray(new Object[0]); + Preconditions.checkState((labels.length == 2), "labels count should be 2 in 2 classification algo."); + String str0 = labels[0].toString(); + String str1 = labels[1].toString(); + + String positiveLabelValueString = positiveLabel; + if (positiveLabelValueString == null) { + positiveLabelValueString = (str1.compareTo(str0) > 0) ? str1 : str0; + } + + if (labels[1].toString().equals(positiveLabelValueString)) { + Object t = labels[0]; + labels[0] = labels[1]; + labels[1] = t; + } + return labels; + } + + public static class CalcGradientAndHessian extends RichMapPartitionFunction , + Tuple4 > { + private static final long serialVersionUID = 2861532185191763285L; + private LinearModelData modelData; + + public void open(Configuration conf) { + this.modelData = (LinearModelData) this.getRuntimeContext(). + getBroadcastVariable("linearModelData") + .get(0); + } + + @Override + public void mapPartition(Iterable > data, + Collector > collector) + throws Exception { + final OptimObjFunc objFunc = OptimObjFunc.getObjFunction(LinearModelType.LR, + new Params()); + int n = modelData.coefVector.size(); + DenseVector coef = modelData.coefVector; + DenseMatrix hessian = new DenseMatrix(n, n); + DenseVector gradient = new DenseVector(n); + + Tuple2 tuple2 = objFunc.calcHessianGradientLoss(data, coef, hessian, gradient); + //LogistRegressionSummary summary = LocalLinearModel.calcLrSummary(Tuple4.of(coef, gradient, hessian, + // tuple2.f1), BaseVectorSummarizer); + collector.collect(Tuple4.of(coef, gradient, hessian, tuple2.f1)); + } + } + + public static class CalcLrSummary extends RichMapPartitionFunction < + Tuple4 , + LogistRegressionSummary> { + private static final long serialVersionUID = 1799381476070262842L; + private BaseVectorSummarizer srt; + + public void open(Configuration conf) { + this.srt = (BaseVectorSummarizer) this.getRuntimeContext(). + getBroadcastVariable("Summarizer") + .get(0); + } + + @Override + public void mapPartition(Iterable > model, + Collector collector) throws Exception { + Iterator > iter = model.iterator(); + if (iter.hasNext()) { + collector.collect(LocalLinearModel.calcLrSummary(iter.next(), srt)); + } + } + } + + public static class CalRegSummary extends RichFlatMapFunction { + private static final long serialVersionUID = 1372774780273725623L; + private LinearModelData modelData; + + public void open(Configuration conf) { + this.modelData = (LinearModelData) this.getRuntimeContext(). + getBroadcastVariable("linearModelData") + .get(0); + } + + @Override + public void flatMap(BaseVectorSummarizer summarizer, Collector collector) + throws Exception { + DenseVector beta = modelData.coefVector; + int vectorSize = summarizer.toSummary().vectorSize(); + int[] statIndices = new int[vectorSize - 1]; + for (int i = 0; i < statIndices.length; i++) { + statIndices[i] = i + 1; + } + collector.collect(LocalLinearModel.calcLinearRegressionSummary(beta, summarizer, 0, statIndices)); + } + } + + public static Boolean isLinearRegression(String linearModelType) { + linearModelType = linearModelType.trim().toUpperCase(); + if (!linearModelType.equals("LINEARREG") && !linearModelType.equals("LR")) { + throw new RuntimeException("model type not support. " + linearModelType); + } + return "LINEARREG".equals(linearModelType); + } + + public static DataSet getLabelInfo(BatchOperator in, + String labelCol) { + return in.select(new String[] {labelCol}).distinct().getDataSet().map( + new MapFunction () { + private static final long serialVersionUID = 2044498497762182626L; + + @Override + public Object map(Row row) { + return row.getField(0); + } + }); + } + + private static Tuple3 transferLabel(Row row, + int[] featureIndices, int vecIdx, int labelIdx, + int weightIdx, + String positiveLableValueString) throws Exception { + Double weight = weightIdx != -1 ? ((Number) row.getField(weightIdx)).doubleValue() : 1.0; + Double val = FeatureLabelUtil.getLabelValue(row, false, + labelIdx, positiveLableValueString); + Tuple3 tuple3; + if (featureIndices != null) { + DenseVector vec = new DenseVector(featureIndices.length); + for (int i = 0; i < featureIndices.length; ++i) { + vec.set(i, ((Number) row.getField(featureIndices[i])).doubleValue()); + } + tuple3 = Tuple3.of(weight, val, vec); + } else { + Vector vec = VectorUtil.getVector(row.getField(vecIdx)); + Preconditions.checkState((vec != null), + "vector for linear model train is null, please check your input data."); + + tuple3 = Tuple3.of(weight, val, vec); + } + + return Tuple3.of(tuple3.f0, tuple3.f1, tuple3.f2.prefix(1.0)); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/linear/SoftmaxObjFunc.java b/core/src/main/java/com/alibaba/alink/operator/common/linear/SoftmaxObjFunc.java index ed85264c7..5c15ec509 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/linear/SoftmaxObjFunc.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/linear/SoftmaxObjFunc.java @@ -22,7 +22,6 @@ public class SoftmaxObjFunc extends OptimObjFunc { * k1 = numClass - 1 */ private final int k1; - private Tuple2 etas = null; /** * Constructor. @@ -42,7 +41,7 @@ public SoftmaxObjFunc(Params params) { * @return the loss value and weight value. */ @Override - protected double calcLoss(Tuple3 labelVector, DenseVector coefVector) { + public double calcLoss(Tuple3 labelVector, DenseVector coefVector) { int featDim = coefVector.size() / k1; double[] weights = coefVector.getData(); double sumExp = 1; @@ -95,13 +94,9 @@ public double[] calcSearchValues( DenseVector dirVec, double beta, int numStep) { double[] losses = new double[numStep + 1]; double[] stateValues = new double[numStep + 1]; + Tuple2 etas = Tuple2.of(new double[k1 + 1], new double[k1 + 1]); for (Tuple3 labelVector : labelVectors) { - if (etas == null) { - double[] f0 = new double[k1 + 1]; - double[] f1 = new double[k1 + 1]; - etas = Tuple2.of(f0, f1); - } calcEta(labelVector, coefVector, dirVec, beta, etas); int yk = labelVector.f1.intValue(); @@ -180,7 +175,7 @@ private void calcEta(Tuple3 labelVector, DenseVector co * @param updateGrad gradient need to update. */ @Override - protected void updateGradient(Tuple3 labelVector, DenseVector coefVector, + public void updateGradient(Tuple3 labelVector, DenseVector coefVector, DenseVector updateGrad) { double[] phi = calcPhi(labelVector, coefVector); @@ -268,7 +263,7 @@ private double[] calcPhi(Tuple3 labelVector, DenseVecto * @param updateHessian hessian matrix need to update. */ @Override - protected void updateHessian(Tuple3 labelVector, DenseVector coefVector, + public void updateHessian(Tuple3 labelVector, DenseVector coefVector, DenseMatrix updateHessian) { double[] phi = calcPhi(labelVector, coefVector); int featDim = coefVector.size() / k1; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/linear/UnaryLossObjFunc.java b/core/src/main/java/com/alibaba/alink/operator/common/linear/UnaryLossObjFunc.java index b105b01d9..a29fc614f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/linear/UnaryLossObjFunc.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/linear/UnaryLossObjFunc.java @@ -48,7 +48,7 @@ public boolean hasSecondDerivative() { * @return the loss value and weight value. */ @Override - protected double calcLoss(Tuple3 labelVector, DenseVector coefVector) { + public double calcLoss(Tuple3 labelVector, DenseVector coefVector) { double eta = getEta(labelVector, coefVector); return this.unaryLossFunc.loss(eta, labelVector.f1); } @@ -61,7 +61,7 @@ protected double calcLoss(Tuple3 labelVector, DenseVect * @param updateGrad gradient need to update. */ @Override - protected void updateGradient(Tuple3 labelVector, DenseVector coefVector, + public void updateGradient(Tuple3 labelVector, DenseVector coefVector, DenseVector updateGrad) { double eta = getEta(labelVector, coefVector); double div = labelVector.f0 * unaryLossFunc.derivative(eta, labelVector.f1); @@ -76,7 +76,7 @@ protected void updateGradient(Tuple3 labelVector, Dense * @param updateHessian hessian matrix need to update. */ @Override - protected void updateHessian(Tuple3 labelVector, DenseVector coefVector, + public void updateHessian(Tuple3 labelVector, DenseVector coefVector, DenseMatrix updateHessian) { Vector vec = labelVector.f2; double eta = getEta(labelVector, coefVector); diff --git a/core/src/main/java/com/alibaba/alink/operator/common/stream/model/FileModelStreamSink.java b/core/src/main/java/com/alibaba/alink/operator/common/modelstream/FileModelStreamSink.java similarity index 98% rename from core/src/main/java/com/alibaba/alink/operator/common/stream/model/FileModelStreamSink.java rename to core/src/main/java/com/alibaba/alink/operator/common/modelstream/FileModelStreamSink.java index 87f722660..214fa9478 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/stream/model/FileModelStreamSink.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/modelstream/FileModelStreamSink.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.operator.common.stream.model; +package com.alibaba.alink.operator.common.modelstream; import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.core.fs.FileSystem.WriteMode; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/stream/model/ModelStreamFileScanner.java b/core/src/main/java/com/alibaba/alink/operator/common/modelstream/ModelStreamFileScanner.java similarity index 99% rename from core/src/main/java/com/alibaba/alink/operator/common/stream/model/ModelStreamFileScanner.java rename to core/src/main/java/com/alibaba/alink/operator/common/modelstream/ModelStreamFileScanner.java index d1c01ff4e..f8dbcbaa7 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/stream/model/ModelStreamFileScanner.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/modelstream/ModelStreamFileScanner.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.operator.common.stream.model; +package com.alibaba.alink.operator.common.modelstream; import org.apache.flink.api.common.time.Time; import org.apache.flink.core.fs.Path; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/stream/model/ModelStreamMeta.java b/core/src/main/java/com/alibaba/alink/operator/common/modelstream/ModelStreamMeta.java similarity index 82% rename from core/src/main/java/com/alibaba/alink/operator/common/stream/model/ModelStreamMeta.java rename to core/src/main/java/com/alibaba/alink/operator/common/modelstream/ModelStreamMeta.java index b12676851..585dea3e3 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/stream/model/ModelStreamMeta.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/modelstream/ModelStreamMeta.java @@ -1,4 +1,4 @@ -package com.alibaba.alink.operator.common.stream.model; +package com.alibaba.alink.operator.common.modelstream; import java.io.Serializable; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/stream/model/ModelStreamUtils.java b/core/src/main/java/com/alibaba/alink/operator/common/modelstream/ModelStreamUtils.java similarity index 92% rename from core/src/main/java/com/alibaba/alink/operator/common/stream/model/ModelStreamUtils.java rename to core/src/main/java/com/alibaba/alink/operator/common/modelstream/ModelStreamUtils.java index 989f34e85..6dc274028 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/stream/model/ModelStreamUtils.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/modelstream/ModelStreamUtils.java @@ -1,8 +1,5 @@ -package com.alibaba.alink.operator.common.stream.model; +package com.alibaba.alink.operator.common.modelstream; -import org.apache.flink.api.common.functions.MapFunction; -import org.apache.flink.api.common.functions.Partitioner; -import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.tuple.Tuple2; @@ -11,10 +8,8 @@ import org.apache.flink.core.fs.FileStatus; import org.apache.flink.core.fs.Path; import org.apache.flink.ml.api.misc.param.Params; -import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import org.apache.flink.util.Collector; import org.apache.flink.util.StringUtils; import com.alibaba.alink.common.exceptions.AkPreconditions; @@ -251,32 +246,6 @@ public static TableSchema createSchemaFromFilePath(FilePath filePath, String sch return schema; } - public static DataStream broadcastStream(DataStream input) { - return input - .flatMap(new RichFlatMapFunction >() { - private static final long serialVersionUID = 6421400378693673120L; - - @Override - public void flatMap(Row row, Collector > out) - throws Exception { - int numTasks = getRuntimeContext().getNumberOfParallelSubtasks(); - for (int i = 0; i < numTasks; ++i) { - out.collect(Tuple2.of(i, row)); - } - } - }).partitionCustom(new Partitioner () { - - @Override - public int partition(Integer key, int numPartitions) {return key;} - }, 0).map(new MapFunction , Row>() { - - @Override - public Row map(Tuple2 value) throws Exception { - return value.f1; - } - }); - } - private static final int YEAR_LENGTH = 4; private static final int MONTH_LENGTH = 2; private static final int DAY_LENGTH = 2; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/nlp/DocCountVectorizerModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/nlp/DocCountVectorizerModelMapper.java index e2df2b6b3..9c4c4f84f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/nlp/DocCountVectorizerModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/nlp/DocCountVectorizerModelMapper.java @@ -8,7 +8,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.mapper.SISOModelMapper; import com.alibaba.alink.common.utils.JsonConverter; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/nlp/DocHashCountVectorizerModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/nlp/DocHashCountVectorizerModelMapper.java index ebbbdde3b..0d70ce81e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/nlp/DocHashCountVectorizerModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/nlp/DocHashCountVectorizerModelMapper.java @@ -6,7 +6,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.mapper.SISOModelMapper; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/nlp/Word2VecModelDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/nlp/Word2VecModelDataConverter.java index 5fe09883b..bf6a82148 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/nlp/Word2VecModelDataConverter.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/nlp/Word2VecModelDataConverter.java @@ -6,7 +6,7 @@ import org.apache.flink.types.Row; import org.apache.flink.util.Collector; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.model.ModelDataConverter; import java.util.List; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/nlp/WordCountUtil.java b/core/src/main/java/com/alibaba/alink/operator/common/nlp/WordCountUtil.java index 6cbde6821..b856a8650 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/nlp/WordCountUtil.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/nlp/WordCountUtil.java @@ -22,7 +22,7 @@ import org.apache.flink.util.Collector; import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.common.utils.RowUtil; import com.alibaba.alink.operator.batch.BatchOperator; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/nlp/bert/BertTextEmbeddingMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/nlp/bert/BertTextEmbeddingMapper.java index fb30bdf67..823edfd17 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/nlp/bert/BertTextEmbeddingMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/nlp/bert/BertTextEmbeddingMapper.java @@ -6,7 +6,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.dl.BertResources; import com.alibaba.alink.common.dl.plugin.TFPredictorClassLoaderFactory; import com.alibaba.alink.common.io.plugin.ResourcePluginFactory; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/nlp/bert/PreTrainedTokenizerMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/nlp/bert/PreTrainedTokenizerMapper.java index 4eeba78d2..269cc84de 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/nlp/bert/PreTrainedTokenizerMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/nlp/bert/PreTrainedTokenizerMapper.java @@ -6,7 +6,7 @@ import org.apache.flink.table.api.TableSchema; import com.alibaba.alink.common.linalg.tensor.IntTensor; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.mapper.Mapper; import com.alibaba.alink.operator.common.nlp.bert.tokenizer.EncodingKeys; import com.alibaba.alink.operator.common.nlp.bert.tokenizer.Kwargs; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/ConstraintBetweenBins.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/ConstraintBetweenBins.java new file mode 100644 index 000000000..89473d441 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/ConstraintBetweenBins.java @@ -0,0 +1,210 @@ +package com.alibaba.alink.operator.common.optim; + +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; + +import com.alibaba.alink.common.utils.JsonConverter; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +/** + * bins constraint of one feature. constraints only occur on the bins of one feature. + */ +public class ConstraintBetweenBins implements Serializable { + private static final long serialVersionUID = -7770851276936359039L; + public String name = "initial"; + public int dim = 0; + + //Upper bound of weight constraint of variable. + //number type: integer double. + //eg, 1, 1.0 then bin_1 < 1.0. + @JsonProperty("UP") + public List lessThan = new ArrayList <>(); + + //lower bound of weight constraint for variable. + //number type: integer double. + //eg, 1, 1.0 then bin_1 > 1.0. + @JsonProperty("LO") + public List largerThan = new ArrayList <>(); + + //The weight of the variable is equal to a fixed value. + //number type: integer double. + //eg, 1, 1.0 then bin_1 = 1.0. + @JsonProperty("=") + public List equal = new ArrayList <>(); + + //The weights of variables are proportional to each other. + //number type: integer integer double. + //eg, 1,2,1.5 then bin_1/bin_2=1.5. + @JsonProperty("%") + public List scale = new ArrayList <>(); + + //The weights of variables satisfy the ascending order constraint in order. + //number type: integer, integer, integer .... + //eg, 1,3,6,4 then bin_1 < bin_3 < bin_6 < bin_4. + @JsonProperty("<") + public List lessThanBin = new ArrayList <>(); + + //The weights of variables satisfy the decending order constraint in order. + //number type: integer, integer, integer .... + //eg, 1,3,6,4 then bin_1 > bin_3 > bin_6 > bin_4. + @JsonProperty(">") + public List largerThanBin = new ArrayList <>(); + + public ConstraintBetweenBins() {} + + /** + * @param name feature colName. + * @param dim bins number of feature. + */ + public ConstraintBetweenBins(String name, int dim) { + this.name = name; + this.dim = dim; + } + + public void setName(String name) { + this.name = name; + } + + public void setDim(int dim) { + this.dim = dim; + } + + //Upper bound of weight constraint of variable. + //number type: integer double. + //eg, 1, 1.0 then bin_1 < 1.0. + public void addLessThan(Number[] item) { + lessThan.add(item); + } + + //lower bound of weight constraint for variable. + //number type: integer double. + //eg, 1, 1.0 then bin_1 > 1.0. + public void addLargerThan(Number[] item) { + largerThan.add(item); + } + + //The weight of the variable is equal to a fixed value. + //number type: integer double. + //eg, 1, 1.0 then bin_1 = 1.0. + public void addEqual(Number[] item) { + equal.add(item); + } + + //The weights of variables are proportional to each other. + //number type: integer integer double. + //eg, 1,2,1.5 then bin_1/bin_2=1.5. + public void addScale(Number[] item) { + scale.add(item); + } + + //The weights of variables satisfy the ascending order constraint in order. + //number type: integer, integer, integer .... + //eg, 1,3,6,4 then bin_1 < bin_3 < bin_6 < bin_4. + public void addLessThanBin(Number[] item) { + lessThanBin.add(item); + } + + //The weights of variables satisfy the decending order constraint in order. + //number type: integer, integer, integer .... + //eg, 1,3,6,4 then bin_1 > bin_3 > bin_6 > bin_4. + public void addLargerThanBin(Number[] item) { + largerThanBin.add(item); + } + + public int getInequalSize() { + return lessThan.size() + largerThan.size() + lessThanBin.size() + largerThanBin.size(); + } + + public int getEqualSize() { + return equal.size() + scale.size(); + } + + public String toString() { + return JsonConverter.toJson(this); + } + + //Divide the continuous unequal(< or >) into pairwise. + public static ConstraintBetweenBins fromJson(String constraintJson) { + if (constraintJson == null || constraintJson.equals("")) { + return new ConstraintBetweenBins(); + } + ConstraintBetweenBins constraint = JsonConverter.fromJson(constraintJson, ConstraintBetweenBins.class); + List lessThanBin = new ArrayList <>(); + for (Number[] item : constraint.lessThanBin) { + int size = item.length; + if (size == 2) { + lessThanBin.add(item); + } else { + for (int i = 0; i < size - 1; i++) { + lessThanBin.add(new Number[] {item[i], item[i + 1]}); + } + } + } + + List largerThanBin = new ArrayList <>(); + for (Number[] item : constraint.largerThanBin) { + int size = item.length; + if (size == 2) { + largerThanBin.add(item); + } else { + for (int i = 0; i < size - 1; i++) { + largerThanBin.add(new Number[] {item[i], item[i + 1]}); + } + } + } + constraint.lessThanBin = lessThanBin; + constraint.largerThanBin = largerThanBin; + return constraint; + } + + public Tuple4 getConstraints(int dim) { + int inequalSize = lessThan.size() + largerThan.size() + lessThanBin.size() + largerThanBin.size(); + int equalSize = equal.size() + scale.size(); + double[][] equalityConstraint = new double[equalSize][dim]; + double[][] inequalityConstraint = new double[inequalSize][dim]; + double[] equalityItem = new double[equalSize]; + double[] inequalityItem = new double[inequalSize]; + //in sqp, it needs to write in the form of ≥. + int index = 0; + for (int i = 0; i < lessThan.size(); i++) { + inequalityConstraint[index][(int) lessThan.get(i)[0]] = -1; + inequalityItem[index] = (double) lessThan.get(i)[1] * -1; + index++; + } + for (int i = 0; i < largerThan.size(); i++) { + inequalityConstraint[index][(int) largerThan.get(i)[0]] = 1; + inequalityItem[index] = (double) largerThan.get(i)[1]; + index++; + } + for (int i = 0; i < lessThanBin.size(); i++) { + inequalityConstraint[index][(int) lessThanBin.get(i)[0]] = -1; + inequalityConstraint[index][(int) lessThanBin.get(i)[1]] = 1; + inequalityItem[index] = 0; + index++; + } + for (int i = 0; i < largerThanBin.size(); i++) { + inequalityConstraint[index][(int) largerThanBin.get(i)[0]] = 1; + inequalityConstraint[index][(int) largerThanBin.get(i)[1]] = -1; + inequalityItem[index] = 0; + index++; + } + + index = 0; + for (int i = 0; i < equal.size(); i++) { + equalityConstraint[index][(int) equal.get(i)[0]] = 1; + equalityItem[index] = (double) equal.get(i)[1]; + index++; + } + for (int i = 0; i < scale.size(); i++) { + equalityConstraint[index][(int) scale.get(i)[0]] = 1; + equalityConstraint[index][(int) scale.get(i)[1]] = -1 * ((double) scale.get(i)[2]); + equalityItem[index] = 0; + index++; + } + return Tuple4.of(inequalityConstraint, inequalityItem, equalityConstraint, equalityItem); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/ConstraintBetweenFeatures.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/ConstraintBetweenFeatures.java new file mode 100644 index 000000000..50087528d --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/ConstraintBetweenFeatures.java @@ -0,0 +1,316 @@ +package com.alibaba.alink.operator.common.optim; + +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; + +import com.alibaba.alink.common.utils.JsonConverter; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; + +/** + * ConstraintBetweenFeatures support six types constraint, and support 2 types of features. + * In scorecard, one feature may be divided into several partition, so this class supports + * constraint between features, or between partitions. + */ +public class ConstraintBetweenFeatures implements Serializable { + private static final long serialVersionUID = 1830363299295476243L; + private final String name = "constraintBetweenFeatures"; + @JsonProperty("UP") + public List lessThan = new ArrayList <>(); + @JsonProperty("LO") + public List largerThan = new ArrayList <>(); + @JsonProperty("=") + public List equal = new ArrayList <>(); + @JsonProperty("%") + public List scale = new ArrayList <>(); + @JsonProperty("<") + public List lessThanFeature = new ArrayList <>(); + @JsonProperty(">") + public List largerThanFeature = new ArrayList <>(); + + //for table. + public void addLessThan(String f1, int i1, double value) { + if (lessThan.size() == 0 || lessThan.get(0).length == 3) { + Object[] temp = new Object[] {f1, i1, value}; + lessThan.add(temp); + } else { + throw new RuntimeException(); + } + } + + public void addLargerThan(String f1, int i1, double value) { + if (largerThan.size() == 0 || largerThan.get(0).length == 3) { + Object[] temp = new Object[] {f1, i1, value}; + largerThan.add(temp); + } else { + throw new RuntimeException(); + } + } + + public void addEqual(String f1, int i1, double value) { + if (equal.size() == 0 || equal.get(0).length == 3) { + Object[] temp = new Object[] {f1, i1, value}; + equal.add(temp); + } else { + throw new RuntimeException(); + } + } + + public void addScale(String f1, int i1, String f2, int i2, double time) { + if (scale.size() == 0 || scale.get(0).length == 5) { + Object[] temp = new Object[] {f1, i1, f2, i2, time}; + scale.add(temp); + } else { + throw new RuntimeException(); + } + } + + public void addLessThanFeature(String f1, int i1, String f2, int i2) { + if (lessThanFeature.size() == 0 || lessThanFeature.get(0).length == 4) { + Object[] temp = new Object[] {f1, i1, f2, i2}; + lessThanFeature.add(temp); + } else { + throw new RuntimeException(); + } + } + + public void addLargerThanFeature(String f1, int i1, String f2, int i2) { + if (largerThanFeature.size() == 0 || largerThanFeature.get(0).length == 4) { + Object[] temp = new Object[] {f1, i1, f2, i2}; + largerThanFeature.add(temp); + } else { + throw new RuntimeException(); + } + } + + //for vector. + public void addLessThan(int i1, double value) { + if (lessThan.size() == 0 || lessThan.get(0).length == 2) { + Object[] temp = new Object[] {i1, value}; + lessThan.add(temp); + } else { + throw new RuntimeException(); + } + } + + public void addLargerThan(int i1, double value) { + if (largerThan.size() == 0 || largerThan.get(0).length == 2) { + Object[] temp = new Object[] {i1, value}; + largerThan.add(temp); + } else { + throw new RuntimeException(); + } + } + + public void addEqual(int i1, double value) { + if (equal.size() == 0 || equal.get(0).length == 2) { + Object[] temp = new Object[] {i1, value}; + equal.add(temp); + } else { + throw new RuntimeException(); + } + } + + public void addScale(int i1, int i2, double time) { + if (scale.size() == 0 || scale.get(0).length == 3) { + Object[] temp = new Object[] {i1, i2, time}; + scale.add(temp); + } else { + throw new RuntimeException(); + } + } + + public void addLessThanFeature(int i1, int i2) { + if (lessThanFeature.size() == 0 || lessThanFeature.get(0).length == 2) { + Object[] temp = new Object[] {i1, i2}; + lessThanFeature.add(temp); + } else { + throw new RuntimeException(); + } + } + + public void addLargerThanFeature(int i1, int i2) { + if (largerThanFeature.size() == 0 || largerThanFeature.get(0).length == 2) { + Object[] temp = new Object[] {i1, i2}; + largerThanFeature.add(temp); + } else { + throw new RuntimeException(); + } + } + + public String toString() { + return JsonConverter.toJson(this); + } + + public static ConstraintBetweenFeatures fromJson(String constraintJson) { + if (constraintJson == null || constraintJson.equals("")) { + return new ConstraintBetweenFeatures(); + } + ConstraintBetweenFeatures constraint = JsonConverter.fromJson(constraintJson, ConstraintBetweenFeatures.class); + + if (constraint.lessThanFeature.size() + constraint.largerThanFeature.size() == 0) { + return constraint; + } + List lessThanFeature = new ArrayList <>(); + List largerThanFeature = new ArrayList <>(); + if (constraint.lessThanFeature.size() != 0 && constraint.lessThanFeature.get(0)[0] instanceof String || + constraint.largerThanFeature.size() != 0 && constraint.largerThanFeature.get(0)[0] instanceof String) { + for (Object[] item : constraint.lessThanFeature) { + int size = item.length / 2; + if (size == 2) { + lessThanFeature.add(item); + } else { + for (int i = 0; i < size - 1; i++) { + lessThanFeature.add( + new Object[] {item[i * 2], item[i * 2 + 1], item[i * 2 + 2], item[i * 2 + 3]}); + } + } + } + + for (Object[] item : constraint.largerThanFeature) { + int size = item.length / 2; + if (size == 2) { + largerThanFeature.add(item); + } else { + for (int i = 0; i < size - 1; i++) { + largerThanFeature.add( + new Object[] {item[i * 2], item[i * 2 + 1], item[i * 2 + 2], item[i * 2 + 3]}); + } + } + } + } else { + for (Object[] item : constraint.lessThanFeature) { + int size = item.length; + if (size == 2) { + lessThanFeature.add(item); + } else { + for (int i = 0; i < size - 1; i++) { + lessThanFeature.add(new Object[] {item[i], item[i + 1]}); + } + } + } + + for (Object[] item : constraint.largerThanFeature) { + int size = item.length; + if (size == 2) { + largerThanFeature.add(item); + } else { + for (int i = 0; i < size - 1; i++) { + largerThanFeature.add(new Object[] {item[i], item[i + 1]}); + } + } + } + } + constraint.lessThanFeature = lessThanFeature; + constraint.largerThanFeature = largerThanFeature; + return constraint; + } + + /** + * extract constraints which are in the indices from the all constraint. + */ + public ConstraintBetweenFeatures extractConstraint(String[] indices) { + HashSet indicesSet = new HashSet <>(indices.length); + for (String index : indices) { + indicesSet.add(index); + } + ConstraintBetweenFeatures constraint = new ConstraintBetweenFeatures(); + for (Object[] objects : this.lessThan) { + if (indicesSet.contains(objects[0])) { + constraint.addLessThan((String) objects[0], (int) objects[1], (double) objects[2]); + } + } + for (Object[] objects : this.largerThan) { + if (indicesSet.contains(objects[0])) { + constraint.addLargerThan((String) objects[0], (int) objects[1], (double) objects[2]); + } + } + for (Object[] objects : this.equal) { + if (indicesSet.contains(objects[0])) { + constraint.addEqual((String) objects[0], (int) objects[1], (double) objects[2]); + } + } + for (Object[] objects : this.scale) { + if (indicesSet.contains(objects[0]) || indicesSet.contains(objects[2])) { + constraint.addScale((String) objects[0], (int) objects[1], (String) objects[2], (int) objects[3], + (double) objects[4]); + } + } + for (Object[] objects : this.largerThanFeature) { + if (indicesSet.contains(objects[0]) || indicesSet.contains(objects[2])) { + constraint.addLargerThanFeature((String) objects[0], (int) objects[1], (String) objects[2], + (int) objects[3]); + } + } + for (Object[] objects : this.lessThanFeature) { + if (indicesSet.contains(objects[0]) || indicesSet.contains(objects[2])) { + constraint.addLessThanFeature((String) objects[0], (int) objects[1], (String) objects[2], + (int) objects[3]); + } + } + return constraint; + } + + public ConstraintBetweenFeatures extractConstraint(int[] indices) { + ConstraintBetweenFeatures constraint = new ConstraintBetweenFeatures(); + for (Object[] objects : this.lessThan) { + int idx = findIdx(indices, (int) objects[0]); + if (idx >= 0) { + constraint.addLessThan(idx, (double) objects[1]); + } + } + for (Object[] objects : this.largerThan) { + int idx = findIdx(indices, (int) objects[0]); + if (idx >= 0) { + constraint.addLargerThan(idx, (double) objects[1]); + } + } + for (Object[] objects : this.equal) { + int idx = findIdx(indices, (int) objects[0]); + if (idx >= 0) { + constraint.addEqual(idx, (double) objects[1]); + } + } + for (Object[] objects : this.scale) { + int idx0 = findIdx(indices, (int) objects[0]); + int idx1 = findIdx(indices, (int) objects[1]); + if (idx0 >= 0 && idx1 >= 0) { + constraint.addScale(idx0, idx1, (double) objects[2]); + } + } + for (Object[] objects : this.lessThanFeature) { + int idx0 = findIdx(indices, (int) objects[0]); + int idx1 = findIdx(indices, (int) objects[1]); + if (idx0 >= 0 && idx1 >= 0) { + constraint.addLessThanFeature(idx0, idx1); + } + } + for (Object[] objects : this.largerThanFeature) { + int idx0 = findIdx(indices, (int) objects[0]); + int idx1 = findIdx(indices, (int) objects[1]); + if (idx0 >= 0 && idx1 >= 0) { + constraint.addLargerThanFeature(idx0, idx1); + } + } + return constraint; + } + + public int getInequalSize() { + return lessThan.size() + largerThan.size() + lessThanFeature.size() + largerThanFeature.size(); + } + + public int getEqualSize() { + return equal.size() + scale.size(); + } + + private int findIdx(int[] indices, int idx) { + for (int i = 0; i < indices.length; i++) { + if (idx == indices[i]) { + return i; + } + } + return -1; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/FeatureConstraint.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/FeatureConstraint.java new file mode 100644 index 000000000..46ff666ca --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/FeatureConstraint.java @@ -0,0 +1,548 @@ +package com.alibaba.alink.operator.common.optim; + +import org.apache.flink.api.java.tuple.Tuple4; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.utils.JsonConverter; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class FeatureConstraint implements Serializable { + + private static final long serialVersionUID = -8834177894594296080L; + // bins constraint of one feature. + private List featureConstraint = new ArrayList <>(); + //the constraints between features. If no binConstraint, it can also support. + private ConstraintBetweenFeatures constraintBetweenFeatures; + //countZero is the number of each feature in the dataset. + private DenseVector countZero = null; + private List elseNullSave = null; + + public void setCountZero(DenseVector countZero) { + if (countZero != null) { + this.countZero = countZero; + } + } + + public boolean fromScorecard() { + return null != elseNullSave; + } + + public FeatureConstraint() { + } + + //if the length won't be changed, then replace list as array. + public void addBinConstraint(ConstraintBetweenBins... featureConstraint) { + if (null == this.featureConstraint) { + this.featureConstraint = new ArrayList <>(Arrays.asList(featureConstraint)); + } else { + this.featureConstraint.addAll(Arrays.asList(featureConstraint)); + } + + } + + public HashMap getDim() { + int size = getBinConstraintSize(); + HashMap res = new HashMap <>(size); + for (int i = 0; i < size; i++) { + ConstraintBetweenBins cons = this.featureConstraint.get(i); + res.put(cons.name, i); + } + return res; + } + + public void addConstraintBetweenFeature(ConstraintBetweenFeatures constraint) { + constraintBetweenFeatures = constraint; + } + + public int getBinConstraintSize() { + return featureConstraint.size(); + } + + public String toString() { + return JsonConverter.toJson(this); + } + + public static FeatureConstraint fromJson(String constraintJson) { + if (constraintJson == null || constraintJson.equals("")) { + return new FeatureConstraint(); + } + FeatureConstraint constraints = JsonConverter.fromJson(constraintJson, FeatureConstraint.class); + return constraints; + } + + //here the bin constraint must not be null. + public Tuple4 , Integer, Integer, Integer> getParamsWithBinAndFeature() { + int size = featureConstraint.size(); + int inequalSize = 0; + int equalSize = 0; + int dim = 0; + HashMap featureIndex = new HashMap <>(size); + for (int i = 0; i < size; i++) { + if (i != 0) { + dim += featureConstraint.get(i - 1).dim; + } + featureIndex.put(featureConstraint.get(i).name, dim); + int tempInequal = featureConstraint.get(i).getInequalSize(); + int tempEqual = featureConstraint.get(i).getEqualSize(); + if (tempEqual + tempInequal > 0) { + //add default constraint + equalSize += 1; + } + inequalSize += tempInequal; + equalSize += tempEqual; + } + dim += featureConstraint.get(size - 1).dim; + return Tuple4.of(featureIndex, inequalSize, equalSize, dim); + } + + //for feature with no bins. It's suitable for table. Judge table or vector in the algo. + public Tuple4 getConstraintsForFeatures( + HashMap featureIndex) { + int dim = featureIndex.size(); + int inequalSize = 0; + int equalSize = 0; + if (constraintBetweenFeatures != null) { + inequalSize = constraintBetweenFeatures.getInequalSize(); + equalSize = constraintBetweenFeatures.getEqualSize(); + } + double[][] equalityConstraint = new double[equalSize][dim]; + double[][] inequalityConstraint = new double[inequalSize][dim]; + double[] equalityItem = new double[equalSize]; + double[] inequalityItem = new double[inequalSize]; + int equalStartIndex = 0; + int inequalStartIndex = 0; + addConstraintBetweenBinAndFeatures(constraintBetweenFeatures, featureIndex, equalStartIndex, inequalStartIndex, + equalityConstraint, inequalityConstraint, equalityItem, inequalityItem); + return Tuple4.of(inequalityConstraint, inequalityItem, equalityConstraint, equalityItem); + } + + //for feature with no bins. It's suitable for vector. Judge table or vector in the algo. + public Tuple4 getConstraintsForFeatures(int dim) { + int inequalSize = 0; + int equalSize = 0; + if (constraintBetweenFeatures != null) { + inequalSize = constraintBetweenFeatures.getInequalSize(); + equalSize = constraintBetweenFeatures.getEqualSize(); + } + double[][] equalityConstraint = new double[equalSize][dim]; + double[][] inequalityConstraint = new double[inequalSize][dim]; + double[] equalityItem = new double[equalSize]; + double[] inequalityItem = new double[inequalSize]; + int equalStartIndex = 0; + int inequalStartIndex = 0; + int binStartIndex = 0; + addConstraintBetweenFeatures(constraintBetweenFeatures, equalStartIndex, inequalStartIndex, binStartIndex, + equalityConstraint, inequalityConstraint, equalityItem, inequalityItem); + return Tuple4.of(inequalityConstraint, inequalityItem, equalityConstraint, equalityItem); + } + + //this is for bins in one feature. + //this is only for bin!! it add the default constraint. + public Tuple4 getConstraintsForFeatureWithBin() { + Tuple4 , Integer, Integer, Integer> binFeatureParams = getParamsWithBinAndFeature(); + HashMap featureIndex = binFeatureParams.f0; + int inequalSize = binFeatureParams.f1; + int equalSize = binFeatureParams.f2; + int dim = binFeatureParams.f3; + if (constraintBetweenFeatures != null) { + inequalSize += constraintBetweenFeatures.getInequalSize(); + equalSize += constraintBetweenFeatures.getEqualSize(); + } + double[][] equalityConstraint = new double[equalSize][dim]; + double[][] inequalityConstraint = new double[inequalSize][dim]; + double[] equalityItem = new double[equalSize]; + double[] inequalityItem = new double[inequalSize]; + int equalStartIndex = 0; + int inequalStartIndex = 0; + int binStartIndex = 0; + int elseNullIndex = 0; + for (ConstraintBetweenBins constraintBetweenBins : featureConstraint) { + boolean hasConstraint = addConstraintBetweenBins(constraintBetweenBins, + equalStartIndex, inequalStartIndex, binStartIndex, + equalityConstraint, inequalityConstraint, equalityItem, inequalityItem); + equalStartIndex += constraintBetweenBins.getEqualSize(); + inequalStartIndex += constraintBetweenBins.getInequalSize(); + //here add default constraint, consider the constraint on null and else + //the index of null is -1, and if have else, the index is -2. + //if there is constraint on null or else, default constraint should include them, + // even the constraint is zero. but in fact, if the constraint is zero, the default + //if there is not constraint on null or else, and there are not samples under null or else, default + // constraint + //should include them. + //else do not add default constraint. + if (hasConstraint) { + //if it's null, it is not from scorecard. + if (this.elseNullSave == null) { + addDefaultBinConstraint(constraintBetweenBins, equalStartIndex, binStartIndex, + equalityConstraint, equalityItem, null, null); + } else { + addDefaultBinConstraint(constraintBetweenBins, equalStartIndex, binStartIndex, + equalityConstraint, equalityItem, this.elseNullSave.get(elseNullIndex), countZero); + } + equalStartIndex += 1; + } + binStartIndex += constraintBetweenBins.dim; + elseNullIndex++; + } + addConstraintBetweenBinAndFeatures(constraintBetweenFeatures, featureIndex, equalStartIndex, inequalStartIndex, + equalityConstraint, inequalityConstraint, equalityItem, inequalityItem); + return Tuple4.of(inequalityConstraint, inequalityItem, equalityConstraint, equalityItem); + } + + private static void addDefaultBinConstraint(ConstraintBetweenBins constraint, int equalStartIndex, + int binStartIndex, + double[][] equalityConstraint, double[] equalityItem, + Integer elseNullIndex, DenseVector countZero) { + int dim = constraint.dim; + for (int i = binStartIndex; i < binStartIndex + dim; i++) { + equalityConstraint[equalStartIndex][i] = 1; + } + //if no data on else/null, no default constraint, even if have constraint; + //if have data on else/null, add default constraint. + if (null == elseNullIndex) { + } else if (elseNullIndex == 1) { + if (countZero.get(binStartIndex + dim - 1) == 0) { + equalityConstraint[equalStartIndex][binStartIndex + dim - 1] = 0; + } + } else if (elseNullIndex == 2) { + if (countZero.get(binStartIndex + dim - 1) == 0) { + equalityConstraint[equalStartIndex][binStartIndex + dim - 1] = 0; + } + equalityConstraint[equalStartIndex][binStartIndex + dim - 1] = 0; + } else if (elseNullIndex == 0) { + if (countZero.get(binStartIndex + dim - 1) == 0) { + equalityConstraint[equalStartIndex][binStartIndex + dim - 1] = 0; + } + if (countZero.get(binStartIndex + dim - 2) == 0) { + equalityConstraint[equalStartIndex][binStartIndex + dim - 2] = 0; + } + } + equalityItem[equalStartIndex] = 0; + } + + //only for constraints between bins in one feature. + private boolean addConstraintBetweenBins(ConstraintBetweenBins constraint, + int equalStartIndex, int inequalStartIndex, int binStartIndex, + double[][] equalityConstraint, double[][] inequalityConstraint, + double[] equalityItem, double[] inequalityItem) { + if (constraint == null) { + return false; + } + boolean hasConstraint = false; + for (int i = 0; i < constraint.lessThan.size(); i++) { + inequalityConstraint[inequalStartIndex][(int) constraint.lessThan.get(i)[0] + binStartIndex] = -1; + inequalityItem[inequalStartIndex] = constraint.lessThan.get(i)[1].doubleValue() * -1; + inequalStartIndex++; + hasConstraint = true; + } + for (int i = 0; i < constraint.largerThan.size(); i++) { + inequalityConstraint[inequalStartIndex][(int) constraint.largerThan.get(i)[0] + binStartIndex] = 1; + inequalityItem[inequalStartIndex] = constraint.largerThan.get(i)[1].doubleValue(); + inequalStartIndex++; + hasConstraint = true; + } + for (int i = 0; i < constraint.lessThanBin.size(); i++) { + inequalityConstraint[inequalStartIndex][(int) constraint.lessThanBin.get(i)[0] + binStartIndex] = -1; + inequalityConstraint[inequalStartIndex][(int) constraint.lessThanBin.get(i)[1] + binStartIndex] = 1; + inequalityItem[inequalStartIndex] = 0; + inequalStartIndex++; + hasConstraint = true; + } + for (int i = 0; i < constraint.largerThanBin.size(); i++) { + inequalityConstraint[inequalStartIndex][(int) constraint.largerThanBin.get(i)[0] + binStartIndex] = 1; + inequalityConstraint[inequalStartIndex][(int) constraint.largerThanBin.get(i)[1] + binStartIndex] = -1; + inequalityItem[inequalStartIndex] = 0; + inequalStartIndex++; + hasConstraint = true; + } + + for (int i = 0; i < constraint.equal.size(); i++) { + equalityConstraint[equalStartIndex][(int) constraint.equal.get(i)[0] + binStartIndex] = 1; + equalityItem[equalStartIndex] = constraint.equal.get(i)[1].doubleValue(); + equalStartIndex++; + hasConstraint = true; + } + for (int i = 0; i < constraint.scale.size(); i++) { + equalityConstraint[equalStartIndex][(int) constraint.scale.get(i)[0] + binStartIndex] = 1; + equalityConstraint[equalStartIndex][(int) constraint.scale.get(i)[1] + binStartIndex] = + -1 * constraint.scale.get(i)[2].doubleValue(); + equalityItem[equalStartIndex] = 0; + equalStartIndex++; + hasConstraint = true; + } + return hasConstraint; + } + + //this is for features in the form of vector. + private static void addConstraintBetweenFeatures(ConstraintBetweenFeatures constraint, + int equalStartIndex, int inequalStartIndex, int binStartIndex, + double[][] equalityConstraint, double[][] inequalityConstraint, + double[] equalityItem, double[] inequalityItem) { + if (constraint == null) { + return; + } + int index = inequalStartIndex; + for (int i = 0; i < constraint.lessThan.size(); i++) { + inequalityConstraint[index][(int) constraint.lessThan.get(i)[0] + binStartIndex] = -1; + inequalityItem[index] = ((Number) constraint.lessThan.get(i)[1]).doubleValue() * -1; + index++; + } + for (int i = 0; i < constraint.largerThan.size(); i++) { + inequalityConstraint[index][(int) constraint.largerThan.get(i)[0] + binStartIndex] = 1; + inequalityItem[index] = ((Number) constraint.largerThan.get(i)[1]).doubleValue(); + index++; + } + for (int i = 0; i < constraint.lessThanFeature.size(); i++) { + inequalityConstraint[index][(int) constraint.lessThanFeature.get(i)[0] + binStartIndex] = -1; + inequalityConstraint[index][(int) constraint.lessThanFeature.get(i)[1] + binStartIndex] = 1; + inequalityItem[index] = 0; + index++; + } + for (int i = 0; i < constraint.largerThanFeature.size(); i++) { + inequalityConstraint[index][(int) constraint.largerThanFeature.get(i)[0] + binStartIndex] = 1; + inequalityConstraint[index][(int) constraint.largerThanFeature.get(i)[1] + binStartIndex] = -1; + inequalityItem[index] = 0; + index++; + } + + index = equalStartIndex; + for (int i = 0; i < constraint.equal.size(); i++) { + equalityConstraint[index][(int) constraint.equal.get(i)[0] + binStartIndex] = 1; + equalityItem[index] = ((Number) constraint.equal.get(i)[1]).doubleValue(); + index++; + } + for (int i = 0; i < constraint.scale.size(); i++) { + equalityConstraint[index][(int) constraint.scale.get(i)[0] + binStartIndex] = 1; + equalityConstraint[index][(int) constraint.scale.get(i)[1] + binStartIndex] = + -1 * ((Number) constraint.scale.get(i)[2]).doubleValue(); + equalityItem[index] = 0; + index++; + } + } + + //for features in the form of table. + private static void addConstraintBetweenBinAndFeatures(ConstraintBetweenFeatures constraint, + HashMap featureIndex, + int equalStartIndex, int inequalStartIndex, + double[][] equalityConstraint, + double[][] inequalityConstraint, + double[] equalityItem, double[] inequalityItem) { + if (constraint == null) { + return; + } + int index = inequalStartIndex; + for (int i = 0; i < constraint.lessThan.size(); i++) { + int first = featureIndex.get(constraint.lessThan.get(i)[0]) + (int) constraint.lessThan.get(i)[1]; + inequalityConstraint[index][first] = 1; + inequalityItem[index] = ((Number) constraint.lessThan.get(i)[2]).doubleValue(); + index++; + } + for (int i = 0; i < constraint.largerThan.size(); i++) { + int first = featureIndex.get(constraint.largerThan.get(i)[0]) + (int) constraint.largerThan.get(i)[1]; + inequalityConstraint[index][first] = 1; + inequalityItem[index] = ((Number) constraint.largerThan.get(i)[2]).doubleValue(); + index++; + } + for (int i = 0; i < constraint.lessThanFeature.size(); i++) { + int first = featureIndex.get(constraint.lessThanFeature.get(i)[0]) + (int) constraint.lessThanFeature.get( + i)[1]; + int second = featureIndex.get(constraint.lessThanFeature.get(i)[2]) + (int) constraint.lessThanFeature.get( + i)[3]; + inequalityConstraint[index][first] = -1; + inequalityConstraint[index][second] = 1; + inequalityItem[index] = 0; + index++; + } + for (int i = 0; i < constraint.largerThanFeature.size(); i++) { + int first = featureIndex.get(constraint.largerThanFeature.get(i)[0]) + (int) constraint.largerThanFeature + .get(i)[1]; + int second = featureIndex.get(constraint.largerThanFeature.get(i)[2]) + (int) constraint.largerThanFeature + .get(i)[3]; + inequalityConstraint[index][first] = 1; + inequalityConstraint[index][second] = -1; + inequalityItem[index] = 0; + index++; + } + + index = equalStartIndex; + for (int i = 0; i < constraint.equal.size(); i++) { + int first = featureIndex.get(constraint.equal.get(i)[0]) + (int) constraint.equal.get(i)[1]; + equalityConstraint[index][first] = 1; + equalityItem[index] = ((Number) constraint.equal.get(i)[2]).doubleValue(); + index++; + } + for (int i = 0; i < constraint.scale.size(); i++) { + int first = featureIndex.get(constraint.scale.get(i)[0]) + (int) constraint.scale.get(i)[1]; + int second = featureIndex.get(constraint.scale.get(i)[2]) + (int) constraint.scale.get(i)[3]; + equalityConstraint[index][first] = 1; + equalityConstraint[index][second] = + -1 * ((Number) constraint.scale.get(i)[4]).doubleValue(); + equalityItem[index] = 0; + index++; + } + } + + //ano: bin, all is here. + public void addDim(FeatureConstraint anotherConstraint) { + HashMap featureDim = this.getDim(); + int length = anotherConstraint.featureConstraint.size(); + //ConstraintBetweenBins constraintBetweenBins : anotherConstraint.featureConstraint + for (int i = 0; i < length; ++i) { + ConstraintBetweenBins constraintBetweenBins = anotherConstraint.featureConstraint.get(i); + if (featureDim.containsKey(constraintBetweenBins.name)) { + int dim = constraintBetweenBins.dim; + constraintBetweenBins = this.featureConstraint.get(featureDim.get(constraintBetweenBins.name)); + constraintBetweenBins.dim = dim; + anotherConstraint.featureConstraint.set(i, constraintBetweenBins); + } + } + this.featureConstraint = anotherConstraint.featureConstraint; + } + + public void modify(Map hasElse) { + this.elseNullSave = new ArrayList <>(this.featureConstraint.size()); + for (ConstraintBetweenBins constraint : this.featureConstraint) { + boolean withElse = false; + boolean withNull = false; + HashMap replace = new HashMap <>(2); + boolean hasName = hasElse.get(constraint.name); + int dim = constraint.dim; + if (hasName) { + replace.put(-1, dim - 2); + replace.put(-2, dim - 1); + } else { + replace.put(-1, dim - 1); + } + for (Number[] equal : constraint.equal) { + int number = (int) equal[0]; + if (number == -1) { + withElse = true; + } + if (number == -2) { + withNull = true; + } + if (number == -1 || number == -2) { + equal[0] = replace.get(number); + } + } + for (Number[] scale : constraint.scale) { + int number = (int) scale[0]; + if (number == -1) { + withElse = true; + } + if (number == -2) { + withNull = true; + } + if (number == -1 || number == -2) { + scale[0] = replace.get(number); + } + number = (int) scale[1]; + if (number == -1) { + withElse = true; + } + if (number == -2) { + withNull = true; + } + if (number == -1 || number == -2) { + scale[1] = replace.get(number); + } + } + for (Number[] lessThanBin : constraint.lessThanBin) { + int number = (int) lessThanBin[0]; + if (number == -1) { + withElse = true; + } + if (number == -2) { + withNull = true; + } + if (number == -1 || number == -2) { + lessThanBin[0] = replace.get(number); + } + number = (int) lessThanBin[1]; + if (number == -1) { + withElse = true; + } + if (number == -2) { + withNull = true; + } + if (number == -1 || number == -2) { + lessThanBin[1] = replace.get(number); + } + } + for (Number[] largerThanBin : constraint.largerThanBin) { + int number = (int) largerThanBin[0]; + if (number == -1) { + withElse = true; + } + if (number == -2) { + withNull = true; + } + if (number == -1 || number == -2) { + largerThanBin[0] = replace.get(number); + } + number = (int) largerThanBin[1]; + if (number == -1) { + withElse = true; + } + if (number == -2) { + withNull = true; + } + if (number == -1 || number == -2) { + largerThanBin[1] = replace.get(number); + } + } + for (Number[] largerThan : constraint.largerThan) { + int number = (int) largerThan[0]; + if (number == -1) { + withElse = true; + } + if (number == -2) { + withNull = true; + } + if (number == -1 || number == -2) { + largerThan[0] = replace.get(number); + } + } + for (Number[] lessThan : constraint.lessThan) { + int number = (int) lessThan[0]; + if (number == -1) { + withElse = true; + } + if (number == -2) { + withNull = true; + } + if (number == -1 || number == -2) { + lessThan[0] = replace.get(number); + } + } + if (!withElse && !withNull) { + this.elseNullSave.add(0); + } + if (withElse && !withNull) { + this.elseNullSave.add(1); + } + if (!withElse && withNull) { + this.elseNullSave.add(2); + } + if (withElse && withNull) { + this.elseNullSave.add(3); + } + } + } + + public String extractConstraint(int[] indices) { + FeatureConstraint constraint = new FeatureConstraint(); + if (this.constraintBetweenFeatures != null) { + constraint.constraintBetweenFeatures = + this.constraintBetweenFeatures.extractConstraint(indices); + } + return constraint.toString(); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/LocalFmOptimizer.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/LocalFmOptimizer.java new file mode 100644 index 000000000..e49a1b4bd --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/LocalFmOptimizer.java @@ -0,0 +1,298 @@ +package com.alibaba.alink.operator.common.optim; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.SparseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.model.ModelParamName; +import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.FmDataFormat; +import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.LogitLoss; +import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.LossFunction; +import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.SquareLoss; +import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.Task; +import com.alibaba.alink.params.recommendation.FmTrainParams; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; + +/** + * Local fm optimizer. + */ +public class LocalFmOptimizer { + private final List > trainData; + private final int[] dim; + protected FmDataFormat fmModel = null; + private final double[] lambda; + + private FmDataFormat sigmaGii; + private final double learnRate; + private final LossFunction lossFunc; + private final int numEpochs; + private final Task task; + + private final double[] y; + private final double[] loss; + private final double[] vx; + private final double[] v2x2; + private long oldTime; + private final double[] lossCurve; + private double oldLoss = 1.0; + + /** + * construct function. + * + * @param trainData train data. + * @param params parameters for optimizer. + */ + public LocalFmOptimizer(List > trainData, Params params) { + this.numEpochs = params.get(FmTrainParams.NUM_EPOCHS); + this.trainData = trainData; + this.y = new double[trainData.size()]; + this.loss = new double[4]; + this.dim = new int[3]; + dim[0] = params.get(FmTrainParams.WITH_INTERCEPT) ? 1 : 0; + dim[1] = params.get(FmTrainParams.WITH_LINEAR_ITEM) ? 1 : 0; + dim[2] = params.get(FmTrainParams.NUM_FACTOR); + vx = new double[dim[2]]; + v2x2 = new double[dim[2]]; + this.lambda = new double[3]; + lambda[0] = params.get(FmTrainParams.LAMBDA_0); + lambda[1] = params.get(FmTrainParams.LAMBDA_1); + lambda[2] = params.get(FmTrainParams.LAMBDA_2); + task = params.get(ModelParamName.TASK); + this.learnRate = params.get(FmTrainParams.LEARN_RATE); + oldTime = System.currentTimeMillis(); + if (task.equals(Task.REGRESSION)) { + double minTarget = -1.0e20; + double maxTarget = 1.0e20; + double d = maxTarget - minTarget; + d = Math.max(d, 1.0); + maxTarget = maxTarget + d * 0.2; + minTarget = minTarget - d * 0.2; + lossFunc = new SquareLoss(maxTarget, minTarget); + } else { + lossFunc = new LogitLoss(); + } + + lossCurve = new double[numEpochs * 3]; + } + + /** + * initialize fmModel. + */ + public void setWithInitFactors(FmDataFormat model) { + this.fmModel = model; + int vectorSize = fmModel.factors.length; + sigmaGii = new FmDataFormat(vectorSize, dim, 0.0); + } + + /** + * optimize Fm problem. + * + * @return fm model. + */ + public Tuple2 optimize() { + for (int i = 0; i < numEpochs; ++i) { + updateFactors(); + calcLossAndEvaluation(); + if (termination(i)) { + break; + } + } + return Tuple2.of(fmModel, lossCurve); + } + + /** + * Termination function of fm iteration. + */ + public boolean termination(int step) { + lossCurve[3 * step] = loss[0] / loss[1]; + lossCurve[3 * step + 2] = loss[3] / loss[1]; + if (task.equals(Task.BINARY_CLASSIFICATION)) { + lossCurve[3 * step + 1] = loss[2]; + + System.out.println("step : " + step + " loss : " + + loss[0] / loss[1] + " auc : " + loss[2] + " accuracy : " + + loss[3] / loss[1] + " time : " + (System.currentTimeMillis() + - oldTime)); + } else { + lossCurve[3 * step + 1] = loss[2] / loss[1]; + System.out.println("step : " + step + " loss : " + + loss[0] / loss[1] + " mae : " + loss[2] / loss[1] + " mse : " + + loss[3] / loss[1] + " time : " + (System.currentTimeMillis() + - oldTime)); + } + oldTime = System.currentTimeMillis(); + if (Math.abs(oldLoss - loss[0] / loss[1]) / oldLoss < 1.0e-6) { + oldLoss = loss[0] / loss[1]; + return true; + } else { + oldLoss = loss[0] / loss[1]; + return false; + } + } + + /** + * Calculate loss and evaluations. + */ + public void calcLossAndEvaluation() { + double lossSum = 0.; + for (int i = 0; i < y.length; i++) { + double yTruth = trainData.get(i).f1; + double l = lossFunc.l(yTruth, y[i]); + lossSum += l; + } + + if (this.task.equals(Task.REGRESSION)) { + double mae = 0.0; + double mse = 0.0; + for (int i = 0; i < y.length; i++) { + double yDiff = y[i] - trainData.get(i).f1; + mae += Math.abs(yDiff); + mse += yDiff * yDiff; + } + loss[2] = mae; + loss[3] = mse; + } else { + Integer[] order = new Integer[y.length]; + double correctNum = 0.0; + for (int i = 0; i < y.length; i++) { + order[i] = i; + if (y[i] > 0 && trainData.get(i).f1 > 0.5) { + correctNum += 1.0; + } + if (y[i] < 0 && trainData.get(i).f1 < 0.5) { + correctNum += 1.0; + } + } + Arrays.sort(order, Comparator.comparingDouble(o -> y[o])); + int mSum = 0; + int nSum = 0; + double posRankSum = 0.; + for (int i = 0; i < order.length; i++) { + int sampleId = order[i]; + int rank = i + 1; + boolean isPositiveSample = trainData.get(sampleId).f1 > 0.5; + if (isPositiveSample) { + mSum++; + posRankSum += rank; + } else { + nSum++; + } + } + if (mSum != 0 && nSum != 0) { + double auc = (posRankSum - 0.5 * mSum * (mSum + 1.0)) / ((double) mSum * (double) nSum); + loss[2] = auc; + } else { + loss[2] = 0.0; + } + loss[3] = correctNum; + } + loss[0] = lossSum; + loss[1] = y.length; + } + + private void updateFactors() { + for (int i1 = 0; i1 < trainData.size(); ++i1) { + Tuple3 sample = trainData.get(i1); + Vector vec = sample.f2; + Tuple2 yVx = calcY(vec, fmModel, dim); + y[i1] = yVx.f0; + + double yTruth = sample.f1; + double dldy = lossFunc.dldy(yTruth, yVx.f0); + + int[] indices; + double[] vals; + if (sample.f2 instanceof SparseVector) { + indices = ((SparseVector) sample.f2).getIndices(); + vals = ((SparseVector) sample.f2).getValues(); + } else { + indices = new int[sample.f2.size()]; + for (int i = 0; i < sample.f2.size(); ++i) { + indices[i] = i; + } + vals = ((DenseVector) sample.f2).getData(); + } + double localLearnRate = sample.f0 * learnRate; + + double eps = 1.0e-8; + if (dim[0] > 0) { + double grad = dldy + lambda[0] * fmModel.bias; + sigmaGii.bias += grad * grad; + fmModel.bias -= localLearnRate * grad / (Math.sqrt(sigmaGii.bias + eps)); + } + + for (int i = 0; i < indices.length; ++i) { + int idx = indices[i]; + // update fmModel + for (int j = 0; j < dim[2]; j++) { + double vixi = vals[i] * fmModel.factors[idx][j]; + double d = vals[i] * (yVx.f1[j] - vixi); + double grad = dldy * d + lambda[2] * fmModel.factors[idx][j]; + sigmaGii.factors[idx][j] += grad * grad; + fmModel.factors[idx][j] -= localLearnRate * grad / (Math.sqrt(sigmaGii.factors[idx][j] + eps)); + } + if (dim[1] > 0) { + double grad = dldy * vals[i] + lambda[1] * fmModel.factors[idx][dim[2]]; + sigmaGii.factors[idx][dim[2]] += grad * grad; + fmModel.factors[idx][dim[2]] + -= grad * localLearnRate / (Math.sqrt(sigmaGii.factors[idx][dim[2]]+ eps)); + } + } + } + } + + /** + * calculate the value of y with given fm model. + */ + private Tuple2 calcY(Vector vec, FmDataFormat fmModel, int[] dim) { + int[] featureIds; + double[] featureValues; + if (vec instanceof SparseVector) { + featureIds = ((SparseVector) vec).getIndices(); + featureValues = ((SparseVector) vec).getValues(); + } else { + featureIds = new int[vec.size()]; + for (int i = 0; i < vec.size(); ++i) { + featureIds[i] = i; + } + featureValues = ((DenseVector) vec).getData(); + } + + Arrays.fill(vx, 0.0); + Arrays.fill(v2x2, 0.0); + + // (1) compute y + double y = 0.; + + if (dim[0] > 0) { + y += fmModel.bias; + } + + for (int i = 0; i < featureIds.length; i++) { + int featurePos = featureIds[i]; + double x = featureValues[i]; + + // the linear term + if (dim[1] > 0) { + y += x * fmModel.factors[featurePos][dim[2]]; + } + // the quadratic term + for (int j = 0; j < dim[2]; j++) { + double vixi = x * fmModel.factors[featurePos][j]; + vx[j] += vixi; + v2x2[j] += vixi * vixi; + } + } + + for (int i = 0; i < dim[2]; i++) { + y += 0.5 * (vx[i] * vx[i] - v2x2[i]); + } + return Tuple2.of(y, vx); + } +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/LocalOptimizer.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/LocalOptimizer.java index 7a03c602c..c4012693c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/optim/LocalOptimizer.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/LocalOptimizer.java @@ -8,12 +8,16 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.AlinkGlobalConfiguration; +import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import com.alibaba.alink.common.linalg.DenseMatrix; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.Vector; import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc; +import com.alibaba.alink.operator.local.AlinkLocalSession.TaskRunner; +import com.alibaba.alink.operator.local.LocalOperator; import com.alibaba.alink.params.regression.HasEpsilon; import com.alibaba.alink.params.shared.HasNumCorrections_30; +import com.alibaba.alink.params.shared.HasNumThreads; import com.alibaba.alink.params.shared.iter.HasMaxIterDefaultAs100; import com.alibaba.alink.params.shared.linear.HasEpsilonDefaultAs0000001; import com.alibaba.alink.params.shared.linear.HasL1; @@ -32,8 +36,8 @@ */ public class LocalOptimizer { - private static final int MAX_FEATURE_NUM = 1000; - private static double EPS = 1.0e-18; + private static final int NEWTON_MAX_FEATURE_NUM = 1024; + private static final double EPS = 1.0e-18; /** * autoCross function. @@ -43,12 +47,14 @@ public class LocalOptimizer { * @param initCoef the initial coefficient of problem. * @param params some parameters of optimization method. */ - public static Tuple2 optimize(OptimObjFunc objFunc, - List > trainData, - DenseVector initCoef, Params params) throws Exception { + public static Tuple2 optimize(OptimObjFunc objFunc, + List > trainData, + DenseVector initCoef, Params params) { LinearTrainParams.OptimMethod method = params.get(LinearTrainParams.OPTIM_METHOD); if (null == method) { - if (params.get(HasL1.L_1) > 0) { + if (trainData.get(0).f2.size() <= NEWTON_MAX_FEATURE_NUM) { + method = OptimMethod.Newton; + } else if (params.get(HasL1.L_1) > 0) { method = OptimMethod.OWLQN; } else { method = OptimMethod.LBFGS; @@ -71,9 +77,24 @@ public static Tuple2 optimize(OptimObjFunc objFunc, } } - public static Tuple2 gd(List > labledVectors, - DenseVector initCoefs, - Params params, OptimObjFunc objFunc) throws Exception { + public static int getNumThreads(List > labledVectors, Params params) { + int numThreads = LocalOperator.getDefaultNumThreads(); + if (params.contains(HasNumThreads.NUM_THREADS)) { + numThreads = params.get(HasNumThreads.NUM_THREADS); + } + return Math.min(numThreads, labledVectors.size()); + } + + public static double[] getFinalConvergeInfos(double[] convergeInfos, int curIter) { + int n = 3 * (curIter + 1); + double[] finalInfos = new double[n]; + System.arraycopy(convergeInfos, 0, finalInfos, 0, n); + return finalInfos; + } + + public static Tuple2 gd(List > labledVectors, + DenseVector initCoefs, + Params params, OptimObjFunc objFunc) { DenseVector coefs = initCoefs.clone(); double epsilon = params.get(HasEpsilon.EPSILON); int maxIter = params.get(HasMaxIterDefaultAs100.MAX_ITER).intValue(); @@ -82,15 +103,15 @@ public static Tuple2 gd(List gd(List (coefs, minLoss); + return new Tuple2 <>(coefs, convergeInfos); } - public static Tuple2 sgd(List > labledVectors, - DenseVector initCoefs, - Params params, OptimObjFunc objFunc) throws Exception { + public static Tuple2 sgd(List > labledVectors, + DenseVector initCoefs, + Params params, OptimObjFunc objFunc) { int maxIter = params.get(HasMaxIterDefaultAs100.MAX_ITER).intValue(); double learningRate = params.get(HasLearningRateDefaultAs01.LEARNING_RATE); double epsilon = params.get(HasEpsilonDefaultAs0000001.EPSILON); @@ -120,75 +146,94 @@ public static Tuple2 sgd(List lb = labledVectors.get(rand.nextInt(labledVectors.size())); - for (int k = 0; k < grad.size(); ++k) { - grad.set(k, 0.0); - } List > list = new ArrayList <>(); list.add(lb); objFunc.calcGradient(list, initCoefs, grad); initCoefs.plusScaleEqual(grad, -learningRate); } + Double curLoss = objFunc.calcObjValue(labledVectors, initCoefs).f0; + + //DenseVector oldCoef = initCoefs.clone(); + //for (int j = 0; j < labledVectors.size(); ++j) { + // Tuple3 lb = labledVectors.get(rand.nextInt(labledVectors.size())); + // //Arrays.fill(grad.getData(), 0.0); + // List > list = new ArrayList <>(); + // list.add(lb); + // objFunc.calcGradient(list, initCoefs, grad); + // initCoefs.plusScaleEqual(grad, -learningRate); + //} + // + //Double curLoss = objFunc.calcObjValue(labledVectors, initCoefs).f0; + //if (curLoss.isNaN() || curLoss.isInfinite()) { + // initCoefs = oldCoef; + // learningRate *= 0.25; + // System.out.println("learning rate changed from " + learningRate + " to " + learningRate * 0.5); + // i--; + // continue; + //} - double curLoss = objFunc.calcObjValue(labledVectors, initCoefs).f0; if (curLoss <= minLoss) { minLoss = curLoss; - for (int j = 0; j < nCoef; ++j) { - minVec.set(j, initCoefs.get(j)); - } + minVec.setEqual(initCoefs); } - if (AlinkGlobalConfiguration.isPrintProcessInfo()) { System.out.println( "sgd step (" + (i + 1) + ") " + " current loss : " + curLoss + " and minLoss : " + minLoss); } - if (curLoss < epsilon) { + convergeInfos[3 * i] = curLoss; + convergeInfos[3 * i + 1] = grad.normL2(); + convergeInfos[3 * i + 2] = learningRate; + if (curLoss < epsilon || Math.abs(lastLoss - curLoss) / curLoss < epsilon) { if (AlinkGlobalConfiguration.isPrintProcessInfo()) { System.out.println("sgd converged at step : " + i); } + convergeInfos = getFinalConvergeInfos(convergeInfos, i); break; } + lastLoss = curLoss; } - return new Tuple2 <>(initCoefs, minLoss); + return new Tuple2 <>(initCoefs, convergeInfos); } - public static Tuple2 newton(List > labledVectors, - DenseVector initCoefs, Params params, OptimObjFunc objFunc) - throws Exception { - Tuple4 tuple4 = + public static Tuple2 newton(List > labledVectors, + DenseVector initCoefs, Params params, OptimObjFunc objFunc) { + Tuple4 tuple4 = newtonWithHessian(labledVectors, initCoefs, params, objFunc); return Tuple2.of(tuple4.f0, tuple4.f3); } - public static Tuple4 newtonWithHessian( + public static Tuple4 newtonWithHessian( List > labledVectors, - DenseVector initCoefs, Params params, OptimObjFunc objFunc) - throws Exception { + DenseVector initCoefs, Params params, OptimObjFunc objFunc) { int maxIter = params.get(HasMaxIterDefaultAs100.MAX_ITER).intValue(); double epsilon = params.get(HasEpsilonDefaultAs0000001.EPSILON); int nCoef = initCoefs.size(); if (!objFunc.hasSecondDerivative()) { throw new RuntimeException("the loss function doesn't have 2 order Derivative, newton can't work."); } - if (nCoef > MAX_FEATURE_NUM) { + if (nCoef > NEWTON_MAX_FEATURE_NUM) { throw new RuntimeException("Too many coefficients, newton can't work."); } + int numThreads = getNumThreads(labledVectors, params); + DenseVector minVec = initCoefs.clone(); - double minLoss = objFunc.calcObjValue(labledVectors, initCoefs).f0; + double minLoss = calcObjValueMT(objFunc, numThreads, labledVectors, initCoefs).f0; DenseVector gradSum = new DenseVector(nCoef); DenseMatrix hessianSum = new DenseMatrix(nCoef, nCoef); DenseVector t = new DenseVector(nCoef); DenseMatrix bMat = new DenseMatrix(nCoef, 1); + double[] convergeInfos = new double[3 * maxIter]; for (int i = 0; i < maxIter; i++) { - gradSum.scaleEqual(0.0); - hessianSum.scaleEqual(0.0); - Tuple2 - oldloss = objFunc.calcHessianGradientLoss(labledVectors, initCoefs, hessianSum, gradSum); + Tuple2 oldloss + = calcHessianGradientLossMT(objFunc, numThreads, labledVectors, initCoefs, hessianSum, gradSum); for (int j = 0; j < nCoef; j++) { bMat.set(j, 0, gradSum.get(j)); } @@ -198,13 +243,11 @@ public static Tuple4 newtonWithH } initCoefs.minusEqual(t); - double curLoss = objFunc.calcObjValue(labledVectors, initCoefs).f0; + double curLoss = calcObjValueMT(objFunc, numThreads, labledVectors, initCoefs).f0; if (curLoss <= minLoss) { minLoss = curLoss; - for (int j = 0; j < nCoef; ++j) { - minVec.set(j, initCoefs.get(j)); - } + minVec.setEqual(initCoefs); } double gradNorm = gradSum.scale(1 / oldloss.f0).normL2(); double lossChangeRatio = Math.abs(curLoss - oldloss.f1 / oldloss.f0) / curLoss; @@ -213,69 +256,47 @@ public static Tuple4 newtonWithH "********** local newton step (" + (i + 1) + ") current loss ratio : " + lossChangeRatio + " and minLoss : " + minLoss + " grad : " + gradNorm); } + convergeInfos[3 * i] = curLoss; + convergeInfos[3 * i + 1] = gradNorm; + convergeInfos[3 * i + 2] = Double.NaN; if (gradNorm < epsilon || lossChangeRatio < epsilon) { + convergeInfos = getFinalConvergeInfos(convergeInfos, i); break; } } - return Tuple4.of(minVec, gradSum, hessianSum, minLoss); - } - - /* - * fix the first several weight and only update the last some weight. - */ - public static Tuple2 lbfgsWithLast(List > labledVectors, - DenseVector initCoefs, Params params, - OptimObjFunc objFunc) { - double[] fixedCoefs = params.get(FIXED_COEFS); - Tuple2 optimizedCoef = lbfgsBase(labledVectors, initCoefs, params, objFunc, fixedCoefs); - double[] optCoefData = optimizedCoef.f0.getData(); - double[] allCoefData = new double[fixedCoefs.length + optCoefData.length]; - System.arraycopy(fixedCoefs, 0, allCoefData, 0, fixedCoefs.length); - System.arraycopy(optCoefData, 0, allCoefData, fixedCoefs.length, optCoefData.length); - optimizedCoef.f0.setData(allCoefData); - return optimizedCoef; - } - - public static ParamInfo FIXED_COEFS = ParamInfoFactory - .createParamInfo("fixedCoefs", double[].class) - .build(); - - public static Tuple2 lbfgs(List > labledVectors, - DenseVector initCoefs, Params params, OptimObjFunc objFunc) { - return lbfgsBase(labledVectors, initCoefs, params, objFunc, new double[0]); + return Tuple4.of(minVec, gradSum, hessianSum, convergeInfos); } - //input initCoefs is candidate coefs, while fixedCoefs saved. they are all onehoted. - private static Tuple2 lbfgsBase(List > labledVectors, - DenseVector initCoefs, Params params, OptimObjFunc objFunc, - double[] fixedCoefs) { - DenseVector allCoefs = new DenseVector(initCoefs.size() + fixedCoefs.length); - System.arraycopy(initCoefs.getData(), 0, allCoefs.getData(), fixedCoefs.length, initCoefs.size()); - System.arraycopy(fixedCoefs, 0, allCoefs.getData(), 0, fixedCoefs.length); - - int maxIter = params.get(HasMaxIterDefaultAs100.MAX_ITER).intValue(); + public static Tuple2 lbfgs(List > labledVectors, + DenseVector initCoefs, Params params, OptimObjFunc objFunc) { + int maxIter = params.get(HasMaxIterDefaultAs100.MAX_ITER); double learnRate = params.get(HasLearningRateDefaultAs01.LEARNING_RATE); double epsilon = params.get(HasEpsilonDefaultAs0000001.EPSILON); - int nCoef = initCoefs.size(); int numCorrection = params.get(HasNumCorrections_30.NUM_CORRECTIONS); + int nCoef = initCoefs.size(); DenseVector[] yK = new DenseVector[numCorrection]; DenseVector[] sK = new DenseVector[numCorrection]; for (int i = 0; i < numCorrection; ++i) { - yK[i] = new DenseVector(initCoefs.size()); - sK[i] = new DenseVector(initCoefs.size()); + yK[i] = new DenseVector(nCoef); + sK[i] = new DenseVector(nCoef); } - DenseVector oldGradient = new DenseVector(initCoefs.size()); + DenseVector oldGradient = new DenseVector(nCoef); DenseVector minVec = initCoefs.clone(); DenseVector dir = null; double[] alpha = new double[numCorrection]; - double minLoss = objFunc.calcObjValue(labledVectors, allCoefs).f0; + double minLoss = objFunc.calcObjValue(labledVectors, initCoefs).f0; double stepLength = -1.0; DenseVector gradient = initCoefs.clone(); + double[] convergeInfos = new double[3 * maxIter]; + + int numThreads = getNumThreads(labledVectors, params); + + DenseVector coefs = initCoefs.clone(); for (int i = 0; i < maxIter; i++) { - double weightSum = objFunc.calcGradient(labledVectors, allCoefs, gradient); + double weightSum = calcGradientMT(objFunc, numThreads, labledVectors, coefs, gradient); if (i == 0) { dir = gradient.clone(); @@ -285,7 +306,9 @@ private static Tuple2 lbfgsBase(List lbfgsBase(List lbfgsBase(List (minVec, minLoss); + return new Tuple2 <>(minVec, convergeInfos); } - // DenseVector allCoefs = new DenseVector(initCoefs.size()+fixedCoefs.length); - // System.arraycopy(initCoefs.getData(), 0, allCoefs.getData(), 0, initCoefs.size()); - // System.arraycopy(fixedCoefs, 0, allCoefs.getData(), initCoefs.size(), fixedCoefs.length); - // DenseVector gradient = initCoefs.clone(); - // - // int maxIter = params.get(HasMaxIterDefaultAs100.MAX_ITER).intValue(); - // double learnRate = params.get(HasLearningRateDv01.LEARNING_RATE); - // double epsilon = params.get(HasEpsilonDv0000001.EPSILON); - // int nCoef = initCoefs.size(); - // int numCorrection = params.get(HasNumCorrections_30.NUM_CORRECTIONS); - // - // DenseVector[] yK = new DenseVector[numCorrection]; - // DenseVector[] sK = new DenseVector[numCorrection]; - // for (int i = 0; i < numCorrection; ++i) { - // yK[i] = new DenseVector(initCoefs.size()); - // sK[i] = new DenseVector(initCoefs.size()); - // } - // DenseVector oldGradient = new DenseVector(initCoefs.size()); - // DenseVector minVec = initCoefs.clone(); - // DenseVector dir = null; - // double[] alpha = new double[numCorrection]; - // - // double minLoss = objFunc.calcObjValue(labledVectors, allCoefs).f0; - // - // double stepLength = -1.0; - // - // for (int i = 0; i < maxIter; i++) { - // double weightSum = objFunc.calcGradient(labledVectors, allCoefs, gradient); - // - // if (i == 0) { - // dir = gradient.clone(); - // } - // dir = calcDir(yK, sK, numCorrection, gradient, oldGradient, i, alpha, stepLength, dir); - // - // int numSearchStep = 10; - // - // double beta = learnRate / numSearchStep; - // double[] losses = objFunc.calcSearchValues(labledVectors, allCoefs, dir, beta, numSearchStep); - // int pos = -1; - // for (int j = 1; j < losses.length; ++j) { - // if (losses[j] < losses[0]) { - // losses[0] = losses[j]; - // pos = j; - // } - // } - // - // if (pos == -1) { - // stepLength = 0.0; - // learnRate *= 0.1; - // } else if (pos == 10) { - // stepLength = beta * pos; - // learnRate *= 10.0; - // } else { - // stepLength = beta * pos; - // } - // - // updateFunc.updateModel(initCoefs, null, dir, stepLength); - // - // if (losses[0] / weightSum <= minLoss) { - // minLoss = losses[0] / weightSum; - // for (int j = 0; j < nCoef; ++j) { - // minVec.set(j, initCoefs.get(j)); - // } - // } - // double gradNorm = Math.sqrt(gradient.normL2Square()); - // - // if (GlobalConfiguration.isPrintProcessInfo()) { - // System.out.println("LBFGS step (" + (i + 1) + ") learnRate : " + learnRate - // + " current loss : " + losses[0] / weightSum + " and minLoss : " + minLoss + " grad norm : " - // + gradNorm); - // } - // if (gradNorm < epsilon || learnRate < EPS) { - // if (GlobalConfiguration.isPrintProcessInfo()) { - // System.out.println("LBFGS converged at step : " + i); - // } - // break; - // } - // } - // return new Tuple2<>(minVec, minLoss); - // } - - public static Tuple2 owlqn(List > labledVectors, - DenseVector initCoefs, Params params, OptimObjFunc objFunc) - throws Exception { + + public static Tuple2 owlqn(List > labledVectors, + DenseVector initCoefs, Params params, OptimObjFunc objFunc) { int maxIter = params.get(HasMaxIterDefaultAs100.MAX_ITER).intValue(); double learnRate = params.get(HasLearningRateDefaultAs01.LEARNING_RATE); double epsilon = params.get(HasEpsilonDefaultAs0000001.EPSILON); @@ -433,8 +376,12 @@ public static Tuple2 owlqn(List 0.0) { /** transfer gradient to pseudo-gradient */ @@ -470,7 +417,8 @@ public static Tuple2 owlqn(List owlqn(List (minVec, minLoss); + return new Tuple2 <>(minVec, convergeInfos); } private static DenseVector calcDir(DenseVector[] yK, DenseVector[] sK, int m, DenseVector gradient, @@ -598,7 +550,7 @@ private static DenseVector calcDir(DenseVector[] yK, DenseVector[] sK, int m, De private static DenseVector calcOwlqnDir(DenseVector[] yK, DenseVector[] sK, int m, DenseVector gradient, DenseVector oldGradient, DenseVector pseGradient, int k, - double[] alpha) throws Exception { + double[] alpha) { DenseVector qL = pseGradient.clone(); // update Y_k = g_k+1 - g_k if (k == 0) { @@ -637,4 +589,213 @@ private static DenseVector calcOwlqnDir(DenseVector[] yK, DenseVector[] sK, int } return qL; } + + private static Tuple2 calcObjValueMT(OptimObjFunc objFunc, int numThreads, + List > labelVectors, + DenseVector coefVector) { + if (numThreads <= 1) { + return objFunc.calcObjValue(labelVectors, coefVector); + } else { + TaskRunner taskRunner = new TaskRunner(); + final SubIterator >[] subIterators = new SubIterator[numThreads]; + final double[] outWeightSum = new double[numThreads]; + final double[] outFVal = new double[numThreads]; + for (int k = 0; k < numThreads; k++) { + subIterators[k] = new SubIterator <>(labelVectors, numThreads, k); + } + for (int k = 0; k < numThreads; k++) { + final int curThread = k; + taskRunner.submit(() -> { + double weightSum = 0.0; + double fVal = 0.0; + double loss; + for (Tuple3 labelVector : subIterators[curThread]) { + loss = objFunc.calcLoss(labelVector, coefVector); + fVal += loss * labelVector.f0; + weightSum += labelVector.f0; + } + outFVal[curThread] = fVal; + outWeightSum[curThread] = weightSum; + } + ); + } + taskRunner.join(); + double weightSum = outWeightSum[0]; + double fVal = outFVal[0]; + for (int k = 1; k < numThreads; k++) { + weightSum += outWeightSum[k]; + fVal += outFVal[k]; + } + return objFunc.finalizeObjValue(coefVector, fVal, weightSum); + } + } + + private static Tuple2 calcHessianGradientLossMT(OptimObjFunc objFunc, int numThreads, + List > labelVectors, + DenseVector coefVector, + DenseMatrix hessian, + DenseVector grad) { + if (numThreads <= 1) { + return objFunc.calcHessianGradientLoss(labelVectors, coefVector, hessian, grad); + } else { + if (!objFunc.hasSecondDerivative()) { + throw new AkUnsupportedOperationException( + "loss function can't support second derivative, newton precondition can not work."); + } + + TaskRunner taskRunner = new TaskRunner(); + final SubIterator >[] subIterators = new SubIterator[numThreads]; + final double[] outWeightSum = new double[numThreads]; + final double[] outLoss = new double[numThreads]; + int nCoefs = coefVector.size(); + final DenseVector[] subGrads = new DenseVector[numThreads]; + final DenseMatrix[] subHessian = new DenseMatrix[numThreads]; + for (int k = 0; k < numThreads; k++) { + subIterators[k] = new SubIterator <>(labelVectors, numThreads, k); + subGrads[k] = new DenseVector(nCoefs); + subHessian[k] = new DenseMatrix(nCoefs, nCoefs); + } + for (int k = 0; k < numThreads; k++) { + final int curThread = k; + taskRunner.submit(() -> { + double weightSum = 0.0; + double loss = 0.0; + for (Tuple3 labelVector : subIterators[curThread]) { + loss = objFunc.calcLoss(labelVector, coefVector); + weightSum += labelVector.f0; + objFunc.updateGradient(labelVector, coefVector, subGrads[curThread]); + objFunc.updateHessian(labelVector, coefVector, subHessian[curThread]); + } + outLoss[curThread] = loss; + outWeightSum[curThread] = weightSum; + } + ); + } + taskRunner.join(); + + double weightSum = outWeightSum[0]; + double loss = outLoss[0]; + grad.setEqual(subGrads[0]); + System.arraycopy(subHessian[0].getData(), 0, hessian.getData(), 0, nCoefs * nCoefs); + for (int k = 1; k < numThreads; k++) { + weightSum += outWeightSum[k]; + loss += outLoss[k]; + grad.plusEqual(subGrads[k]); + hessian.plusEquals(subHessian[k]); + } + + objFunc.finalizeHessianGradientLoss(coefVector, hessian, grad, weightSum); + + return Tuple2.of(weightSum, loss); + } + } + + private static double calcGradientMT(OptimObjFunc objFunc, int numThreads, + List > labelVectors, + DenseVector coefVector, DenseVector grad) { + if (numThreads <= 1) { + return objFunc.calcGradient(labelVectors, coefVector, grad); + } else { + int nCoefs = coefVector.size(); + TaskRunner taskRunner = new TaskRunner(); + final SubIterator >[] subIterators = new SubIterator[numThreads]; + final double[] outValues = new double[numThreads]; + final DenseVector[] subGrads = new DenseVector[numThreads]; + for (int k = 0; k < numThreads; k++) { + subIterators[k] = new SubIterator <>(labelVectors, numThreads, k); + subGrads[k] = new DenseVector(nCoefs); + } + for (int k = 0; k < numThreads; k++) { + final int curThread = k; + taskRunner.submit(() -> { + for (Tuple3 labelVector : subIterators[curThread]) { + outValues[curThread] += labelVector.f0; + objFunc.updateGradient(labelVector, coefVector, subGrads[curThread]); + } + } + ); + } + taskRunner.join(); + double weightSum = outValues[0]; + grad.setEqual(subGrads[0]); + for (int k = 1; k < numThreads; k++) { + weightSum += outValues[k]; + grad.plusEqual(subGrads[k]); + } + objFunc.finalizeGradient(coefVector, grad, weightSum); + + return weightSum; + } + } + + private static double[] calcSearchValuesMT(final OptimObjFunc objFunc, int numThreads, + List > labelVectors, + DenseVector coefVector, + DenseVector dirVec, double beta, int numStep) { + if (numThreads <= 1) { + return objFunc.calcSearchValues(labelVectors, coefVector, dirVec, beta, numStep); + } else { + TaskRunner taskRunner = new TaskRunner(); + final SubIterator >[] subIterators = new SubIterator[numThreads]; + final double[][] lossMat = new double[numThreads][numStep + 1]; + for (int k = 0; k < numThreads; k++) { + subIterators[k] = new SubIterator <>(labelVectors, numThreads, k); + } + for (int k = 0; k < numThreads; k++) { + final int curThread = k; + taskRunner.submit(() -> { + double[] sublosses = objFunc.calcSearchValues(subIterators[curThread], coefVector, dirVec, + beta, + numStep); + System.arraycopy(sublosses, 0, lossMat[curThread], 0, numStep + 1); + } + ); + } + taskRunner.join(); + double[] losses = lossMat[0]; + for (int j = 1; j < numThreads; j++) { + for (int jj = 0; jj <= numStep; jj++) { + losses[jj] += lossMat[j][jj]; + } + } + return losses; + } + } + + private static double[] constraintCalcSearchValuesMT(OptimObjFunc objFunc, int numThreads, + List > labelVectors, + DenseVector coefVector, + DenseVector dirVec, double beta, int numStep) { + if (numThreads <= 1) { + return objFunc.constraintCalcSearchValues(labelVectors, coefVector, dirVec, beta, numStep); + } else { + TaskRunner taskRunner = new TaskRunner(); + final List >[] subLists = new List[numThreads]; + final double[][] lossMat = new double[numThreads][numStep + 1]; + int nTotal = labelVectors.size(); + int nSub = nTotal / numThreads; + for (int k = 0; k < numThreads - 1; k++) { + subLists[k] = labelVectors.subList(nSub * k, nSub * (k + 1)); + } + subLists[numThreads - 1] = labelVectors.subList(nSub * (numThreads - 1), nTotal); + for (int k = 0; k < numThreads; k++) { + final int curThread = k; + taskRunner.submit(() -> { + double[] sublosses = objFunc.constraintCalcSearchValues(subLists[curThread], coefVector, + dirVec, beta, numStep); + System.arraycopy(sublosses, 0, lossMat[curThread], 0, numStep + 1); + } + ); + } + taskRunner.join(); + double[] losses = lossMat[0]; + for (int j = 1; j < numThreads; j++) { + for (int jj = 0; jj <= numStep; jj++) { + losses[jj] += lossMat[j][jj]; + } + } + return losses; + } + } + } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/SubIterator.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/SubIterator.java new file mode 100644 index 000000000..26007a7de --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/SubIterator.java @@ -0,0 +1,40 @@ +package com.alibaba.alink.operator.common.optim; + +import java.util.Iterator; +import java.util.List; + +public class SubIterator implements Iterator , Iterable { + final List list; + final int numSubs; + final int curSub; + final int listSize; + int cursor; + + public SubIterator(List list, int numSubs, int curSub) { + this.list = list; + this.numSubs = numSubs; + this.curSub = curSub; + this.listSize = list.size(); + cursor = curSub; + } + + @Override + public boolean hasNext() { + return cursor < listSize; + } + + @Override + public T next() { + T result = null; + if (hasNext()) { + result = list.get(cursor); + cursor += numSubs; + } + return result; + } + + @Override + public Iterator iterator() { + return this; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/ConstraintObjFunc.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/ConstraintObjFunc.java new file mode 100644 index 000000000..cdd090de2 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/ConstraintObjFunc.java @@ -0,0 +1,62 @@ +package com.alibaba.alink.operator.common.optim.activeSet; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.operator.common.linear.UnaryLossObjFunc; +import com.alibaba.alink.operator.common.linear.unarylossfunc.UnaryLossFunc; + +public class ConstraintObjFunc extends UnaryLossObjFunc { + + private static final long serialVersionUID = -8984917852066425449L; + public DenseMatrix equalityConstraint; + public DenseMatrix inequalityConstraint; + public DenseVector equalityItem; + public DenseVector inequalityItem; + + /** + * Constructor. + * + * @param unaryLossFunc loss function. + * @param params input parameters. + */ + public ConstraintObjFunc(UnaryLossFunc unaryLossFunc, Params params) { + super(unaryLossFunc, params); + } + + /** + * Calculate loss values for line search in optimization. + * + * @param labelVectors train data. + * @param coefVector coefficient of current time. + * @param dirVecOrigin descend direction of optimization problem. + * @param numStep num of line search step. + * @return double[] losses. + */ + public double[] calcLineSearch(Iterable > labelVectors, DenseVector coefVector, + DenseVector dirVecOrigin, int numStep, double l2Weight) { + double[] losses = new double[2 * numStep + 1]; + DenseVector[] stepVec = new DenseVector[2 * numStep + 1]; + stepVec[numStep] = coefVector.clone(); + DenseVector dirVec = dirVecOrigin.clone(); + stepVec[numStep + 1] = coefVector.plus(dirVec); + stepVec[numStep - 1] = coefVector.minus(dirVec); + for (int i = 2; i < numStep + 1; i++) { + DenseVector temp = dirVec.scale(9 * Math.pow(10, 1 - i)); + stepVec[numStep + i] = stepVec[numStep + i - 1].minus(temp); + stepVec[numStep - i] = stepVec[numStep - i + 1].plus(temp); + } + + double l2Item = coefVector.normL2() * l2Weight; + for (Tuple3 labelVector : labelVectors) { + for (int i = 0; i < numStep * 2 + 1; ++i) { + losses[i] += calcLoss(labelVector, stepVec[i]) * labelVector.f0 + l2Item; + } + } + return losses; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/ConstraintVariable.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/ConstraintVariable.java new file mode 100644 index 000000000..191347318 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/ConstraintVariable.java @@ -0,0 +1,16 @@ +package com.alibaba.alink.operator.common.optim.activeSet; + +import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable; + +public class ConstraintVariable extends OptimVariable { + public final static String constraints = "constraints"; + public final static String weightDim = "dim"; + public final static String lossAllReduce = "lossAllReduce"; + public final static String lastLoss = "lastLoss"; + public final static String loss = "loss"; + public final static String convergence = "convergence"; + public final static String weight = "weight"; + public final static String linearSearchTimes = "linearSearchTimes"; + public final static String minL2Weight = "minL2Weight"; + public final static String newtonRetryTime = "retryTime"; +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/QpProblem.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/QpProblem.java new file mode 100644 index 000000000..265736849 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/QpProblem.java @@ -0,0 +1,161 @@ +package com.alibaba.alink.operator.common.optim.activeSet; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; + +import com.alibaba.alink.common.linalg.BLAS; +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; + +import java.util.Arrays; + +public class QpProblem { + public static Tuple3 qpact(DenseMatrix hessian, DenseVector grad, + DenseMatrix equalMatrix, DenseVector equalConstant, + DenseMatrix inequalMatrix, + DenseVector inequalConstant, + DenseVector x0) { + int dim = x0.size(); + int hesRow = hessian.numRows(); + int hesCol = hessian.numCols(); + if (hesRow == 0 || hesCol == 0) { + hessian = DenseMatrix.zeros(dim, dim); + } + int gradSize = grad.size(); + if (gradSize == 0) { + grad = DenseVector.zeros(dim); + } + + double epsilon = 1e-9; + double err = 1e-6; + double iter = 0; + int maxIter = 1000; + int exitFlag = 0; + DenseVector x = x0; + + if (equalConstant == null) { + equalConstant = new DenseVector(); + } + if (equalMatrix == null) { + equalMatrix = new DenseMatrix(); + } + if (inequalConstant == null) { + inequalConstant = new DenseVector(); + } + if (inequalMatrix == null) { + inequalMatrix = new DenseMatrix(); + } + int equalNum = equalConstant.size(); + int inequalNum = inequalConstant.size(); + DenseVector lambdaK = DenseVector.zeros(inequalNum + equalNum); + int[] index = new int[inequalNum]; + Arrays.fill(index, 1); + double[][] inequalMatrixData = inequalMatrix.getArrayCopy2D(); + for (int i = 0; i < inequalNum; i++) { + if (BLAS.dot(inequalMatrixData[i], x.getData()) > inequalConstant.get(i) + epsilon) { + index[i] = 0; + } + } + while (iter <= maxIter) { + int indexSum = equalNum; + for (int i = 0; i < inequalNum; i++) { + indexSum += index[i]; + } + double[][] constraintsData = new double[indexSum][dim]; + if (equalNum > 0) { + SqpUtil.fillMatrix(constraintsData, 0, 0, equalMatrix.getArrayCopy2D()); + } + int begin = equalNum; + for (int j = 0; j < inequalNum; j++) { + if (index[j] > 0) { + System.arraycopy(inequalMatrix.getRow(j), 0, constraintsData[begin], 0, dim); + begin++; + } + } + DenseMatrix constraintsMatrix = new DenseMatrix(constraintsData); + DenseVector gk = hessian.multiplies(x).plus(grad); + int m1 = indexSum; + DenseVector[] vectors = subProblem(hessian, gk, constraintsMatrix, DenseVector.zeros(m1)); + DenseVector dk = vectors[0]; + lambdaK = vectors[1]; + if (dk.normL1() < err) { + //y is the min value of lambda, while jk is the index of the lambda. + double y = 0; + int jk = 0; + if (lambdaK.size() > equalNum) { + Tuple2 tuple2 = SqpUtil.findMin(lambdaK, equalNum, m1 - equalNum); + y = tuple2.f0; + jk = tuple2.f1 + 1; + } + if (y > 0) { + exitFlag = 0; + } else { + exitFlag = 1; + int indexTempSum = 0; + for (int i = 0; i < inequalNum; i++) { + indexTempSum += index[i]; + //seems equalNum no need to plus. + if (index[i] == 1 & (indexTempSum) == jk) { + index[i] = 0; + break; + } + } + } + iter++; + } else { + exitFlag = 1; + //calculate the step length + double alpha = 1; + double step = 1; + int activeIndex = 0; + for (int i = 0; i < inequalNum; i++) { + double[] inequalRow = inequalMatrix.getRow(i); + if (index[i] == 0 && BLAS.dot(inequalRow, dk.getData()) < 0) { + double tempStep = (inequalConstant.get(i) - BLAS.dot(inequalRow, x.getData())) / + BLAS.dot(inequalRow, dk.getData()); + if (tempStep < step) { + step = tempStep; + activeIndex = i; + } + } + } + if (alpha > step) { + alpha = step; + } + x = x.plus(dk.scale(alpha)); + if (step < 1) { + index[activeIndex] = 1; + } + } + if (exitFlag == 0) { + break; + } + iter++; + } + return Tuple3.of(x, lambdaK, exitFlag); + } + + //this is the sub problem which only considers equality constraints. + //may use lapack to solve this problem. + public static DenseVector[] subProblem(DenseMatrix hessian, DenseVector gradient, DenseMatrix equalMatrix, + DenseVector equalConstant) { + //may throw exceptions. + DenseMatrix ginvH = hessian.inverse(); + int equalNum = equalMatrix.numRows(); + DenseVector lambda; + DenseVector x; + if (equalNum > 0) { + DenseMatrix rb = equalMatrix.multiplies(ginvH); + DenseMatrix temp = (rb.multiplies(equalMatrix.transpose())).inverse(); + DenseMatrix B = temp.multiplies(rb); + DenseVector temp2 = B.multiplies(gradient); + lambda = temp2.plus(temp.multiplies(equalConstant)); + DenseMatrix G = ginvH.minus(ginvH.multiplies(equalMatrix.transpose()).multiplies(temp).multiplies(rb)); + x = B.transpose().multiplies(equalConstant).minus(G.multiplies(gradient)); + } else { + x = ginvH.multiplies(gradient).scale(-1); + lambda = DenseVector.zeros(1); + } + return new DenseVector[] {x, lambda}; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/Sqp.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/Sqp.java new file mode 100644 index 000000000..dd5eb0b71 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/Sqp.java @@ -0,0 +1,442 @@ +package com.alibaba.alink.operator.common.optim.activeSet; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import com.alibaba.alink.common.comqueue.ComContext; +import com.alibaba.alink.common.comqueue.CompareCriterionFunction; +import com.alibaba.alink.common.comqueue.CompleteResultFunction; +import com.alibaba.alink.common.comqueue.ComputeFunction; +import com.alibaba.alink.common.comqueue.IterativeComQueue; +import com.alibaba.alink.common.comqueue.communication.AllReduce; +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.model.ModelParamName; +import com.alibaba.alink.operator.common.optim.Optimizer; +import com.alibaba.alink.operator.common.optim.local.LocalSqp; +import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc; +import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable; +import com.alibaba.alink.operator.common.optim.subfunc.PreallocateMatrix; +import com.alibaba.alink.operator.common.optim.subfunc.PreallocateVector; +import com.alibaba.alink.params.shared.iter.HasMaxIterDefaultAs100; +import com.alibaba.alink.params.shared.linear.HasEpsilonDefaultAs0000001; +import com.alibaba.alink.params.shared.linear.HasL2; +import com.alibaba.alink.params.shared.linear.HasWithIntercept; +import org.apache.commons.math3.optim.PointValuePair; +import org.apache.commons.math3.optim.linear.LinearConstraint; +import org.apache.commons.math3.optim.linear.LinearConstraintSet; +import org.apache.commons.math3.optim.linear.LinearObjectiveFunction; +import org.apache.commons.math3.optim.linear.Relationship; +import org.apache.commons.math3.optim.linear.SimplexSolver; +import org.apache.commons.math3.optim.nonlinear.scalar.GoalType; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class Sqp extends Optimizer { + private static final int MAX_FEATURE_NUM = 3000; + + /** + * construct function. + * + * @param objFunc object function, calc loss and grad. + * @param trainData data for training. + * @param coefDim the dimension of features. + * @param params some parameters of optimization method. + */ + public Sqp(DataSet objFunc, DataSet > trainData, + DataSet coefDim, Params params) { + super(objFunc, trainData, coefDim, params); + } + + /** + * Solve the following quadratic programing problem: + *

+ * min 0.5 * p^TGp + g_k^Tp + * s.t. A_i \dot p = b - A_i \dot x_k, where i belongs to equality constraints + * A_i \dot p >= b - A_i \dot x_k, where i belongs to inequality constraints + */ + @Override + public DataSet > optimize() { + int maxIter = params.get(HasMaxIterDefaultAs100.MAX_ITER); + this.coefVec = this.coefDim.map(new InitialCoef()).withBroadcastSet(this.objFuncSet, "objFunc"); + DataSet model = new IterativeComQueue() + .initWithPartitionedData(OptimVariable.trainData, trainData) + .initWithBroadcastData(OptimVariable.model, coefVec) + .initWithBroadcastData(OptimVariable.objFunc, objFuncSet) + .initWithBroadcastData(ConstraintVariable.weightDim, coefDim) + .add(new InitializeParams()) + .add(new PreallocateVector(OptimVariable.grad, new double[2]))//initial grad + .add(new PreallocateMatrix(OptimVariable.hessian, MAX_FEATURE_NUM))//initial hessian + .add(new CalcGradAndHessian())//update grad and hessian with up-to-date coef. + .add(new AllReduce(OptimVariable.gradHessAllReduce)) + .add(new GetGradientAndHessian()) + .add(new CalcDir(params))//sqp calculate dir + .add(new LineSearch(params.get(HasL2.L_2)))//line search the best weight, and the best grad and hessian. + .add(new AllReduce(ConstraintVariable.lossAllReduce)) + .add(new GetMinCoef()) + .add(new CalcConvergence())//consider to put the following three step at the head. + .setCompareCriterionOfNode0(new IterTermination(params.get(HasEpsilonDefaultAs0000001.EPSILON))) + .setMaxIter(maxIter) + .closeWith(new BuildModel()) + .exec(); + + return model.mapPartition(new ParseRowModel()); + } + + private static class InitializeParams extends ComputeFunction { + + private static final long serialVersionUID = 3775769468090056172L; + + @Override + public void calc(ComContext context) { + if (context.getStepNo() == 1) { + DenseVector weight = ((List ) context.getObj(OptimVariable.model)).get(0); + context.putObj(ConstraintVariable.weight, weight); + context.putObj(ConstraintVariable.minL2Weight, 1e-8); + context.putObj(ConstraintVariable.linearSearchTimes, 40); + context.putObj(ConstraintVariable.newtonRetryTime, 12); + context.putObj(ConstraintVariable.loss, 0.); + ConstraintObjFunc objFunc = + (ConstraintObjFunc) ((List ) context.getObj(OptimVariable.objFunc)).get(0); + context.putObj(SqpVariable.icmBias, objFunc.inequalityItem); + context.putObj(SqpVariable.ecmBias, objFunc.equalityItem); + + } + } + } + + //initialize the coef. + public static double[] phaseOne(double[][] equalMatrix, double[] equalItem, + double[][] inequalMatrix, double[] inequalItem, int dim) { + int equalNum = 0; + if (equalItem != null) { + equalNum = equalItem.length; + } + + int inequalNum = 0; + if (inequalItem != null) { + inequalNum = inequalItem.length; + } + + int constraintLength = equalNum + inequalNum; + //if no constraint, return zeros. + if (constraintLength == 0) { + double[] res = new double[dim]; + Arrays.fill(res, 1e-4); + return res; + } + //optimize the phase function, which is 0*x1+0*x2+...z1+z2+...,first place equal, then place inequal. + double[] objData = new double[dim + constraintLength]; + Arrays.fill(objData, dim, dim + constraintLength, 1); + LinearObjectiveFunction objFunc = new LinearObjectiveFunction(objData, 0); + List cons = new ArrayList <>(); + for (int i = 0; i < equalNum; i++) { + double[] constraint = new double[dim + constraintLength]; + System.arraycopy(equalMatrix[i], 0, constraint, 0, dim); + constraint[i + dim] = 1; + double item = equalItem[i]; + cons.add(new LinearConstraint(constraint, Relationship.EQ, item)); + } + for (int i = 0; i < inequalNum; i++) { + double[] constraint = new double[dim + constraintLength]; + System.arraycopy(inequalMatrix[i], 0, constraint, 0, dim); + constraint[i + dim + equalNum] = 1; + double item = inequalItem[i]; + cons.add(new LinearConstraint(constraint, Relationship.GEQ, item)); + } + for (int i = dim; i < dim + constraintLength; i++) { + double[] constraint = new double[dim + constraintLength]; + constraint[i] = 1; + cons.add(new LinearConstraint(constraint, Relationship.GEQ, 0)); + } + + LinearConstraintSet conSet = new LinearConstraintSet(cons); + PointValuePair pair = new SimplexSolver().optimize(objFunc, conSet, GoalType.MINIMIZE); + double[] res = new double[dim]; + System.arraycopy(pair.getPoint(), 0, res, 0, dim); + return res; + } + + public static class InitialCoef extends RichMapFunction { + private static final long serialVersionUID = -1725328800337420019L; + private double[][] equalMatrix; + private double[] equalItem; + private double[][] inequalMatrix; + private double[] inequalItem; + + @Override + public void open(Configuration parameters) throws Exception { + ConstraintObjFunc objFunc = + (ConstraintObjFunc) getRuntimeContext() + .getBroadcastVariable("objFunc").get(0); + this.inequalMatrix = objFunc.inequalityConstraint.getArrayCopy2D(); + this.inequalItem = objFunc.inequalityItem.getData(); + this.equalMatrix = objFunc.equalityConstraint.getArrayCopy2D(); + this.equalItem = objFunc.equalityItem.getData(); + } + + @Override + public DenseVector map(Integer n) throws Exception { + return new DenseVector(phaseOne(equalMatrix, equalItem, inequalMatrix, inequalItem, n)); + } + } + + public static class CalcDir extends ComputeFunction { + + private static final long serialVersionUID = 7694040433763461801L; + private boolean hasIntercept; + private double l2; + + public CalcDir(Params params) { + hasIntercept = params.get(HasWithIntercept.WITH_INTERCEPT); + l2 = params.get(HasL2.L_2); + } + + @Override + public void calc(ComContext context) { + ConstraintObjFunc objFunc = (ConstraintObjFunc) ((List ) context.getObj( + OptimVariable.objFunc)).get(0); + Double loss = context.getObj(ConstraintVariable.loss); + Tuple2 grad = context.getObj(OptimVariable.grad); + DenseVector gradient = grad.f0; + DenseMatrix hessian = context.getObj(OptimVariable.hessian); + int dim = ((List ) context.getObj(ConstraintVariable.weightDim)).get(0); + final int retryTime = context.getObj(ConstraintVariable.newtonRetryTime); + final double minL2Weight = context.getObj(ConstraintVariable.minL2Weight); + DenseVector weight = context.getObj(ConstraintVariable.weight); + DenseVector dir = SqpPai.getStartDir(objFunc, weight, + context.getObj(SqpVariable.icmBias), + context.getObj(SqpVariable.ecmBias)); + boolean[] activeSet = SqpPai.getActiveSet(objFunc.inequalityConstraint, objFunc.inequalityItem, dir, dim); + // dir = QpProblem.subProblem(hessian, gradient, new DenseMatrix(),new DenseVector())[0]; + Tuple3 dirItems = + SqpPai.calcDir(retryTime, dim, objFunc, dir, weight, hessian, gradient, l2, minL2Weight, hasIntercept, + activeSet); + // LocalSqp.calcDir(retryTime, dim, objFunc, weight, hessian, gradient, loss, l2, + // minL2Weight, hasIntercept); + dir = dirItems.f0; + grad.f0 = dirItems.f1; + hessian = dirItems.f2; + context.putObj(OptimVariable.grad, grad); + context.putObj(OptimVariable.hessian, hessian); + context.putObj(ConstraintVariable.loss, loss);//grad and hessian has been put in. + context.putObj(OptimVariable.dir, dir); + } + } + + public static class LineSearch extends ComputeFunction { + + private static final long serialVersionUID = 2611682666208211053L; + private ConstraintObjFunc objFunc; + private double l2Weight; + + public LineSearch(double l2Weight) { + this.l2Weight = l2Weight; + } + + @Override + public void calc(ComContext context) { + objFunc = (ConstraintObjFunc) ((List ) context.getObj(OptimVariable.objFunc)).get(0); + DenseVector dir = context.getObj(OptimVariable.dir); + final int linearSearchTimes = context.getObj(ConstraintVariable.linearSearchTimes); + final double minL2Weight = context.getObj(ConstraintVariable.minL2Weight); + if (l2Weight == 0) { + l2Weight += minL2Weight; + } + Iterable > labledVectors = context.getObj(OptimVariable.trainData); + DenseVector weight = context.getObj(ConstraintVariable.weight); + double[] losses = objFunc.calcLineSearch(labledVectors, weight, dir, linearSearchTimes, l2Weight); + context.putObj(ConstraintVariable.lossAllReduce, losses); + } + } + + public static class GetMinCoef extends ComputeFunction { + + private static final long serialVersionUID = 239058213400494835L; + + @Override + public void calc(ComContext context) { + double[] losses = context.getObj(ConstraintVariable.lossAllReduce); + Tuple2 grad = context.getObj(OptimVariable.grad); + DenseVector dir = context.getObj(OptimVariable.dir); + DenseVector weight = context.getObj(ConstraintVariable.weight); + double loss = LocalSqp.lineSearch(losses, weight, grad.f0, dir); + context.putObj(ConstraintVariable.weight, weight); + int stepNum = context.getStepNo(); + if (stepNum != 1) { + context.putObj(ConstraintVariable.lastLoss, context.getObj(ConstraintVariable.loss)); + } + context.putObj(ConstraintVariable.loss, loss); + } + } + + public static class CalcConvergence extends ComputeFunction { + + private static final long serialVersionUID = 1936163292416518616L; + + @Override + public void calc(ComContext context) { + //restore items + ConstraintObjFunc objFunc = + (ConstraintObjFunc) ((List ) context.getObj(OptimVariable.objFunc)).get(0); + objFunc.equalityItem = context.getObj(SqpVariable.ecmBias); + objFunc.inequalityItem = context.getObj(SqpVariable.icmBias); + + if (context.getStepNo() != 1) { + double loss = context.getObj(ConstraintVariable.loss); + double lastLoss = context.getObj(ConstraintVariable.lastLoss); + int iter = context.getStepNo(); + int lossStep = 5; + double convergence; + if (iter <= lossStep) { + convergence = (lastLoss - loss) / (Math.abs(loss) * iter); + } else { + convergence = (lastLoss - loss) / (Math.abs(loss) * lossStep); + } + // if (context.getTaskId() == 0) { + // System.out.println("iter: " + iter + ", convergence: " + convergence); + // } + context.putObj(ConstraintVariable.convergence, convergence); + } + } + } + + public static class CalcGradAndHessian extends ComputeFunction { + private static final long serialVersionUID = 4760392853024920737L; + private OptimObjFunc objFunc; + + @Override + public void calc(ComContext context) { + Iterable > labledVectors = context.getObj(OptimVariable.trainData); + Tuple2 grad = context.getObj(OptimVariable.grad); + DenseVector weight = context.getObj(ConstraintVariable.weight); + DenseMatrix hessian = context.getObj(OptimVariable.hessian); + + int size = grad.f0.size(); + + if (objFunc == null) { + objFunc = ((List ) context.getObj(OptimVariable.objFunc)).get(0); + } + //here does not add loss calculation。 + Tuple2 loss = objFunc.calcHessianGradientLoss(labledVectors, weight, hessian, grad.f0); + + /** + * prepare buffer vec for allReduce. the last two elements of vec are weight Sum and current loss. + */ + double[] buffer = context.getObj(OptimVariable.gradHessAllReduce); + if (buffer == null) { + buffer = new double[size + size * size + 2]; + context.putObj(OptimVariable.gradHessAllReduce, buffer); + } + for (int i = 0; i < size; ++i) { + buffer[i] = grad.f0.get(i); + for (int j = 0; j < size; ++j) { + buffer[(i + 1) * size + j] = hessian.get(i, j); + } + } + buffer[size + size * size] = loss.f0; + buffer[size + size * size + 1] = loss.f1; + } + } + + public static class GetGradientAndHessian extends ComputeFunction { + + private static final long serialVersionUID = 2724626183370161805L; + + @Override + public void calc(ComContext context) { + Tuple2 grad = context.getObj(OptimVariable.grad); + DenseMatrix hessian = context.getObj(OptimVariable.hessian); + int size = grad.f0.size(); + double[] gradarr = context.getObj(OptimVariable.gradHessAllReduce); + grad.f1[0] = gradarr[size + size * size]; + for (int i = 0; i < size; ++i) { + grad.f0.set(i, gradarr[i] / grad.f1[0]); + for (int j = 0; j < size; ++j) { + hessian.set(i, j, gradarr[(i + 1) * size + j] / grad.f1[0]); + } + } + grad.f1[0] = gradarr[size + size * size]; + grad.f1[1] = gradarr[size + size * size + 1] / grad.f1[0]; + } + } + + public static class IterTermination extends CompareCriterionFunction { + private static final long serialVersionUID = -9142037869276356229L; + private double epsilon; + + IterTermination(double epsilon) { + this.epsilon = epsilon; + } + + @Override + public boolean calc(ComContext context) { + if (context.getStepNo() == 1) { + return false; + } + double convergence = context.getObj(ConstraintVariable.convergence); + ; + return Math.abs(convergence) <= epsilon; + } + } + + public static class BuildModel extends CompleteResultFunction { + + private static final long serialVersionUID = -7967444945852772659L; + + @Override + public List calc(ComContext context) { + if (context.getTaskId() != 0) { + return null; + } + DenseVector weight = context.getObj(ConstraintVariable.weight); + double[] losses = new double[2]; + if (context.containsObj(ConstraintVariable.lastLoss)) { + losses[0] = context.getObj(ConstraintVariable.lastLoss); + } + if (context.containsObj(ConstraintVariable.loss)) { + losses[1] = context.getObj(ConstraintVariable.loss); + } + Params params = new Params(); + params.set(ModelParamName.COEF, weight); + params.set(ModelParamName.LOSS_CURVE, losses); + List model = new ArrayList <>(1); + model.add(Row.of(params.toJson())); + return model; + } + } + + public static class ParseRowModel extends RichMapPartitionFunction > { + + private static final long serialVersionUID = 7590757264182157832L; + + @Override + public void mapPartition(Iterable iterable, + Collector > collector) throws Exception { + DenseVector coefVector = null; + double[] lossCurve = null; + int taskId = getRuntimeContext().getIndexOfThisSubtask(); + if (taskId == 0) { + for (Row row : iterable) { + Params params = Params.fromJson((String) row.getField(0)); + coefVector = params.get(ModelParamName.COEF); + lossCurve = params.get(ModelParamName.LOSS_CURVE); + } + + if (coefVector != null) { + collector.collect(Tuple2.of(coefVector, lossCurve)); + } + } + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/SqpPai.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/SqpPai.java new file mode 100644 index 000000000..0f2cb1260 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/SqpPai.java @@ -0,0 +1,372 @@ +package com.alibaba.alink.operator.common.optim.activeSet; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; + +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; + +import java.util.Arrays; + +public class SqpPai { + public static boolean[] getActiveSet(DenseMatrix inequalityConstraint, DenseVector inequalityItem, + DenseVector dir, int dim) { + int inequalNum = inequalityItem.size(); + boolean[] activeSet = new boolean[inequalNum]; + Arrays.fill(activeSet, false); + for (int i = 0; i < inequalNum; i++) { + double sum = 0; + double[] row = inequalityConstraint.getRow(i); + for (int j = 0; j < dim; j++) { + sum += row[j] * dir.get(j); + } + if (Math.abs(sum - inequalityItem.get(i)) <= 1e-7) { + activeSet[i] = true; + } else { + if (!(sum < inequalityItem.get(i))) { + activeSet[i] = true; + } + } + } + return activeSet; + } + + //todo check dir and weight, which is qp_dir and dir. + private static Tuple2 searchActiveSet(DenseMatrix inequalityConstraint, + DenseVector inequalityItem, + DenseVector dir, DenseVector qpDir, boolean[] activeSet, + int dim) { + double alpha = 1.; + int inequalNum = inequalityItem.size(); + int active = -1; + for (int i = 0; i < inequalNum; i++) { + if (activeSet[i]) { + continue; + } + double ax = 0; + double ap = 0; + double p = 0; + for (int j = 0; j < dim; j++) { + double val = inequalityConstraint.get(i, j); + ax -= val * dir.get(j); + ap -= val * qpDir.get(j); + } + if (ap >= 0) { + continue; + } + p = (-inequalityItem.get(i) - ax) / ap; + if (p < alpha) { + alpha = p; + active = i;//满足条件的话就设置i是满足要求,可以active的。 + } + } + return Tuple2.of(alpha, active); + } + + private static double[][] enableActiveSet(DenseMatrix hessian, int dim, boolean[] activeSet, + DenseMatrix equalityConstraint, DenseVector equalityItem, + DenseMatrix inequalityConstraint, DenseVector inequalityItem) { + int equalNum = equalityItem.size(); + int inequalNum = inequalityItem.size(); + int kktSize = calcKktSize(dim, equalityItem, activeSet); + double[][] h = new double[kktSize][kktSize]; + double[][] hessianArray = hessian.getArrayCopy2D(); + SqpUtil.fillMatrix(h, 0, 0, hessianArray); + for (int i = 0; i < equalNum; i++) { + int row = i + dim; + for (int j = 0; j < dim; j++) { + h[row][j] = equalityConstraint.get(i, j); + h[j][row] = equalityConstraint.get(i, j); + } + } + int asCnt = 0; + for (int i = 0; i < inequalNum; i++) { + if (activeSet[i]) { + int row = equalNum + dim + asCnt; + for (int j = 0; j < dim; j++) { + h[row][j] = inequalityConstraint.get(i, j); + h[j][row] = inequalityConstraint.get(i, j); + } + asCnt++; + } + } + return h; + } + + private static Tuple2 checkLambda(DenseMatrix equalityConstraint, + DenseMatrix inequalityConstraint, int dim, + boolean[] activeSet, DenseVector qpDir) { + int inequalNum = inequalityConstraint.numRows(); + int equalNum = equalityConstraint.numRows(); + int lagrangeId = 0; + double maxLam = -Double.MAX_VALUE; + int maxId = -1; + for (int i = 0; i < inequalNum; i++) { + if (activeSet[i]) { + double lambda = qpDir.get(dim + equalNum + lagrangeId); + if (lambda > maxLam) { + maxId = i; + maxLam = lambda; + } + lagrangeId++; + } + } + if (maxId > 0 && maxLam > 0) { + activeSet[maxId] = false; + return Tuple2.of(activeSet, true); + } + return Tuple2.of(activeSet, false); + } + + //todo 注意一下,pai上有dim,也就是只乘前面的数据 + private static double calculateQpLoss(DenseMatrix h, DenseVector p, DenseVector g, int dim) { + double loss; + DenseVector tmp = new DenseVector(dim); + matDotVec(h, p, tmp, dim); + loss = dot(tmp, p, dim); + return loss * 0.5 + dot(p, g, dim); + } + + private static Tuple2 solveQuadProblem(DenseMatrix equalityConstraint, + DenseVector equalityItem, + DenseVector qpDir, + DenseMatrix inequalityConstraint, + DenseVector inequalityItem, + DenseVector dir, boolean[] activeSet, + DenseMatrix hessian, + DenseVector grad, DenseVector weight) { + int dim = weight.size(); + int kktSize = calcKktSize(dim, equalityItem, activeSet); + DenseVector gpGrad = new DenseVector(kktSize); + matDotVec(hessian, dir, gpGrad, dim); + vecAddVec(grad, gpGrad, dim); + DenseMatrix h = new DenseMatrix( + enableActiveSet(hessian, dim, activeSet, + equalityConstraint, equalityItem, inequalityConstraint, inequalityItem)); + double norm = 1 / gpGrad.normL1(); + h.scaleEqual(norm); + gpGrad.scaleEqual(norm); + try { + DenseMatrix ginvH = h.inverse(); + qpDir = ginvH.multiplies(gpGrad); + } catch (Exception e) { + return Tuple2.of(qpDir, -1.); + } + double sum = 0; + for (int i = 0; i < dim; i++) { + sum += Math.pow(qpDir.get(i), 2); + } + sum = 1. * Math.sqrt(sum) / dim; + return Tuple2.of(qpDir, sum); + } + + // private static Tuple2 solveQuadProblem2(DenseMatrix equalityConstraint, DenseVector + // equalityItem, + // DenseVector qpDir, + // DenseMatrix inequalityConstraint, DenseVector + // inequalityItem, + // DenseVector dir, boolean[] activeSet, + // DenseMatrix hessian, + // DenseVector grad, DenseVector weight) { + // int dim = weight.size(); + // int equalSize = equalityItem.size(); + // int addInequalCount = equalSize; + // for (boolean b : activeSet) { + // if (b) { + // equalSize++; + // } + // } + // int kktSize = calcKktSize(dim, equalityItem, activeSet); + // DenseVector gpGrad = new DenseVector(kktSize); + // matDotVec(hessian, dir, gpGrad, dim); + // vecAddVec(grad, gpGrad, dim); + // + // double[][] matrixData = new double[equalSize][dim]; + // double[] vectorData = new double[equalSize]; + // SqpUtil.fillMatrix(matrixData, 0, 0, equalityConstraint.getArrayCopy2D()); + // System.arraycopy(equalityItem.getData(), 0, vectorData, 0, addInequalCount); + // for (int i = 0; i < activeSet.length; i++) { + // if (activeSet[i]) { + // for (int j = 0; j < addInequalCount; j++) { + // matrixData[addInequalCount][j] = inequalityConstraint.get(i, j); + // } + // vectorData[addInequalCount] = inequalityItem.get(i); + // addInequalCount++; + // } + // } + // try { + // DenseVector dirRes = QpProblem.subProblem(hessian, gpGrad, + // new DenseMatrix(matrixData), new DenseVector(vectorData))[0]; + // double sum = dirRes.normL2() / dim; + // return Tuple2.of(dirRes, sum); + // } catch (Exception e) { + // return Tuple2.of(qpDir, -1.); + // } + // } + + //the main run func. + private static boolean solveActiveSetProblem(DenseMatrix equalityConstraint, DenseVector equalityItem, + DenseMatrix inequalityConstraint, DenseVector inequalityItem, + int dim, DenseVector dir, boolean[] activeSet, DenseMatrix hessian, + DenseVector grad, DenseVector weight) { + int iterTime = inequalityItem.size(); + if (iterTime == 0) { + iterTime = 1; + } + double loss = 0; + double lastLoss = 0; + int kktSize = calcKktSize(dim, equalityItem, activeSet); + DenseVector qpDir = new DenseVector(kktSize); + for (int i = 0; i < iterTime; i++) { + Tuple2 items = solveQuadProblem(equalityConstraint, equalityItem, qpDir, + inequalityConstraint, inequalityItem, + dir, activeSet, hessian, grad, weight); + double p = items.f1; + qpDir = items.f0; + if (p < 0) { + return false; + } + if (p < 1e-6) { + Tuple2 res = checkLambda(equalityConstraint, inequalityConstraint, dim, + activeSet, qpDir); + activeSet = res.f0; + if (!res.f1) { + break; + } + continue; + } + Tuple2 items2 = searchActiveSet(inequalityConstraint, inequalityItem, dir, qpDir, + activeSet, dim); + double alpha = items2.f0; + int activeConst = items2.f1; + if (activeConst >= 0) { + activeSet[activeConst] = true; + } + //初始的dir是全0的,在存在约束条件的时候,会迭代,将dir累加上去。 + for (int j = 0; j < dim; j++) { + dir.add(j, alpha * qpDir.get(j)); + } + loss = calculateQpLoss(hessian, dir, grad, dim); + if (lastLoss != 0) { + double cond = (lastLoss - loss) / Math.abs(loss); + if (cond < 1e-6) { + Tuple2 res = checkLambda(equalityConstraint, inequalityConstraint, dim, + activeSet, qpDir); + activeSet = res.f0; + if (!res.f1) { + break; + } + } + } + lastLoss = loss; + } + return true; + } + + private static void matDotVec(DenseMatrix matrix, DenseVector vector, DenseVector res, int dim) { + + for (int i = 0; i < dim; i++) { + res.set(i, 0); + double[] row = matrix.getRow(i); + for (int j = 0; j < dim; j++) { + res.add(i, row[j] * vector.get(j)); + } + } + } + + public static void vecAddVec(DenseVector dv1, DenseVector dv2, int dim) { + for (int i = 0; i < dim; i++) { + dv2.add(i, dv1.get(i)); + } + } + + private static double dot(DenseVector d1, DenseVector d2, int dim) { + double res = 0; + for (int i = 0; i < dim; i++) { + res += d1.get(i) * d2.get(i); + } + return res; + } + + private static int countInequalNum(boolean[] activeSet) { + int num = 0; + for (boolean b : activeSet) { + if (b) { + num++; + } + } + return num; + } + + private static int calcKktSize(int dim, DenseVector equalityItem, boolean[] activeSet) { + return dim + equalityItem.size() + countInequalNum(activeSet); + } + + public static Tuple3 + calcDir(double retryTime, int dim, ConstraintObjFunc sqpObjFunc, DenseVector dir, + DenseVector weight, DenseMatrix hessian, DenseVector grad, + double l2Weight, double minL2Weight, boolean hasIntercept, boolean[] activeSet) { + DenseMatrix equalityConstraint = sqpObjFunc.equalityConstraint; + DenseMatrix inequalityConstraint = sqpObjFunc.inequalityConstraint; + DenseVector equalityItem = sqpObjFunc.equalityItem; + DenseVector inequalityItem = sqpObjFunc.inequalityItem; + for (int i = 0; i < retryTime; i++) { + boolean pass = solveActiveSetProblem(equalityConstraint, equalityItem, inequalityConstraint, + inequalityItem, + dim, dir, activeSet, hessian, grad, weight); + if (pass) { + break; + } + //update gradient and hessian + int begin = 0; + double l2 = l2Weight + minL2Weight; + if (hasIntercept) { + begin = 1; + } + for (int j = begin; j < dim; j++) { + grad.add(j, l2 * weight.get(j)); + hessian.add(j, j, l2); + } + if (hasIntercept) { + grad.add(0, minL2Weight * weight.get(0)); + hessian.add(0, 0, minL2Weight); + } + minL2Weight *= 10; + + } + if (null == dir) { + throw new RuntimeException("sqp fail to calculate the best dir!"); + } + // loss, dir, grad, hessian + return Tuple3.of(dir, grad, hessian); + } + + public static DenseVector getStartDir(ConstraintObjFunc sqpObjFunc, DenseVector weight, + DenseVector icmBias, DenseVector ecmBias) { + int dim = weight.size(); + DenseMatrix equalityConstraint = sqpObjFunc.equalityConstraint; + DenseMatrix inequalityConstraint = sqpObjFunc.inequalityConstraint; + DenseVector equalityItem = sqpObjFunc.equalityItem; + DenseVector inequalityItem = sqpObjFunc.inequalityItem; + + for (int row = 0; row < inequalityItem.size(); row++) { + double sum = 0; + double[] inequalRow = inequalityConstraint.getRow(row); + for (int col = 0; col < dim; col++) { + sum += weight.get(col) * inequalRow[col]; + } + inequalityItem.set(row, icmBias.get(row) - sum); + } + + for (int row = 0; row < equalityItem.size(); row++) { + double sum = 0; + double[] equalRow = equalityConstraint.getRow(row); + for (int col = 0; col < dim; col++) { + sum += weight.get(col) * equalRow[col]; + } + equalityItem.set(row, ecmBias.get(row) - sum); + } + + return new DenseVector(dim); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/SqpUtil.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/SqpUtil.java new file mode 100644 index 000000000..ddc0335f9 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/SqpUtil.java @@ -0,0 +1,65 @@ +package com.alibaba.alink.operator.common.optim.activeSet; + +import org.apache.flink.api.java.tuple.Tuple2; + +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; + +public class SqpUtil { + + public static void fillMatrix(double[][] targetMatrix, int rowStart, int colStart, double[][] values) { + int rowLength = values.length; + if (rowLength != 0) { + int colLength = values[0].length; + if (colLength != 0) { + for (int i = 0; i < rowLength; i++) { + System.arraycopy(values[i], 0, targetMatrix[rowStart + i], colStart, colLength); + } + } + } + } + + // public static DenseVector slice(DenseVector dv, int start, int length) { + // double[] d = dv.getData(); + // double[] e = new double[length]; + // System.arraycopy(d, start, e, 0, length); + // return new DenseVector(e); + // } + + public static DenseMatrix concatMatrixRow(DenseMatrix equalMatrix, DenseMatrix inequalMatrix) { + int r1 = equalMatrix.numRows(); + int r2 = inequalMatrix.numRows(); + int c = equalMatrix.numCols(); + double[][] res = new double[r1 + r2][c]; + fillMatrix(res, 0, 0, equalMatrix.getArrayCopy2D()); + fillMatrix(res, r1, 0, inequalMatrix.getArrayCopy2D()); + return new DenseMatrix(res); + } + + public static DenseVector generateMaxVector(DenseVector x, double value) { + double[] data = x.getData(); + int length = data.length; + for (int i = 0; i < length; i++) { + data[i] = Math.max(data[i], value); + } + return x; + } + + public static DenseVector copyVec(DenseVector vector, int start, int length) { + double[] res = new double[length]; + System.arraycopy(vector.getData(), start, res, 0, length); + return new DenseVector(res); + } + + public static Tuple2 findMin(DenseVector vector, int start, int length) { + double min = Double.MAX_VALUE; + int index = 0; + for (int i = start; i < (start + length); i++) { + if (vector.get(i) < min) { + min = vector.get(i); + index = i; + } + } + return Tuple2.of(min, index - start); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/SqpVariable.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/SqpVariable.java new file mode 100644 index 000000000..71b420153 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/activeSet/SqpVariable.java @@ -0,0 +1,6 @@ +package com.alibaba.alink.operator.common.optim.activeSet; + +public class SqpVariable extends ConstraintVariable { + public final static String icmBias = "icmBias"; + public final static String ecmBias = "ecmBias"; +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/barrierIcq/BarrierData.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/barrierIcq/BarrierData.java new file mode 100644 index 000000000..26f17fba9 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/barrierIcq/BarrierData.java @@ -0,0 +1,6 @@ +package com.alibaba.alink.operator.common.optim.barrierIcq; + +public class BarrierData { + public double loss; + public double lastLoss; +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/barrierIcq/BarrierOpt.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/barrierIcq/BarrierOpt.java new file mode 100644 index 000000000..f16010ff2 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/barrierIcq/BarrierOpt.java @@ -0,0 +1,21 @@ +package com.alibaba.alink.operator.common.optim.barrierIcq; + +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; + +public class BarrierOpt { + public boolean hasConst; + public int hessianDim; + public int dimension;//opt.dimension + public boolean hasIntercept; + public final double minL2Weight = 1e-5; + public DenseVector icb; + public DenseMatrix icm; + public DenseMatrix ecm; + public DenseVector ecb; + public boolean phaseOne; + public double l2Weight; + public BarrierData data; + public final double convergence_tolerance = 1e-6; + public final int lineSearchRetryTimes = 40; +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/barrierIcq/BarrierVariable.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/barrierIcq/BarrierVariable.java new file mode 100644 index 000000000..d2ff4d994 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/barrierIcq/BarrierVariable.java @@ -0,0 +1,10 @@ +package com.alibaba.alink.operator.common.optim.barrierIcq; + +import com.alibaba.alink.operator.common.optim.activeSet.ConstraintVariable; + +public class BarrierVariable extends ConstraintVariable { + public final static String t = "t"; + public final static String divideT = "divideT"; + public final static String hessianNotConvergence = "hessianNotConvergence"; + public final static String localIterTime = "localIterTime"; +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/barrierIcq/LogBarrier.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/barrierIcq/LogBarrier.java new file mode 100644 index 000000000..25fa6f458 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/barrierIcq/LogBarrier.java @@ -0,0 +1,316 @@ +package com.alibaba.alink.operator.common.optim.barrierIcq; + +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.comqueue.ComContext; +import com.alibaba.alink.common.comqueue.CompareCriterionFunction; +import com.alibaba.alink.common.comqueue.ComputeFunction; +import com.alibaba.alink.common.comqueue.IterativeComQueue; +import com.alibaba.alink.common.comqueue.communication.AllReduce; +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.operator.common.optim.Optimizer; +import com.alibaba.alink.operator.common.optim.activeSet.ConstraintObjFunc; +import com.alibaba.alink.operator.common.optim.activeSet.ConstraintVariable; +import com.alibaba.alink.operator.common.optim.activeSet.Sqp; +import com.alibaba.alink.operator.common.optim.activeSet.SqpPai; +import com.alibaba.alink.operator.common.optim.activeSet.SqpUtil; +import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc; +import com.alibaba.alink.operator.common.optim.subfunc.OptimVariable; +import com.alibaba.alink.operator.common.optim.subfunc.PreallocateMatrix; +import com.alibaba.alink.operator.common.optim.subfunc.PreallocateVector; +import com.alibaba.alink.params.shared.iter.HasMaxIterDefaultAs100; +import com.alibaba.alink.params.shared.linear.HasEpsilonDefaultAs0000001; +import com.alibaba.alink.params.shared.linear.HasL2; +import com.alibaba.alink.params.shared.linear.HasWithIntercept; + +import java.util.List; + +//https://www.stat.cmu.edu/~ryantibs/convexopt-S15/scribes/15-barr-method-scribed.pdf +public class LogBarrier extends Optimizer { + private static final int MAX_FEATURE_NUM = 3000; + + /** + * construct function. + * + * @param objFunc object function, calc loss and grad. + * @param trainData data for training. + * @param coefDim the dimension of features. + * @param params some parameters of optimization method. + */ + public LogBarrier(DataSet objFunc, DataSet > trainData, + DataSet coefDim, Params params) { + super(objFunc, trainData, coefDim, params); + } + + @Override + public DataSet > optimize() { + //use glpk + //in barrier, maxIter is the local parameter in one loop. + this.coefVec = this.coefDim.map(new Sqp.InitialCoef()).withBroadcastSet(this.objFuncSet, "objFunc"); + int maxIter = params.get(HasMaxIterDefaultAs100.MAX_ITER); + DataSet model = new IterativeComQueue() + .initWithPartitionedData(OptimVariable.trainData, trainData) + .initWithBroadcastData(OptimVariable.model, coefVec) + .initWithBroadcastData(OptimVariable.objFunc, objFuncSet) + .initWithBroadcastData(ConstraintVariable.weightDim, coefDim) + .add(new InitializeParams()) + .add(new PreallocateVector(OptimVariable.grad, new double[2]))//initial grad + .add(new PreallocateMatrix(OptimVariable.hessian, MAX_FEATURE_NUM))//initial hessian + .add(new Sqp.CalcGradAndHessian())//update grad and hessian with up-to-date coef. + .add(new AllReduce(OptimVariable.gradHessAllReduce)) + .add(new Sqp.GetGradientAndHessian()) + .add(new RunNewtonStep(params))//add constraint and solve the dir + .add(new Sqp.LineSearch(params.get(HasL2.L_2))) + .add(new AllReduce(ConstraintVariable.lossAllReduce)) + .add(new Sqp.GetMinCoef()) + .add(new CalcConvergence()) + .setCompareCriterionOfNode0(new IterTermination(maxIter, params.get(HasEpsilonDefaultAs0000001.EPSILON))) + .closeWith(new Sqp.BuildModel()) + .exec(); + return model.mapPartition(new Sqp.ParseRowModel()); + } + + public static class InitializeParams extends ComputeFunction { + + private static final long serialVersionUID = 1857803292287152190L; + + @Override + public void calc(ComContext context) { + if (context.getStepNo() == 1) { + ConstraintObjFunc objFunc = (ConstraintObjFunc) ((List ) context.getObj( + OptimVariable.objFunc)).get(0); + double t = objFunc.inequalityConstraint.numRows(); + double divideT; + if (t == 0) { + divideT = 0; + } else { + divideT = 1 / t; + } + context.putObj(BarrierVariable.t, t); + context.putObj(BarrierVariable.divideT, divideT); + context.putObj(BarrierVariable.localIterTime, 0); + context.putObj(BarrierVariable.hessianNotConvergence, false); + context.putObj(ConstraintVariable.newtonRetryTime, 12); + context.putObj(ConstraintVariable.minL2Weight, 1e-8); + context.putObj(ConstraintVariable.linearSearchTimes, 40); + DenseVector weight = ((List ) context.getObj(OptimVariable.model)).get(0); + context.putObj(ConstraintVariable.weight, weight); + context.putObj(ConstraintVariable.loss, 0.0); + context.putObj(ConstraintVariable.lastLoss, Double.MAX_VALUE); + } + } + } + + public static class RunNewtonStep extends ComputeFunction { + private static final long serialVersionUID = 4802057437164571355L; + private boolean hasIntercept; + private double l2; + + public RunNewtonStep(Params params) { + hasIntercept = params.get(HasWithIntercept.WITH_INTERCEPT); + this.l2 = params.get(HasL2.L_2); + } + + @Override + public void calc(ComContext context) { + ConstraintObjFunc objFunc = (ConstraintObjFunc) ((List ) context.getObj( + OptimVariable.objFunc)).get(0); + int hessianDim = ((List ) context.getObj(ConstraintVariable.weightDim)).get(0); + int begin = 0; + if (hasIntercept) { + begin = 1; + } + double minL2Weight = context.getObj(ConstraintVariable.minL2Weight); + Double loss = context.getObj(ConstraintVariable.loss); + Tuple2 grad = context.getObj(OptimVariable.grad); + DenseVector gradient = grad.f0; + DenseMatrix hessian = context.getObj(ConstraintVariable.hessian); + DenseVector weight = context.getObj(ConstraintVariable.weight); + final int retryTime = context.getObj(ConstraintVariable.newtonRetryTime); + double t = context.getObj(BarrierVariable.t); + int constraintNum = objFunc.equalityItem.size() + objFunc.inequalityItem.size(); + addInequalityConstraint(objFunc, constraintNum, gradient, weight, hessian, t); + for (int j = 0; j < retryTime; j++) { + try { + int hSize = objFunc.equalityItem.size() + hessianDim; + DenseVector g = new DenseVector(hSize); + SqpPai.vecAddVec(gradient, g, hessianDim); + DenseMatrix h = new DenseMatrix( + buildH(hessian, weight, g, objFunc.equalityConstraint, objFunc.equalityItem)); + double norm = 1 / g.normL1(); + h.scaleEqual(norm); + g.scaleEqual(norm); + // DenseVector dir = QpProblem.subProblem(hessian, gradient, objFunc + // .equalityConstraint, objFunc.equalityItem)[0]; + DenseVector dir = new DenseVector(hessianDim); + // System.out.println(h); + // System.out.println(g); + SqpPai.vecAddVec(h.inverse().multiplies(g), dir, hessianDim); + context.putObj(ConstraintVariable.dir, dir); + // System.out.println("dir_ori: "+dir); + break; + } catch (Exception e) { + double l2Weight = l2 + minL2Weight; + for (int i = begin; i < hessianDim; i++) { + loss += 0.5 * l2Weight * Math.pow(weight.get(i), 2); + gradient.add(i, l2Weight * weight.get(i)); + hessian.add(i, i, l2Weight); + } + if (hasIntercept) { + loss += 0.5 * minL2Weight * Math.pow(weight.get(0), 2); + gradient.add(0, minL2Weight * weight.get(0)); + hessian.add(0, 0, minL2Weight); + } + minL2Weight *= 10; + } + } + grad.f0 = gradient; + context.putObj(ConstraintVariable.grad, grad); + context.putObj(ConstraintVariable.hessian, hessian); + loss = constrainedLoss(loss, weight, objFunc, t, constraintNum); + context.putObj(ConstraintVariable.loss, loss); + } + + private static double[][] buildH(DenseMatrix hessian, DenseVector weight, DenseVector g, + DenseMatrix equalityConstraint, DenseVector equalityItem) { + int equalNum = equalityItem.size(); + int dim = weight.size(); + int hSize = equalNum + dim; + double[][] h = new double[hSize][hSize]; + double[][] hessianArray = hessian.getArrayCopy2D(); + SqpUtil.fillMatrix(h, 0, 0, hessianArray); + for (int i = 0; i < equalNum; i++) { + int row = i + dim; + double sum = 0; + for (int j = 0; j < dim; j++) { + h[row][j] = equalityConstraint.get(i, j); + h[j][row] = equalityConstraint.get(i, j); + sum += equalityConstraint.get(i, j) * weight.get(j); + } + g.set(row, sum - equalityItem.get(i)); + } + + return h; + } + + private static double constrainedLoss(double loss, DenseVector weight, ConstraintObjFunc objFunc, + double t, double constraintNum) { + if (constraintNum == 0) { + return loss; + } + int numRow = objFunc.inequalityConstraint.numRows(); + for (int i = 0; i < numRow; i++) { + loss -= t * Math.log(sumInequality(objFunc.inequalityConstraint, objFunc.inequalityItem, weight, i)); + } + return loss; + } + + private static double sumInequality(DenseMatrix icm, DenseVector icb, DenseVector w, int row) { + double[] wData = w.getData(); + double s = icb.get(row); + int colNum = icm.numCols(); + for (int i = 0; i < colNum; i++) { + s -= wData[i] * icm.get(row, i); + } + if (s == 0) { + s = 1e-6; + } + return s; + } + + private static void addInequalityConstraint(ConstraintObjFunc objFunc, int constraintNum, DenseVector gradient, + DenseVector weight, DenseMatrix hessian, double t) { + if (constraintNum == 0) { + return; + } + updateGradForInequalityConstraint(objFunc, gradient, weight, hessian, 1 / t); + } + + private static void updateGradForInequalityConstraint(ConstraintObjFunc objFunc, DenseVector gradient, + DenseVector weight, DenseMatrix hessian, double t) { + DenseMatrix icm = objFunc.inequalityConstraint; + int rowNum = icm.numRows(); + int colNum = icm.numCols(); + for (int k = 0; k < rowNum; k++) { + double sum = sumInequality(icm, objFunc.inequalityItem, weight, k); + for (int i = 0; i < colNum; i++) { + double val = icm.get(k, i); + gradient.add(i, t * val / sum); + } + } + for (int k = 0; k < rowNum; k++) { + double sum = Math.pow(sumInequality(icm, objFunc.inequalityItem, weight, k), 2); + for (int i = 0; i < colNum; i++) { + double val1 = icm.get(k, i); + for (int j = 0; j < colNum; j++) { + double val2 = icm.get(k, j); + hessian.add(i, j, t * val1 * val2 / sum); + } + } + } + } + + } + + public static class CalcConvergence extends ComputeFunction { + + private static final long serialVersionUID = 4453719204627742833L; + + @Override + public void calc(ComContext context) { + int iter = context.getObj(BarrierVariable.localIterTime); + double convergence; + if (iter == 0) { + convergence = 100; + } else { + double lastLoss = context.getObj(ConstraintVariable.lastLoss); + double loss = context.getObj(ConstraintVariable.loss); + int lossStep = 5; + if (iter <= lossStep) { + convergence = (lastLoss - loss) / (Math.abs(loss) * iter); + } else { + convergence = (lastLoss - loss) / (Math.abs(loss) * lossStep); + } + } + context.putObj(BarrierVariable.localIterTime, iter + 1); + // if (context.getTaskId() == 0) { + // System.out.println("iter: " + iter + ", convergence: " + convergence); + // } + context.putObj(ConstraintVariable.convergence, convergence); + } + } + + public static class IterTermination extends CompareCriterionFunction { + private static final long serialVersionUID = -3313706116254321792L; + private int maxIter; + private double epsilon; + + IterTermination(int maxIter, double epsilon) { + this.maxIter = maxIter; + this.epsilon = epsilon; + } + + @Override + public boolean calc(ComContext context) { + double convergence = context.getObj(ConstraintVariable.convergence); + if (convergence < this.epsilon || + (int) context.getObj(BarrierVariable.localIterTime) >= maxIter) { + context.putObj(BarrierVariable.localIterTime, 0); + double t = context.getObj(BarrierVariable.t); + double divideT = context.getObj(BarrierVariable.divideT); + t *= 50; + divideT /= 50; + context.putObj(BarrierVariable.t, t); + context.putObj(BarrierVariable.divideT, divideT); + return divideT < this.epsilon; + } + return false; + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/local/ConstrainedLocalOptimizer.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/local/ConstrainedLocalOptimizer.java new file mode 100644 index 000000000..344e61737 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/local/ConstrainedLocalOptimizer.java @@ -0,0 +1,209 @@ +package com.alibaba.alink.operator.common.optim.local; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.operator.common.linear.BaseConstrainedLinearModelTrainBatchOp; +import com.alibaba.alink.operator.common.linear.LinearModelType; +import com.alibaba.alink.operator.common.optim.FeatureConstraint; +import com.alibaba.alink.operator.common.optim.activeSet.ConstraintObjFunc; +import com.alibaba.alink.operator.common.optim.activeSet.Sqp; +import com.alibaba.alink.params.feature.HasConstraint; +import com.alibaba.alink.params.shared.linear.LinearTrainParams; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class ConstrainedLocalOptimizer { + + //pass the param information into this function. + public static Tuple4 optimizeWithHessian( + List > trainData, + LinearModelType modelType, + Params params) { + int dim = trainData.get(0).f2.size();//the intercept has been added. + ConstraintObjFunc objFunc = (ConstraintObjFunc) BaseConstrainedLinearModelTrainBatchOp.getObjFunction(modelType, null); + String constraint = ""; + if (params.contains(HasConstraint.CONSTRAINT)) { + constraint = params.get(HasConstraint.CONSTRAINT); + } + extractConstraintsForFeatureAndBin( + FeatureConstraint.fromJson(constraint), + objFunc, null, dim, true, null, null); + double[] coefData = Sqp.phaseOne( + objFunc.equalityConstraint.getArrayCopy2D(), objFunc.equalityItem.getData(), + objFunc.inequalityConstraint.getArrayCopy2D(), objFunc.inequalityItem.getData(), dim); + DenseVector coef = new DenseVector(coefData); + Tuple4 coefVector = + LocalSqp.sqpWithHessian(trainData, coef, objFunc, params); + return coefVector; + } + + //count zero is only for default constraint of else and null + public static void extractConstraintsForFeatureAndBin(FeatureConstraint constraint, ConstraintObjFunc objFunc, + String[] featureColNames, int dim, boolean hasInterceptItem, + DenseVector countZero, Map hasElse) { + int size = constraint.getBinConstraintSize(); + Tuple4 cons; + if (size != 0) { + //for bin, this is only for bin. + //todo:have not consider for features and bins. + constraint.setCountZero(countZero); + cons = constraint.getConstraintsForFeatureWithBin(); + } else { + if (featureColNames == null) { + //for feature in vector form. + //if set + if (hasInterceptItem) { + dim -= 1; + } + cons = constraint.getConstraintsForFeatures(dim); + } else { + //for feature in table form. + HashMap featureIndex = new HashMap <>(featureColNames.length); + for (int i = 0; i < featureColNames.length; i++) { + featureIndex.put(featureColNames[i], i); + } + cons = constraint.getConstraintsForFeatures(featureIndex); + } + } + if (hasInterceptItem) { + addIntercept(cons); + } + objFunc.inequalityConstraint = new DenseMatrix(cons.f0); + objFunc.inequalityItem = new DenseVector(cons.f1); + objFunc.equalityConstraint = new DenseMatrix(cons.f2); + objFunc.equalityItem = new DenseVector(cons.f3); + } + + private static void addIntercept(Tuple4 cons) { + cons.f0 = prefixMatrix(cons.f0); + cons.f2 = prefixMatrix(cons.f2); + } + + private static double[][] prefixMatrix(double[][] matrix) { + int row = matrix.length; + if (row == 0) {return matrix;} + int col = matrix[0].length; + for (int i = 0; i < row; i++) { + matrix[i] = prefixRow(matrix[i], col); + } + return matrix; + } + + private static double[] prefixRow(double[] row, int length) { + double[] r = new double[length + 1]; + if (length >= 0) {System.arraycopy(row, 0, r, 1, length);} + return r; + } + + // public Tuple4 trainWithHessian(List> originTrainData, + // LinearModelType modelType, + // Params params) { + // boolean hasIntercept = params.get(LinearTrainParams.WITH_INTERCEPT); + // boolean standardization = params.get(LinearTrainParams.STANDARDIZATION); + // if (standardization && constraints != null) { + // throw new RuntimeException("standardization can not be applied for linear model with constraints!"); + // } + // + // boolean lr = !modelType.equals(LinearModelType.LinearReg); + // Tuple2>, String[]> trainDataItem = LocalLogistRegression + // .getLabelValues(originTrainData, lr); + // List> trainData = trainDataItem.f0; + // preProcess(trainData, hasIntercept, standardization); + // //labels[0] is the positive label. + // String[] labels = trainDataItem.f1; + // Tuple4 coefVectorSet = optimizeWithHessian(trainData, + // modelType, params); + // return coefVectorSet; + // } + + // public DenseVector train(List> originTrainData, + // LinearModelType modelType, Params params) { + // return trainWithHessian(originTrainData, modelType, params).f0; + // } + + // public Tuple4 trainWithHessian(List trainData, + // List sampleWeight, + // List label, + // int[] indices,//针对table的输入而言。 + // LinearModelType modelType, + // OptimMethod optimMethod, + // boolean hasIntercept, + // boolean standardization,//先不管 + // double l1,//先不考虑 + // double l2) { + // Params optParams = new Params() + // .set(LinearTrainParams.CONS_SEL_OPTIM_METHOD, optimMethod.name()) + // .set(LinearTrainParams.WITH_INTERCEPT, hasIntercept) + // .set(LinearTrainParams.STANDARDIZATION, standardization) + // .set(HasL1.L_1, l1) + // .set(HasL2.L_2, l2); + // List> vectorData = concatTrainData(trainData, sampleWeight, label, + // indices); + // return trainWithHessian(vectorData, modelType, optParams); + // } + + // + // private static List> concatTrainData(List trainData, + // List sampleWeight, + // List label, + // int[] indices) { + // int sampleNum = label.size(); + // int dim = indices.length; + // List> res = new ArrayList<>(sampleNum); + // for (int i = 0; i < sampleNum; i++) { + // DenseVector data = new DenseVector(dim); + // for (int j = 0; j < dim; j++) { + // data.set(j, (double) trainData.get(i).getField(indices[j])); + // res.set(i, Tuple3.of(sampleWeight.get(i), label.get(i), data)); + // } + // } + // return res; + // } + + @Deprecated + public static void preProcess(List > trainData, boolean hasIntercept, + boolean standardization) { + for (Tuple3 data : trainData) { + if (standardization) { + //to add + if (hasIntercept) { + data.f2 = data.f2.prefix(1); + } + } else { + //to add + if (hasIntercept) { + data.f2 = data.f2.prefix(1); + } + } + } + + } + + //f0 of data is weight, here replace it with predict label. + public List > predict(List > data, + LinearModelType modelType, + DenseVector coef, Params params) { + boolean hasIntercept = params.get(LinearTrainParams.WITH_INTERCEPT); + boolean standardization = params.get(LinearTrainParams.STANDARDIZATION); + preProcess(data, hasIntercept, standardization); + if (modelType.equals(LinearModelType.LinearReg)) { + for (Tuple3 value : data) { + value.f0 = value.f2.dot(coef); + } + } else { + for (Tuple3 value : data) { + value.f0 = value.f2.dot(coef) > 0 ? 1. : 0.; + } + } + return data; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/local/LocalSqp.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/local/LocalSqp.java new file mode 100644 index 000000000..cdfde5a0f --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/local/LocalSqp.java @@ -0,0 +1,190 @@ +package com.alibaba.alink.operator.common.optim.local; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.linalg.BLAS; +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.operator.common.optim.activeSet.ConstraintObjFunc; +import com.alibaba.alink.operator.common.optim.activeSet.SqpPai; +import com.alibaba.alink.operator.common.optim.objfunc.OptimObjFunc; +import com.alibaba.alink.params.shared.linear.HasL1; +import com.alibaba.alink.params.shared.linear.HasL2; +import com.alibaba.alink.params.shared.linear.LinearTrainParams; + +import java.util.List; + +public class LocalSqp { + public static Tuple2 sqp(List > labledVectors, + DenseVector initCoefs, + Params params, OptimObjFunc objFunc) { + Tuple4 hessianItem = sqpWithHessian(labledVectors, initCoefs, + objFunc, params); + return Tuple2.of(hessianItem.f0, hessianItem.f3); + } + + static Tuple4 + sqpWithHessian(List > labledVectors, + DenseVector initCoefs, OptimObjFunc objFunc, + Params params) { + double l2 = params.get(HasL2.L_2); + double l1 = params.get(HasL1.L_1); + boolean hasIntercept = params.get(LinearTrainParams.WITH_INTERCEPT); + int dim = initCoefs.size(); + DenseVector weight = initCoefs; + DenseMatrix hessian = new DenseMatrix(dim, dim); + DenseVector grad = new DenseVector(dim); + double loss = 0; + double lastLoss = -1; + //update grad and hessian + ConstraintObjFunc sqpObjFunc = (ConstraintObjFunc) objFunc; + final int retryTime = 12; + final double minL2Weight = 1e-8; + final int linearSearchTimes = 40; + //store constraint bias + if (sqpObjFunc.inequalityConstraint == null) { + sqpObjFunc.inequalityConstraint = new DenseMatrix(0, dim); + sqpObjFunc.inequalityItem = new DenseVector(0); + } + if (sqpObjFunc.equalityConstraint == null) { + sqpObjFunc.equalityConstraint = new DenseMatrix(0, dim); + sqpObjFunc.equalityItem = new DenseVector(0); + } + DenseVector icmBias = sqpObjFunc.inequalityItem.clone(); + DenseVector ecmBias = sqpObjFunc.equalityItem.clone(); + + //sqp iteration + for (int sqpIter = 0; sqpIter < 100; sqpIter++) { + double weightSumCoef = 1 / sqpObjFunc.calcHessianGradientLoss(labledVectors, weight, hessian, grad).f0; + grad.scaleEqual(weightSumCoef); + hessian.scaleEqual(weightSumCoef); + //initial for each iteration + DenseVector dir = SqpPai.getStartDir(sqpObjFunc, weight, icmBias, ecmBias); + boolean[] activeSet = SqpPai.getActiveSet(sqpObjFunc.inequalityConstraint, sqpObjFunc.inequalityItem, dir, + dim); + //get the current best direction + Tuple3 dirItems = + SqpPai.calcDir(retryTime, dim, sqpObjFunc, dir, weight, + hessian, grad, l2, minL2Weight, hasIntercept, activeSet); + dir = dirItems.f0; + grad = dirItems.f1; + hessian = dirItems.f2; + //linear search + double[] losses = sqpObjFunc.calcLineSearch(labledVectors, weight, dir, linearSearchTimes, + l2 + minL2Weight); + loss = lineSearch(losses, weight, grad, dir); + //check convergence + if (sqpIter != 0) { + lastLoss = loss; + } + //restore item + sqpObjFunc.inequalityItem = icmBias; + sqpObjFunc.equalityItem = ecmBias; + double convergence = 100; + if (sqpIter != 0) { + int lossStep = 5; + if (sqpIter <= lossStep) { + convergence = (lastLoss - loss) / (Math.abs(loss) * sqpIter); + } else { + convergence = (lastLoss - loss) / (Math.abs(loss) * lossStep); + } + } + if (convergence <= 1e-6) { + break; + } + } + return Tuple4.of(weight, grad, hessian, loss); + } + + public static Double lineSearch(double[] losses, DenseVector weight, DenseVector grad, DenseVector dir) { + double beta = 1e-4; + double alpha = 1; + double backOff = 0.1; + int i = 1; + int index = 0; + int size = losses.length; + int origin = size / 2; + int retryTime = origin - 1; + double gd = -BLAS.dot(grad, dir); + double betaTemp = beta * gd * alpha; + boolean brea = false; + for (i = 1; i < retryTime; i++) { + index = origin - i; + if (losses[index] <= losses[origin] + betaTemp) { + brea = true; + break; + } + // index = origin + i; + // if (losses[index] <= losses[origin] + betaTemp) { + // brea = true; + // break; + // } + betaTemp *= backOff; + } + if (!brea) { + betaTemp = beta * gd * alpha; + for (i = 1; i < retryTime; i++) { + index = origin + i; + if (losses[index] <= losses[origin] + betaTemp) { + break; + } + betaTemp *= backOff; + } + } + if (index < origin) { + weight.minusEqual(dir.scale(Math.pow(10, 1 - i))); + } else { + weight.plusEqual(dir.scale(Math.pow(10, 1 - i))); + } + return losses[index]; + } + + // public static Tuple3 calcDir( + // double retryTime, int dim, ConstraintObjFunc sqpObjFunc, DenseVector weight, + // DenseMatrix hessian, DenseVector grad, double loss, + // double l2Weight, double minL2Weight, boolean hasIntercept) { + // DenseVector dir = null; + // DenseMatrix equalityConstraint = sqpObjFunc.equalityConstraint; + // DenseMatrix inequalityConstraint = sqpObjFunc.inequalityConstraint; + // DenseVector equalityItem = sqpObjFunc.equalityItem; + // DenseVector inequalityItem = sqpObjFunc.inequalityItem; + // for (int i = 0; i < retryTime; i++) { + // + // DenseVector x0 = new DenseVector(dim); + // try { + // dir = QpProblem.qpact(hessian, grad, + // equalityConstraint, equalityItem, + // inequalityConstraint, inequalityItem, + // x0).f0; + // break; + // } catch (Exception e) { + // //update gradient and hessian + // int begin = 0; + // double l2 = l2Weight + minL2Weight; + // if (hasIntercept) { + // begin = 1; + // } + // for (int j = begin; j < dim; j++) { + //// loss += 0.5 * l2 * Math.pow(weight.get(j), 2); + // grad.add(j, l2 * weight.get(j)); + // hessian.add(j, j, l2); + // } + // if (hasIntercept) { + //// loss += 0.5 * minL2Weight * Math.pow(weight.get(0), 2); + // grad.add(0, minL2Weight * weight.get(0)); + // hessian.add(0, 0, minL2Weight); + // } + // minL2Weight *= 10; + // } + // } + // if (null == dir) { + // throw new RuntimeException("sqp fail to calculate the best dir!"); + // } + // // loss, dir, grad, hessian + // return Tuple3.of(dir, grad, hessian); + // } +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/objfunc/OptimObjFunc.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/objfunc/OptimObjFunc.java index f2d1a4c4e..87f2eceb1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/optim/objfunc/OptimObjFunc.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/objfunc/OptimObjFunc.java @@ -4,16 +4,29 @@ import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.ml.api.misc.param.Params; +import com.alibaba.alink.common.exceptions.AkUnimplementedOperationException; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import com.alibaba.alink.common.linalg.DenseMatrix; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.MatVecOp; import com.alibaba.alink.common.linalg.SparseVector; import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.operator.common.linear.AftRegObjFunc; +import com.alibaba.alink.operator.common.linear.LinearModelType; +import com.alibaba.alink.operator.common.linear.SoftmaxObjFunc; +import com.alibaba.alink.operator.common.linear.UnaryLossObjFunc; +import com.alibaba.alink.operator.common.linear.unarylossfunc.LogLossFunc; +import com.alibaba.alink.operator.common.linear.unarylossfunc.PerceptronLossFunc; +import com.alibaba.alink.operator.common.linear.unarylossfunc.SmoothHingeLossFunc; +import com.alibaba.alink.operator.common.linear.unarylossfunc.SquareLossFunc; +import com.alibaba.alink.operator.common.linear.unarylossfunc.SvrLossFunc; +import com.alibaba.alink.params.regression.LinearSvrTrainParams; import com.alibaba.alink.params.shared.linear.HasL1; import com.alibaba.alink.params.shared.linear.HasL2; import java.io.Serializable; +import java.util.Arrays; +import java.util.List; /** * Abstract object function for optimization. This class provides the function api to calculate gradient, loss, hessian @@ -56,8 +69,8 @@ public double getL2() { * @param coefVector coefficient of current time. * @return the loss value */ - protected abstract double calcLoss(Tuple3 labelVector, - DenseVector coefVector); + public abstract double calcLoss(Tuple3 labelVector, + DenseVector coefVector); /** * Update gradient. @@ -66,9 +79,9 @@ protected abstract double calcLoss(Tuple3 labelVector, * @param coefVector coefficient of current time. * @param updateGrad gradient need to update. */ - protected abstract void updateGradient(Tuple3 labelVector, - DenseVector coefVector, - DenseVector updateGrad); + public abstract void updateGradient(Tuple3 labelVector, + DenseVector coefVector, + DenseVector updateGrad); /** * Update hessian matrix by one sample. @@ -77,7 +90,7 @@ protected abstract void updateGradient(Tuple3 labelVect * @param coefVector coefficient of current time. * @param updateHessian hessian matrix need to update. */ - protected abstract void updateHessian(Tuple3 labelVector, + public abstract void updateHessian(Tuple3 labelVector, DenseVector coefVector, DenseMatrix updateHessian); @@ -105,6 +118,16 @@ public Tuple2 calcObjValue( fVal += loss * labelVector.f0; weightSum += labelVector.f0; } + return finalizeObjValue(coefVector, fVal, weightSum); + } + + /** + * Calculate object value. + * + * @param coefVector coefficient of current time. + * @return Tuple2: objectValue, weightSum. + */ + public Tuple2 finalizeObjValue(DenseVector coefVector, double fVal, double weightSum) { if (0.0 != weightSum) { fVal /= weightSum; } @@ -128,9 +151,7 @@ public Tuple2 calcObjValue( public double calcGradient(Iterable > labelVectors, DenseVector coefVector, DenseVector grad) { double weightSum = 0.0; - for (int i = 0; i < grad.size(); i++) { - grad.set(i, 0.0); - } + Arrays.fill(grad.getData(), 0.0); for (Tuple3 labelVector : labelVectors) { if (labelVector.f2 instanceof SparseVector) { ((SparseVector) (labelVector.f2)).setSize(coefVector.size()); @@ -138,6 +159,13 @@ public double calcGradient(Iterable > labelVecto weightSum += labelVector.f0; updateGradient(labelVector, coefVector, grad); } + + finalizeGradient(coefVector, grad, weightSum); + + return weightSum; + } + + public void finalizeGradient(DenseVector coefVector, DenseVector grad, double weightSum) { if (weightSum > 0.0) { grad.scaleEqual(1.0 / weightSum); } @@ -150,7 +178,7 @@ public double calcGradient(Iterable > labelVecto grad.add(i, Math.signum(coefArray[i]) * this.l1); } } - return weightSum; + } /** @@ -167,14 +195,9 @@ public Tuple2 calcHessianGradientLoss(Iterable labledVector : labelVectors) { @@ -183,20 +206,9 @@ public Tuple2 calcHessianGradientLoss(Iterable calcHessianGradientLoss(Iterable calcHessianGradientLoss(Iterable > labelVectors, DenseVector coefVector, DenseVector dirVec, double beta, int numStep) { double[] losses = new double[numStep + 1]; - DenseVector[] stepVec = new DenseVector[numStep + 1]; stepVec[0] = coefVector.clone(); DenseVector vecDelta = dirVec.scale(beta); @@ -243,7 +272,7 @@ public double[] calcSearchValues(Iterable > labe * @return double[] losses. */ public double[] constraintCalcSearchValues( - Iterable > labelVectors, + List > labelVectors, DenseVector coefVector, DenseVector dirVec, double beta, int numStep) { double[] losses = new double[numStep + 1]; double[] coefArray = coefVector.getData(); @@ -266,4 +295,44 @@ public double[] constraintCalcSearchValues( } return losses; } + + /** + * Get obj function. + * + * @param modelType Model type. + * @param params Parameters for train. + * @return Obj function. + */ + public static OptimObjFunc getObjFunction(LinearModelType modelType, Params params) { + OptimObjFunc objFunc; + // For different model type, we must set corresponding loss object function. + switch (modelType) { + case LinearReg: + objFunc = new UnaryLossObjFunc(new SquareLossFunc(), params); + break; + case SVR: + double svrTau = params.get(LinearSvrTrainParams.TAU); + objFunc = new UnaryLossObjFunc(new SvrLossFunc(svrTau), params); + break; + case LR: + objFunc = new UnaryLossObjFunc(new LogLossFunc(), params); + break; + case SVM: + objFunc = new UnaryLossObjFunc(new SmoothHingeLossFunc(), params); + break; + case Perceptron: + objFunc = new UnaryLossObjFunc(new PerceptronLossFunc(), params); + break; + case AFT: + objFunc = new AftRegObjFunc(params); + break; + case Softmax: + objFunc = new SoftmaxObjFunc(params); + break; + default: + throw new AkUnimplementedOperationException("Linear model type is Not implemented yet!"); + } + return objFunc; + } + } \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/subfunc/CalcLosses.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/subfunc/CalcLosses.java index efbb57068..1c19573ec 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/optim/subfunc/CalcLosses.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/subfunc/CalcLosses.java @@ -39,7 +39,7 @@ public CalcLosses(OptimMethod method, int numSearchStep) { @Override public void calc(ComContext context) { - Iterable > labledVectors = context.getObj(OptimVariable.trainData); + List > labledVectors = context.getObj(OptimVariable.trainData); Tuple2 dir = context.getObj(OptimVariable.dir); Tuple2 coef = context.getObj(OptimVariable.currentCoef); if (objFunc == null) { diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseModelOutlierPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseModelOutlierPredictBatchOp.java index 613df2244..6e8811090 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseModelOutlierPredictBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseModelOutlierPredictBatchOp.java @@ -5,11 +5,13 @@ import org.apache.flink.util.function.TriFunction; import com.alibaba.alink.common.annotation.Internal; +import com.alibaba.alink.common.annotation.NameCn; import com.alibaba.alink.common.mapper.ModelMapper; import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; import com.alibaba.alink.params.outlier.ModelOutlierParams; @Internal +@NameCn("异常检测基类") public class BaseModelOutlierPredictBatchOp> extends ModelMapBatchOp implements ModelOutlierParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseModelOutlierPredictStreamOp.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseModelOutlierPredictStreamOp.java index fd8f77f7a..e31af84cc 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseModelOutlierPredictStreamOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseModelOutlierPredictStreamOp.java @@ -5,12 +5,14 @@ import org.apache.flink.util.function.TriFunction; import com.alibaba.alink.common.annotation.Internal; +import com.alibaba.alink.common.annotation.NameCn; import com.alibaba.alink.common.mapper.ModelMapper; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.stream.utils.ModelMapStreamOp; import com.alibaba.alink.params.outlier.ModelOutlierParams; @Internal +@NameCn("异常检测基类") public class BaseModelOutlierPredictStreamOp> extends ModelMapStreamOp implements ModelOutlierParams { diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseModelOutlierWithSeriesPredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseModelOutlierWithSeriesPredictBatchOp.java new file mode 100644 index 000000000..539d09793 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseModelOutlierWithSeriesPredictBatchOp.java @@ -0,0 +1,71 @@ +package com.alibaba.alink.operator.common.outlier; + +import org.apache.flink.api.java.DataSet; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.function.TriFunction; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.mapper.ModelMapper; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.ModelMapBatchOp; +import com.alibaba.alink.params.outlier.HasDetectLast; +import com.alibaba.alink.params.outlier.HasInputMTableCol; +import com.alibaba.alink.params.outlier.HasOutputMTableCol; +import com.alibaba.alink.params.outlier.HasWithSeriesInfo; + +@NameCn("异常检测基类") +public class BaseModelOutlierWithSeriesPredictBatchOp> + extends ModelMapBatchOp implements ModelOutlierWithSeriesDetectorParams { + + public BaseModelOutlierWithSeriesPredictBatchOp( + TriFunction mapperBuilder, + Params params) { + super(mapperBuilder, params); + } + + @Override + public T linkFrom(BatchOperator ... inputs) { + checkOpSize(2, inputs); + + try { + if (getParams().get(HasWithSeriesInfo.WITH_SERIES_INFO)) { + //Step 1 : Grouped the input rows into MTables + BatchOperator in_grouped = BaseOutlierBatchOp.group2MTables(inputs[1], getParams()); + + //Step 2 : detect the outlier for each MTable + ModelMapper mapper = this.mapperBuilder.apply(inputs[0].getSchema(), inputs[1].getSchema(), + getParams().clone() + .set(HasInputMTableCol.INPUT_MTABLE_COL, OutlierDetector.TEMP_MTABLE_COL) + .set(HasOutputMTableCol.OUTPUT_MTABLE_COL, OutlierDetector.TEMP_MTABLE_COL) + .set(HasDetectLast.DETECT_LAST, false) + ); + DataSet resultRows = ModelMapBatchOp.calcResultRows(inputs[0], in_grouped, mapper, getParams()); + + //Step 3 : Flatten the MTables to final results + Table resultTable = BaseOutlierBatchOp.flattenMTable( + resultRows, inputs[1].getSchema(), mapper.getOutputSchema(), getParams(), getMLEnvironmentId() + ); + + this.setOutputTable(resultTable); + } else { + final ModelMapper mapper = this.mapperBuilder.apply( + inputs[0].getSchema(), + inputs[1].getSchema(), + this.getParams()); + + DataSet resultRows = ModelMapBatchOp.calcResultRows(inputs[0], inputs[1], mapper, getParams()); + + TableSchema outputSchema = mapper.getOutputSchema(); + this.setOutput(resultRows, outputSchema); + } + + return (T) this; + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseModelOutlierWithSeriesPredictStreamOp.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseModelOutlierWithSeriesPredictStreamOp.java new file mode 100644 index 000000000..35b7e4a52 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseModelOutlierWithSeriesPredictStreamOp.java @@ -0,0 +1,93 @@ +package com.alibaba.alink.operator.common.outlier; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.function.TriFunction; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.io.directreader.DataBridge; +import com.alibaba.alink.common.mapper.ModelMapper; +import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.stream.StreamOperator; +import com.alibaba.alink.operator.stream.utils.ModelMapStreamOp; +import com.alibaba.alink.params.io.ModelFileSinkParams; +import com.alibaba.alink.params.outlier.HasDetectLast; +import com.alibaba.alink.params.outlier.HasInputMTableCol; +import com.alibaba.alink.params.outlier.HasOutputMTableCol; +import com.alibaba.alink.params.outlier.HasWithSeriesInfo; +import com.alibaba.alink.params.shared.HasModelFilePath; + +@NameCn("异常检测基类") +public class BaseModelOutlierWithSeriesPredictStreamOp> + extends ModelMapStreamOp implements ModelOutlierWithSeriesDetectorParams , HasModelFilePath { + + public BaseModelOutlierWithSeriesPredictStreamOp(BatchOperator model, + TriFunction mapperBuilder, + Params params) { + super(model, mapperBuilder, params); + } + + @Override + public T linkFrom(StreamOperator ... inputs) { + checkMinOpSize(1, inputs); + + StreamOperator input_data = inputs[0]; + + StreamOperator input_model_stream = inputs.length > 1 ? inputs[1] : null; + + try { + if (getParams().get(HasWithSeriesInfo.WITH_SERIES_INFO)) { + //Step 1 : Grouped the input rows into MTables + StreamOperator in_grouped = BaseOutlierStreamOp.group2MTables(input_data, getParams()); + + Tuple2 dataBridge = createDataBridge( + getParams().get(ModelFileSinkParams.MODEL_FILE_PATH), + model + ); + + //Step 2 : detect the outlier for each MTable + ModelMapper mapper = this.mapperBuilder.apply(dataBridge.f1, input_data.getSchema(), + getParams().clone() + .set(HasInputMTableCol.INPUT_MTABLE_COL, OutlierDetector.TEMP_MTABLE_COL) + .set(HasOutputMTableCol.OUTPUT_MTABLE_COL, OutlierDetector.TEMP_MTABLE_COL) + .set(HasDetectLast.DETECT_LAST, false) + ); + DataStream resultRows = ModelMapStreamOp.calcResultRows( + dataBridge.f0, dataBridge.f1, in_grouped, input_model_stream, + mapper, getParams(), getMLEnvironmentId(), mapperBuilder); + + //Step 3 : Flatten the MTables to final results + Table resultTable = BaseOutlierStreamOp.flattenMTable( + resultRows, input_data.getSchema(), mapper.getOutputSchema(), getParams(), getMLEnvironmentId() + ); + + this.setOutputTable(resultTable); + } else { + + Tuple2 dataBridge = createDataBridge( + getParams().get(ModelFileSinkParams.MODEL_FILE_PATH), + model + ); + + final ModelMapper mapper = this.mapperBuilder.apply(dataBridge.f1, input_data.getSchema(), + this.getParams()); + + DataStream resultRows = ModelMapStreamOp.calcResultRows( + dataBridge.f0, dataBridge.f1, input_data, input_model_stream, + mapper, getParams(), getMLEnvironmentId(), mapperBuilder); + + TableSchema outputSchema = mapper.getOutputSchema(); + this.setOutput(resultRows, outputSchema); + } + + return (T) this; + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + +} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseOutlierBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseOutlierBatchOp.java index 6988e2719..76f29aa86 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseOutlierBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseOutlierBatchOp.java @@ -10,7 +10,7 @@ import org.apache.flink.types.Row; import org.apache.flink.util.Collector; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.MTable; import com.alibaba.alink.common.annotation.InputPorts; import com.alibaba.alink.common.annotation.Internal; @@ -21,7 +21,7 @@ import com.alibaba.alink.common.exceptions.AkColumnNotFoundException; import com.alibaba.alink.common.mapper.FlatMapperAdapter; import com.alibaba.alink.common.mapper.Mapper; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.TableSourceBatchOp; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseOutlierStreamOp.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseOutlierStreamOp.java index 384f2bee7..4016b40a1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseOutlierStreamOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/BaseOutlierStreamOp.java @@ -7,11 +7,11 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.annotation.Internal; import com.alibaba.alink.common.mapper.FlatMapperAdapter; import com.alibaba.alink.common.mapper.Mapper; -import com.alibaba.alink.common.utils.DataStreamConversionUtil; +import com.alibaba.alink.operator.stream.utils.DataStreamConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.common.dataproc.FlattenMTableMapper; import com.alibaba.alink.operator.stream.StreamOperator; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/ModelOutlierDetector.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/ModelOutlierDetector.java index 17f394b1c..4c4f05b1e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/outlier/ModelOutlierDetector.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/ModelOutlierDetector.java @@ -6,7 +6,7 @@ import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.mapper.ModelMapper; import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.params.outlier.ModelOutlierParams; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/ModelOutlierWithSeriesDetector.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/ModelOutlierWithSeriesDetector.java new file mode 100644 index 000000000..a66101aff --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/ModelOutlierWithSeriesDetector.java @@ -0,0 +1,103 @@ +package com.alibaba.alink.operator.common.outlier; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.common.MTable; +import com.alibaba.alink.common.mapper.ModelMapper; +import com.alibaba.alink.params.outlier.HasDetectLast; +import com.alibaba.alink.params.outlier.HasWithSeriesInfo; +import com.alibaba.alink.params.shared.colname.HasSelectedCol; +import com.alibaba.alink.params.shared.colname.HasSelectedColsDefaultAsNull; + +import java.util.Map; + +import static com.alibaba.alink.operator.common.outlier.OutlierDetector.TEMP_MTABLE_COL; +import static com.alibaba.alink.operator.common.outlier.OutlierDetector.appendPreds2MTable; + +public abstract class ModelOutlierWithSeriesDetector extends ModelMapper { + + /** + * The condition that the mapper output the prediction detail or not. + */ + private final boolean isPredDetail; + private final boolean detectWithSeriesInfo; + + public ModelOutlierWithSeriesDetector(TableSchema modelSchema, TableSchema dataSchema, + Params params) { + super(modelSchema, dataSchema, params); + isPredDetail = params.contains(ModelOutlierWithSeriesDetectorParams.PREDICTION_DETAIL_COL); + if (params.contains(HasWithSeriesInfo.WITH_SERIES_INFO)) { + detectWithSeriesInfo = params.get(HasWithSeriesInfo.WITH_SERIES_INFO); + } else { + detectWithSeriesInfo = false; + } + } + + @Override + protected void map(SlicedSelectedSample selection, SlicedResult result) throws Exception { + if (detectWithSeriesInfo) { + MTable mt = (MTable) selection.get(0); + + Tuple3 >[] preds = detectWithSeries(mt, params.get(HasDetectLast.DETECT_LAST)); + + result.set(0, appendPreds2MTable(mt, preds, params, isPredDetail)); + } else { + Row row = new Row(selection.length()); + selection.fillRow(row); + Tuple3 > t2 = detectByModel(row); + if (isPredDetail) { + result.set(0, t2.f0); + result.set(1, t2.f1); + } else { + result.set(0, t2.f0); + } + } + } + + @Override + protected Tuple4 [], String[]> prepareIoSchema(TableSchema modelSchema, + TableSchema dataSchema, + Params params) { + if (params.contains(HasWithSeriesInfo.WITH_SERIES_INFO) && params.get(HasWithSeriesInfo.WITH_SERIES_INFO)) { + return new Tuple4 <>( + new String[] {TEMP_MTABLE_COL}, + new String[] {TEMP_MTABLE_COL}, + new TypeInformation [] {AlinkTypes.M_TABLE}, + new String[0] + ); + } else { + String[] selectedCols = null; + if (params.contains(HasSelectedCol.SELECTED_COL)) { + selectedCols = new String[] {params.get(HasSelectedCol.SELECTED_COL)}; + } else { + selectedCols = params.get(HasSelectedColsDefaultAsNull.SELECTED_COLS); + if (null == selectedCols) { + selectedCols = dataSchema.getFieldNames(); + } + } + String[] outputCols; + TypeInformation [] outputTypes; + String predResultColName = params.get(ModelOutlierWithSeriesDetectorParams.PREDICTION_COL); + boolean isPredDetail = params.contains(ModelOutlierWithSeriesDetectorParams.PREDICTION_DETAIL_COL); + if (isPredDetail) { + String predDetailColName = params.get(ModelOutlierWithSeriesDetectorParams.PREDICTION_DETAIL_COL); + outputCols = new String[] {predResultColName, predDetailColName}; + outputTypes = new TypeInformation [] {AlinkTypes.BOOLEAN, AlinkTypes.STRING}; + } else { + outputCols = new String[] {predResultColName}; + outputTypes = new TypeInformation [] {AlinkTypes.BOOLEAN}; + } + return Tuple4.of(selectedCols, outputCols, outputTypes, new String[0]); + } + } + + public abstract Tuple3 >[] detectWithSeries(MTable series, boolean detectLast) throws Exception; + + public abstract Tuple3 > detectByModel(Row selection) throws Exception; +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/ModelOutlierWithSeriesDetectorParams.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/ModelOutlierWithSeriesDetectorParams.java new file mode 100644 index 000000000..c831a5242 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/ModelOutlierWithSeriesDetectorParams.java @@ -0,0 +1,6 @@ +package com.alibaba.alink.operator.common.outlier; + +import com.alibaba.alink.params.outlier.OutlierDetectorParams; + +public interface ModelOutlierWithSeriesDetectorParams extends OutlierDetectorParams { +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/NGramModelDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/NGramModelDataConverter.java new file mode 100644 index 000000000..6502778a6 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/NGramModelDataConverter.java @@ -0,0 +1,46 @@ +package com.alibaba.alink.operator.common.outlier; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.model.SimpleModelDataConverter; + +import java.util.Arrays; +import java.util.HashMap; + +import static com.alibaba.alink.common.utils.JsonConverter.gson; + +/** + * @author lqb + * @ClassName NgramModel + * @description NgramModel is + * @date 2019/04/10 + */ +public class NGramModelDataConverter + extends SimpleModelDataConverter , HashMap > { + public final static String NGRAM = "ngram"; + public final static String TEXT_NGRAM_CNT = "textNgramCnt"; + public final static String TEXT_LENGTH = "textLength"; + + public final static String WORD_TYPE = "word"; + public final static String STRING_TYPE = "string"; + + final static String NGRAM_CNT = "ngram_cnt"; + final static String TEXT_CNT = "text_cnt"; + final static String MAX_TEXT_LENGTH = "max_text_length"; + final static String MIN_TEXT_LENGTH = "min_text_length"; + final static String AVG_TEXT_LENGTH = "avg_text_length"; + + public NGramModelDataConverter() { + } + + @Override + public Tuple2 > serializeModel(HashMap data) { + return Tuple2.of(new Params(), Arrays.asList(gson.toJson(data))); + } + + @Override + public HashMap deserializeModel(Params meta, Iterable modelData) { + return gson.fromJson(modelData.iterator().next(), HashMap.class); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/OcsvmModelDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/OcsvmModelDataConverter.java new file mode 100644 index 000000000..599a6b52b --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/OcsvmModelDataConverter.java @@ -0,0 +1,63 @@ +package com.alibaba.alink.operator.common.outlier; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.model.ModelParamName; +import com.alibaba.alink.common.model.SimpleModelDataConverter; +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.operator.common.outlier.OcsvmModelData.SvmModelData; +import com.alibaba.alink.params.outlier.OcsvmModelTrainParams; +import com.alibaba.alink.params.shared.colname.HasSelectedColsDefaultAsNull; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +public class OcsvmModelDataConverter + extends SimpleModelDataConverter { + + public OcsvmModelDataConverter() { + } + + @Override + public Tuple2 > serializeModel(OcsvmModelData data) { + Params meta = new Params(); + meta.set(OcsvmModelTrainParams.KERNEL_TYPE, data.kernelType); + if (data.degree >= 1) { + meta.set(OcsvmModelTrainParams.DEGREE, data.degree); + } + meta.set(OcsvmModelTrainParams.GAMMA, data.gamma); + meta.set(OcsvmModelTrainParams.COEF0, data.coef0); + meta.set(OcsvmModelTrainParams.NU, data.nu); + meta.set(OcsvmModelTrainParams.FEATURE_COLS, data.featureColNames); + meta.set(OcsvmModelTrainParams.VECTOR_COL, data.vectorCol); + meta.set(ModelParamName.BAGGING_NUMBER, data.baggingNumber); + List modelData = new ArrayList <>(); + for (int i = 0; i < data.models.length; ++i) { + String json = JsonConverter.toJson(data.models[i]); + modelData.add(json); + } + return Tuple2.of(meta, modelData); + } + + @Override + public OcsvmModelData deserializeModel(Params meta, Iterable data) { + OcsvmModelData modelData = new OcsvmModelData(); + modelData.baggingNumber = meta.get(ModelParamName.BAGGING_NUMBER); + modelData.models = new SvmModelData[modelData.baggingNumber]; + Iterator dataIterator = data.iterator(); + for (int i = 0; i < modelData.baggingNumber; ++i) { + modelData.models[i] = JsonConverter.fromJson(dataIterator.next(), SvmModelData.class); + } + modelData.featureColNames = meta.get(HasSelectedColsDefaultAsNull.SELECTED_COLS); + modelData.kernelType = meta.get(OcsvmModelTrainParams.KERNEL_TYPE); + modelData.degree = meta.get(OcsvmModelTrainParams.DEGREE); + modelData.gamma = meta.get(OcsvmModelTrainParams.GAMMA); + modelData.coef0 = meta.get(OcsvmModelTrainParams.COEF0); + modelData.nu = meta.get(OcsvmModelTrainParams.NU); + modelData.featureColNames = meta.get(OcsvmModelTrainParams.FEATURE_COLS); + modelData.vectorCol = meta.get(OcsvmModelTrainParams.VECTOR_COL); + return modelData; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/OcsvmModelDetector.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/OcsvmModelDetector.java new file mode 100644 index 000000000..527dcb3a7 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/OcsvmModelDetector.java @@ -0,0 +1,82 @@ +package com.alibaba.alink.operator.common.outlier; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.Vector; +import com.alibaba.alink.common.linalg.VectorUtil; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.common.outlier.OcsvmModelData.SvmModelData; +import com.alibaba.alink.params.outlier.HaskernelType.KernelType; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static com.alibaba.alink.operator.common.outlier.OcsvmKernel.svmPredict; + +public class OcsvmModelDetector extends ModelOutlierDetector { + private static final long serialVersionUID = 6504098446269455446L; + private int[] featureIdx; + private int vectorIndex = -1; + private OcsvmModelData modelData; + private DenseVector localX; + private double gamma; + private double coef0; + private int degree; + private KernelType kernelType; + + public OcsvmModelDetector(TableSchema modelSchema, TableSchema dataSchema, Params params) { + super(modelSchema, dataSchema, params); + } + + @Override + public void loadModel(List modelRows) { + modelData = new OcsvmModelDataConverter().load(modelRows); + this.gamma = modelData.gamma; + this.coef0 = modelData.coef0; + this.degree = modelData.degree; + this.kernelType = modelData.kernelType; + if (modelData.featureColNames != null) { + featureIdx = TableUtil.findColIndicesWithAssertAndHint( + getSelectedCols(), + modelData.featureColNames + ); + localX = new DenseVector(featureIdx.length); + } + String vectorCol = modelData.vectorCol; + if (vectorCol != null && !vectorCol.isEmpty()) { + this.vectorIndex = TableUtil.findColIndexWithAssertAndHint(getSelectedCols(), vectorCol); + } + } + + @Override + protected Tuple3 > detect(SlicedSelectedSample selection) { + double score = 0.0; + for (SvmModelData model : modelData.models) { + double pred = predictSingle(selection, model); + score -= pred; + } + boolean finalResult = score >= 0.0; + Map detail = new HashMap <>(); + detail.put("outlier_score", String.valueOf(score)); + return Tuple3.of(finalResult, score, detail); + } + + public double predictSingle(SlicedSelectedSample selection, SvmModelData model) { + Vector x; + if (this.vectorIndex != -1) { + Object obj = selection.get(this.vectorIndex); + x = VectorUtil.getVector(obj); + return svmPredict(model, x, kernelType, gamma, coef0, degree); + } else { + for (int i = 0; i < featureIdx.length; ++i) { + localX.set(i, ((Number) selection.get(featureIdx[i])).doubleValue()); + } + return svmPredict(model, localX, kernelType, gamma, coef0, degree); + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/OutlierDetector.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/OutlierDetector.java index 7214f99d8..3309d345f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/outlier/OutlierDetector.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/OutlierDetector.java @@ -8,7 +8,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.MTable; import com.alibaba.alink.common.mapper.Mapper; import com.alibaba.alink.common.utils.JsonConverter; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/OutlierUtil.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/OutlierUtil.java index a4f0b25a2..e4cf55440 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/outlier/OutlierUtil.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/OutlierUtil.java @@ -8,7 +8,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.MTable; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.SparseVector; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/SHEsdDetector.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/SHEsdDetector.java new file mode 100644 index 000000000..cd714f97c --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/SHEsdDetector.java @@ -0,0 +1,129 @@ +package com.alibaba.alink.operator.common.outlier; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; + +import com.alibaba.alink.common.MTable; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.operator.common.outlier.tsa.tsacalculator.STLDecomposerCalc; +import com.alibaba.alink.params.outlier.HasDirection; + +import java.util.HashSet; +import java.util.Map; + +public class SHEsdDetector extends OutlierDetector { + + private final int inputMaxIter; + private double alpha; + private HasDirection.Direction direction; + private int frequency; + private final String selectedCol; + + public SHEsdDetector(TableSchema dataSchema, Params params) { + super(dataSchema, params); + inputMaxIter = params.contains(SHEsdDetectorParams.MAX_ITER) ? params.get(SHEsdDetectorParams.MAX_ITER) : -1; + alpha = params.get(SHEsdDetectorParams.SHESD_ALPHA); + direction = params.get(SHEsdDetectorParams.DIRECTION); + frequency = params.get(SHEsdDetectorParams.FREQUENCY); + this.selectedCol = params.contains(SHEsdDetectorParams.SELECTED_COL) ? params.get( + SHEsdDetectorParams.SELECTED_COL) + : dataSchema.getFieldNames()[0]; + } + + @Override + public Tuple3 >[] detect(MTable series, boolean detectLast) throws Exception { + double[] data = OutlierUtil.getNumericArray(series, selectedCol); + int dataNum = data.length; + int maxIter = Math.min(this.inputMaxIter, dataNum / 2); + if (maxIter < 0) { + maxIter = (dataNum + 9) / 10; + } + Tuple3 >[] results = new Tuple3[dataNum]; + for (int i = 0; i < dataNum; i++) { + results[i] = Tuple3.of(false, 0.1, null); + } + + DenseVector[] components = STLDecomposerCalc.decompose(data, frequency); + double[] trend = components[0].getData(); + double[] seasonal = components[1].getData(); + double dataMedian = CalcMidian.tempMedian(data); + double[] dataDecomp = new double[dataNum]; + //here minus median and seasonal of data. + for (int i = 0; i < dataNum; i++) { + data[i] -= (dataMedian + seasonal[i]); + dataDecomp[i] = trend[i] + seasonal[i]; + } + CalcMidian calcMidian = new CalcMidian(data); + int[] outlierIndex = new int[maxIter]; + double[] ares = new double[dataNum]; + HashSet excludedIndices = new HashSet <>(); + for (int i = 1; i < maxIter + 1; i++) { + double median = calcMidian.median(); + //area is the deviation value of each data. + switch (direction) { + case POSITIVE: + for (int j = 0; j < dataNum; j++) { + ares[j] = data[j] - median; + } + break; + case NEGATIVE: + for (int j = 0; j < dataNum; j++) { + ares[j] = median - data[j]; + } + break; + case BOTH: + for (int j = 0; j < dataNum; j++) { + ares[j] = Math.abs(data[j] - median); + } + break; + default: + } + double dataSigma = TimeSeriesAnomsUtils.mad(calcMidian); + + if (Math.abs(dataSigma) < 1e-4) { + break; + } + double maxValue = -Double.MAX_VALUE; + int maxIndex = -1; + for (int j = 0; j < dataNum; j++) { + if (!excludedIndices.contains(j)) { + if (ares[j] > maxValue) { + maxValue = ares[j]; + maxIndex = j; + } + } + } + //添加的时候就意味着计算中位数的地方要扣除相应的。 + calcMidian.remove(data[maxIndex]); + excludedIndices.add(maxIndex); + + maxValue /= dataSigma; + outlierIndex[i - 1] = maxIndex; + double p; + //tempNum is the sample num? + int tempNum = dataNum - i + 1; + if (direction == HasDirection.Direction.BOTH) { + p = 1 - alpha / (2 * tempNum); + } else { + p = 1 - alpha / tempNum; + } + double t = TimeSeriesAnomsUtils.tppf(p, tempNum - 2); + //lam is the hypothesis test condition. + double lam = t * (tempNum - 1) / Math.sqrt((tempNum - 2 + t * t) * tempNum); + if (maxValue > lam) { + results[maxIndex].f0 = true; + results[maxIndex].f1 = EsdDetector.cdfBoth(maxValue, tempNum); + } else { + break; + } + } + + if (detectLast) { + return new Tuple3[] {results[results.length - 1]}; + } else { + return results; + } + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/SHEsdDetectorParams.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/SHEsdDetectorParams.java new file mode 100644 index 000000000..edb4feb89 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/SHEsdDetectorParams.java @@ -0,0 +1,19 @@ +package com.alibaba.alink.operator.common.outlier; + +import com.alibaba.alink.params.outlier.HasDirection; +import com.alibaba.alink.params.outlier.HasMaxOutlierRatio; +import com.alibaba.alink.params.outlier.tsa.HasMaxAnoms; +import com.alibaba.alink.params.outlier.tsa.HasSHESDAlpha; +import com.alibaba.alink.params.shared.colname.HasSelectedCol; +import com.alibaba.alink.params.shared.colname.HasTimeCol; +import com.alibaba.alink.params.shared.iter.HasMaxIter; +import com.alibaba.alink.params.timeseries.HasFrequency; + +public interface SHEsdDetectorParams extends + HasSelectedCol, + HasTimeCol, + HasFrequency , + HasSHESDAlpha , + HasDirection , + HasMaxIter { +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/SHEsdOutlierBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/SHEsdOutlierBatchOp.java new file mode 100644 index 000000000..c5ea6a210 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/SHEsdOutlierBatchOp.java @@ -0,0 +1,21 @@ +package com.alibaba.alink.operator.common.outlier; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.annotation.NameEn; + +@NameCn("SHEsd异常检测") +@NameEn("SHEsd Outlier") +public class SHEsdOutlierBatchOp extends BaseOutlierBatchOp + implements SHEsdDetectorParams { + + public SHEsdOutlierBatchOp() { + this(null); + } + + public SHEsdOutlierBatchOp(Params params) { + super(SHEsdDetector::new, params); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/String2NgramRow.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/String2NgramRow.java new file mode 100644 index 000000000..5efb0008b --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/String2NgramRow.java @@ -0,0 +1,46 @@ +package com.alibaba.alink.operator.common.outlier; + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.table.functions.TableFunction; +import org.apache.flink.types.Row; + +public class String2NgramRow extends TableFunction { + private static final long serialVersionUID = -466143644625699898L; + private int n = 5; + + public String2NgramRow() { + } + + public String2NgramRow(int n) { + this.n = n; + } + + public void eval(String str) { + try { + if (null == str || "" == str) { + Row r = new Row(2); + r.setField(0, ""); + r.setField(1, 0); + collect(r); + return; + } + int length = str.length(); + for (int i = 0; i < length - n + 1; i++) { + Row r = new Row(2); + r.setField(0, str.substring(i, i + n)); + r.setField(1, length); + collect(r); + } + } catch (Exception ex) { + ex.printStackTrace(); + } + + } + + @Override + public TypeInformation getResultType() { + return new RowTypeInfo(new TypeInformation[] {BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO}); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/TimeSeriesAnomsUtils.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/TimeSeriesAnomsUtils.java index 6c2030bcf..b298d905d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/outlier/TimeSeriesAnomsUtils.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/TimeSeriesAnomsUtils.java @@ -1,12 +1,8 @@ package com.alibaba.alink.operator.common.outlier; import org.apache.flink.api.common.functions.RichMapFunction; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; -import org.apache.flink.streaming.api.datastream.DataStream; -import org.apache.flink.table.api.Table; import org.apache.flink.types.Row; import com.alibaba.alink.common.exceptions.AkIllegalDataException; @@ -15,10 +11,7 @@ import com.alibaba.alink.common.linalg.Vector; import com.alibaba.alink.common.linalg.VectorUtil; import com.alibaba.alink.common.probabilistic.IDF; -import com.alibaba.alink.common.utils.DataSetConversionUtil; -import com.alibaba.alink.common.utils.DataStreamConversionUtil; -import java.util.Arrays; import java.util.Comparator; import java.util.List; @@ -103,52 +96,6 @@ public Row map(Row value) throws Exception { } } - //this is for time series algo. - public static Table getStreamTable(long id, DataStream data, - Tuple2 schema, TypeInformation[] groupTypes) { - String[] colNames = schema.f0; - TypeInformation[] colTypes = schema.f1; - return getStreamTable(id, data, colNames, colTypes, groupTypes); - } - - //this is for no use graph algo. - public static Table getStreamTable(long id, DataStream data, - String[] colNames, TypeInformation[] colTypes, TypeInformation[] groupTypes) { - boolean containsIdCol = colNames[0] != null; - if (!containsIdCol) { - colNames[0] = "tempId"; - colTypes[0] = TypeInformation.of(Comparable.class); - } - if (groupTypes != null) { - System.arraycopy(groupTypes, 0, colTypes, 0, groupTypes.length); - } - Table table = DataStreamConversionUtil.toTable(id, data, colNames, colTypes); - if (!containsIdCol) { - String[] outColNames = Arrays.copyOfRange(colNames, 1, colNames.length); - table = table.select(join(outColNames, ",")); - } - return table; - } - - public static Table getBatchTable(long id, DataSet data, - Tuple2 schema, TypeInformation[] groupTypes) { - String[] colNames = schema.f0; - TypeInformation[] colTypes = schema.f1; - boolean containsIdCol = groupTypes != null; - if (!containsIdCol) { - colNames[0] = "tempId"; - colTypes[0] = TypeInformation.of(Comparable.class); - } else { - System.arraycopy(groupTypes, 0, colTypes, 0, groupTypes.length); - } - Table table = DataSetConversionUtil.toTable(id, data, colNames, colTypes); - if (!containsIdCol) { - String[] outColNames = Arrays.copyOfRange(colNames, 1, colNames.length); - table = table.select(join(outColNames, ",")); - } - return table; - } - //transform double array data to string. public static String transformData(double[] data) { StringBuilder sb = new StringBuilder(String.valueOf(data[0])); @@ -158,7 +105,7 @@ public static String transformData(double[] data) { return sb.toString(); } - private static String join(String[] stringArray, String on) { + public static String join(String[] stringArray, String on) { int length = stringArray.length; if (length == 1) { diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/TimeSeriesDecomposeDetectBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/TimeSeriesDecomposeDetectBatchOp.java new file mode 100644 index 000000000..b87967025 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/TimeSeriesDecomposeDetectBatchOp.java @@ -0,0 +1,22 @@ +package com.alibaba.alink.operator.common.outlier; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.Internal; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.params.shared.colname.HasTimeCol; + +@Internal +@NameCn("批式时序分解检测") +public class TimeSeriesDecomposeDetectBatchOp extends BaseOutlierBatchOp + implements TimeSeriesDecomposeParams , + HasTimeCol { + + public TimeSeriesDecomposeDetectBatchOp() { + this(null); + } + + public TimeSeriesDecomposeDetectBatchOp(Params params) { + super(TimeSeriesDecomposeDetector::new, params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/TimeSeriesDecomposeDetectStreamOp.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/TimeSeriesDecomposeDetectStreamOp.java new file mode 100644 index 000000000..066d2fb46 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/TimeSeriesDecomposeDetectStreamOp.java @@ -0,0 +1,19 @@ +package com.alibaba.alink.operator.common.outlier; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.annotation.Internal; +import com.alibaba.alink.common.annotation.NameCn; + +@Internal +@NameCn("流式时序分解检测") +public class TimeSeriesDecomposeDetectStreamOp extends BaseOutlierStreamOp + implements TimeSeriesDecomposeParams { + public TimeSeriesDecomposeDetectStreamOp() { + this(null); + } + + public TimeSeriesDecomposeDetectStreamOp(Params params) { + super(TimeSeriesDecomposeDetector::new, params); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/TimeSeriesDecomposeDetector.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/TimeSeriesDecomposeDetector.java new file mode 100644 index 000000000..0df5d214e --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/TimeSeriesDecomposeDetector.java @@ -0,0 +1,93 @@ +package com.alibaba.alink.operator.common.outlier; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.table.api.TableSchema; + +import com.alibaba.alink.common.MTable; +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.common.outlier.tsa.tsacalculator.BoxPlotDetectorCalc; +import com.alibaba.alink.operator.common.outlier.tsa.tsacalculator.ConvolutionDecomposerCalc; +import com.alibaba.alink.operator.common.outlier.tsa.tsacalculator.DecomposeOutlierDetectorCalc; +import com.alibaba.alink.operator.common.outlier.tsa.tsacalculator.KSigmaDetectorCalc; +import com.alibaba.alink.operator.common.outlier.tsa.tsacalculator.SHESDDetectorCalc; +import com.alibaba.alink.operator.common.outlier.tsa.tsacalculator.STLDecomposerCalc; +import com.alibaba.alink.operator.common.outlier.tsa.tsacalculator.TimeSeriesDecomposerCalc; +import com.alibaba.alink.params.shared.colname.HasTimeColDefaultAsNull; + +import java.util.Map; + +public class TimeSeriesDecomposeDetector extends OutlierDetector { + + private final String selectedCol; + private final String timeCol; + + public TimeSeriesDecomposeDetector(TableSchema dataSchema, + Params params) { + super(dataSchema, params); + this.selectedCol = params.contains(TimeSeriesDecomposeParams.SELECTED_COL) ? params.get( + TimeSeriesDecomposeParams.SELECTED_COL) + : dataSchema.getFieldNames()[0]; + this.timeCol = params.get(HasTimeColDefaultAsNull.TIME_COL); + } + + @Override + public Tuple3 >[] detect(MTable series, boolean detectLast) throws Exception { + if (null != timeCol) { + series.orderBy(timeCol); + } + int colIndex = TableUtil.findColIndex(series.getSchema(), selectedCol); + int n = series.getNumRow(); + double[] data = new double[n]; + for (int i = 0; i < n; i++) { + data[i] = ((Number) series.getEntry(i, colIndex)).doubleValue(); + } + + TimeSeriesDecomposerCalc timeSeriesDecomposer = null; + switch (params.get(TimeSeriesDecomposeParams.DECOMPOSE_METHOD)) { + case CONVOLUTION: + timeSeriesDecomposer = new ConvolutionDecomposerCalc(params); + break; + default: + timeSeriesDecomposer = new STLDecomposerCalc(params); + } + + DenseVector[] decomposedData = null; + String errorDetail = ""; + try { + decomposedData = timeSeriesDecomposer.decompose(data); + } catch (Exception ex) { + //ex.printStackTrace(); + errorDetail = ex.getMessage(); + } + + int[] outlierIndexes = null; + if (decomposedData != null) { + DecomposeOutlierDetectorCalc decomposeOutlierDetector = null; + switch (params.get(TimeSeriesDecomposeParams.DETECT_METHOD)) { + case SHESD: + decomposeOutlierDetector = new SHESDDetectorCalc(params); + break; + case BoxPlot: + decomposeOutlierDetector = new BoxPlotDetectorCalc(params); + break; + default: + decomposeOutlierDetector = new KSigmaDetectorCalc(params); + } + outlierIndexes = decomposeOutlierDetector.detect(decomposedData[2].getData()); + } + + Tuple3 >[] tuple3s = new Tuple3[n]; + for (int i = 0; i < n; i++) { + tuple3s[i] = Tuple3.of(false, null, null); + } + if (outlierIndexes != null) { + for (int idx : outlierIndexes) { + tuple3s[idx].f0 = true; + } + } + + return tuple3s; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/TimeSeriesDecomposeParams.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/TimeSeriesDecomposeParams.java new file mode 100644 index 000000000..6def553d4 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/TimeSeriesDecomposeParams.java @@ -0,0 +1,73 @@ +package com.alibaba.alink.operator.common.outlier; + +import org.apache.flink.ml.api.misc.param.ParamInfo; +import org.apache.flink.ml.api.misc.param.ParamInfoFactory; + +import com.alibaba.alink.common.annotation.DescCn; +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.params.ParamUtil; +import com.alibaba.alink.params.outlier.OutlierDetectorParams; +import com.alibaba.alink.params.shared.colname.HasSelectedColDefaultAsNull; +import com.alibaba.alink.params.timeseries.HasFrequency; +import com.alibaba.alink.params.timeseries.holtwinters.HasSeasonalType; + +import java.io.Serializable; + +public interface TimeSeriesDecomposeParams extends + OutlierDetectorParams , + //HasTimeCol, + HasSelectedColDefaultAsNull , + HasSeasonalType , + HasFrequency { + + @NameCn("时序分解方法") + @DescCn("时序分解方法") + ParamInfo DECOMPOSE_METHOD = ParamInfoFactory + .createParamInfo("decomposeMethod", DecomposeMethod.class) + .setDescription("Method to decompose the time series.") + .setHasDefaultValue(DecomposeMethod.STL) + .build(); + + default DecomposeMethod getDecomposeMethod() { + return get(DECOMPOSE_METHOD); + } + + default T setDecomposeMethod(DecomposeMethod value) { + return set(DECOMPOSE_METHOD, value); + } + + default T setDecomposeMethod(String value) { + return set(DECOMPOSE_METHOD, ParamUtil.searchEnum(DECOMPOSE_METHOD, value)); + } + + enum DecomposeMethod implements Serializable { + STL, + CONVOLUTION + } + + @NameCn("时序分解结果的检测方法") + @DescCn("时序分解结果的检测方法") + ParamInfo DETECT_METHOD = ParamInfoFactory + .createParamInfo("detectMethod", DetectMethod.class) + .setDescription("Detect method for the decomposition of time series.") + .setHasDefaultValue(DetectMethod.KSigma) + .build(); + + default DetectMethod getDetectMethod() { + return get(DETECT_METHOD); + } + + default T setDetectMethod(DetectMethod value) { + return set(DETECT_METHOD, value); + } + + default T setDetectMethod(String value) { + return set(DETECT_METHOD, ParamUtil.searchEnum(DETECT_METHOD, value)); + } + + enum DetectMethod implements Serializable { + KSigma, + SHESD, + BoxPlot + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/STLMethod.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/STLMethod.java new file mode 100644 index 000000000..4565e5512 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/STLMethod.java @@ -0,0 +1,451 @@ +package com.alibaba.alink.operator.common.outlier.tsa; + +import com.alibaba.alink.operator.common.outlier.CalcMidian; +import com.alibaba.alink.operator.common.outlier.TimeSeriesAnomsUtils; + +import java.util.Arrays; + +//todo compare params with the paper. +public class STLMethod { + //jump 至少为1以增加相应平滑器的速度。 线性插值发生在每jump个值之间。 + //t(trend), s(seasonal), l(loess) + //提取季节性时的loess算法时间窗口宽度,需要是奇数 + private int sWindow; + //提取季节性时局部拟合多项式的阶数,0或1 + private int sDegree = 0; + private int sJump; + //提取趋势时的loess算法时间窗口宽度,需要是奇数 + private int tWindow; + private int tJump; + //提取趋势性时局部拟合多项式的阶数,0或1 + private int tDegree = 1; + //提取低通滤波子序列时的loess算法时间窗口宽度 + private int lWindow; + private int lJump; + //提取低通滤波子序列时局部拟合多项式的阶数,必须0或1 + private int lDegree; + private int inner;//the iter times of inner iter. + private int outer; + public double[] trend; + public double[] seasonal; + public double[] remainder; + + //frequency应该是每个季节下的数据个数。 + public STLMethod(double[] ts, int freq, String sWindowString, int sDegree, Integer sJump, + Integer tWindow, Integer tJump, Integer tDegree, + Integer lWindow, Integer lJump, Integer lDegree, + boolean robust, Integer inner, Integer outer) { + int n = ts.length; + //frequency is the object numbers per period? + if (freq < 2) { + throw new RuntimeException("The frequency must be greater than 1."); + } + if (n <= freq * 2) { + throw new RuntimeException( + String.format( + "The time series must contain more than 2 full periods of data. n:%s, freq: %s", n, freq)); + } + if ("periodic".equals(sWindowString)) { // todo now only support periodic. What sWindow will be if else? + sWindow = 10 * n + 1; + } + this.sDegree = sDegree; + if (sJump == null) { + this.sJump = new Double(Math.ceil((this.sWindow / 10.0))).intValue(); + } else { + this.sJump = sJump; + } + if (tWindow == null) { + this.tWindow = nextOdd(Math.ceil(1.5 * freq / (1 - 1.5 / sWindow))); + } else { + this.tWindow = tWindow; + } + if (tJump == null) { + this.tJump = new Double(Math.ceil((this.tWindow / 10.0))).intValue(); + } else { + this.tJump = tJump; + } + if (tDegree != null) { + this.tDegree = tDegree; + } + if (lWindow == null) { + this.lWindow = nextOdd(freq); + } else { + this.lWindow = lWindow; + } + if (lJump == null) { + this.lJump = new Double(Math.ceil((this.lWindow / 10.0))).intValue(); + } else { + this.lJump = lJump; + } + if (lDegree == null) { + this.lDegree = this.tDegree; + } else { + this.lDegree = lDegree; + } + //robust是在loess过程中是否使用鲁棒拟合 + if (inner == null) { + this.inner = robust ? 1 : 2; + } else { + this.inner = inner; + } + if (outer == null) { + this.outer = robust ? 15 : 0; + } else { + this.outer = outer; + } + //initialize + double[] weights = new double[n]; + double[] seasonal = new double[n]; + double[] trend = new double[n]; + this.remainder = new double[n]; + double[][] work = new double[5][n + (freq << 1)]; + + this.sWindow = maxOdd(this.sWindow); + this.tWindow = maxOdd(this.tWindow); + this.lWindow = maxOdd(this.lWindow); + //this function is inner loop of the paper. + //useW就是计算过程中是否使用weight? + stlstp(ts, n, freq, sWindow, this.tWindow, this.lWindow, this.sDegree, this.tDegree, this.lDegree, + this.sJump, this.tJump, this.lJump, this.inner, false, weights, seasonal, trend, work); + + //does outer just calculate robust weights? + for (int nullTemp = 0; nullTemp < this.outer; nullTemp++) { + for (int i = 0; i < n; i++) { + work[0][i] = trend[i] + seasonal[i];//fit的结果,没有考虑reminder + } + //work0 may be the result of fit, and abs(work0-ts) is reminder。this step updates robust weights. + //this is due to the paper,是外层循环做的事情。 + stlrwt(ts, n, work[0], weights); + //seems the inner loop。 + stlstp(ts, n, freq, sWindow, this.tWindow, this.lWindow, this.sDegree, this.tDegree, this.lDegree, + this.sJump, this.tJump, this.lJump, this.inner, true, weights, seasonal, trend, work); + } + if (this.outer <= 0) { + Arrays.fill(weights, 1.0); + } + for (int i = 0; i < n; i++) { + this.remainder[i] = ts[i] - trend[i] - seasonal[i]; + } + this.trend = trend; + this.seasonal = seasonal; + } + + private static int maxOdd(int value) { + value = Math.max(value, 3); + if (value % 2 == 0) { + value += 1; + } + return value; + } + + private static int nextOdd(double x) { + int temp = (int) Math.round(x); + if (temp % 2 == 0) { + temp += 1; + } + return temp; + } + + //itdeg is tDegree + //this is the inner iteration. + private static void stlstp(double[] y, int n, int np, int sWindow, int tWindow, int lWindow, + int sDegree, int itdeg, int lDegree, + int sJump, int tJump, int lJump, + int ni, boolean userW, + double[] weights, double[] seasonal, double[] trend, double[][] work) { + //ni is inner + for (int nullTemp = 0; nullTemp < ni; nullTemp++) { + //first step: detrending. After detrending, save in work0. + for (int i = 0; i < n; i++) { + work[0][i] = y[i] - trend[i];//work0看上去是去除趋势的数据。 + } + + stlss(work[0], n, np, sWindow, sDegree, sJump, userW, weights, work[1], work[2], work[3], work[4], + seasonal); + stlfts(work[1], n + 2 * np, np, work[2], work[0]); + stless(work[2], n, lWindow, lDegree, lJump, false, work[3], work[0], work[4]); + for (int i = 0; i < n; i++) { + //work1可能是平滑结果。 + seasonal[i] = work[1][np + i] - work[0][i];//可能是去除平滑周期子序列趋势 + } + for (int i = 0; i < n; i++) { + work[0][i] = y[i] - seasonal[i];//去周期?? + } + stless(work[0], n, tWindow, itdeg, tJump, userW, weights, trend, work[2]); + } + } + + //isdeg is sDegree. y is the data after minus trend,that is seasonal+reminder. + private static void stlss(double[] y, int n, int np, + int sWindow, int isdeg, int sJump, boolean userW, double[] weights, + double[] season, double[] work1, double[] work2, double[] work3, double[] work4) { + for (int j = 0; j < np; j++) { + int k = (n - j - 1) / np + 1;//k应该是判断当前在哪个period中。 + for (int i = 0; i < k; i++) { + work1[i] = y[i * np + j];//what is work1? + } + if (userW) {//the first loop is false,and then in outer is true了。 + for (int i = 0; i < k; i++) { + work3[i] = weights[i * np + j]; + } + } + double[] work2From1 = Arrays.copyOfRange(work2, 1, work2.length); + stless(work1, k, sWindow, isdeg, sJump, userW, work3, work2From1, work4); + System.arraycopy(work2From1, 0, work2, 1, work2From1.length); + int nRight = Math.min(sWindow, k); + Double nVal = stlest(work1, k, sWindow, isdeg, 0, work2[0], 1, nRight, work4, userW, work3); + if (nVal != null) { + work2[0] = nVal; + } else { + work2[0] = work2[1]; + } + int nLeft = Math.max(1, k - sWindow + 1); + nVal = stlest(work1, k, sWindow, isdeg, k + 1, work2[k + 1], nLeft, k, work4, userW, work3); + if (nVal != null) { + work2[k + 1] = nVal; + } else { + work2[k + 1] = work2[k]; + } + for (int m = 0; m < k + 2; m++) { + season[m * np + j] = work2[m]; + } + } + } + + private static void stlfts(double[] x, int n, int np, double[] trend, double[] work) { + stlma(x, n, np, trend); + stlma(trend, n - np + 1, np, work); + stlma(work, n - 2 * np + 2, 3, trend); + } + + private static void stlma(double[] x, int n, int length, double[] ave) { + double v = TimeSeriesAnomsUtils.sumArray(x, length); + ave[0] = v / length; + + int newN = n - length + 1; + if (newN > 1) { + int k = length; + int m = 0; + for (int j = 1; j < newN; j++) { + k += 1; + m += 1; + v = v - x[m - 1] + x[k - 1]; + ave[j] = v / length; + } + } + } + + private static void stless(double[] y, int n, int length, int ideg, + int nJump, boolean userW, double[] weights, double[] ys, double[] res) { + if (n < 2) { + ys[0] = y[0]; + return; + } + int newNJump = Math.min(nJump, n - 1); + int nLeft = 0;//ini + int nRight = 0;//ini + if (length >= n) { + nLeft = 1; + nRight = n; + for (int i = 0; i < n; i += newNJump) { + Double nys = stlest(y, n, length, ideg, i + 1, ys[i], nLeft, nRight, res, userW, weights); + if (nys != null) { + ys[i] = nys; + } else { + ys[i] = y[i]; + } + } + } else { + if (newNJump == 1) { + int nsh = (length + 1) / 2; + nLeft = 1; + nRight = length; + for (int i = 0; i < n; i++) { + if (i + 1 > nsh && nRight != n) { + nLeft++; + nRight++; + } + Double nys = stlest(y, n, length, ideg, i + 1, ys[i], nLeft, nRight, res, userW, weights); + if (nys != null) { + ys[i] = nys; + } else { + ys[i] = y[i]; + } + } + } else { + int nsh = (length + 1) / 2; + for (int i = 1; i < n + 1; i += newNJump) { + if (i < nsh) { + nLeft = 1; + nRight = length; + } else if (i >= (n - nsh + 1)) { + nLeft = n - length + 1; + nRight = n; + } else { + nLeft = i - nsh + 1; + nRight = length + i - nsh; + } + Double nys = stlest(y, n, length, ideg, i, ys[i - 1], nLeft, nRight, res, userW, weights); + if (nys != null) { + ys[i - 1] = nys; + } else { + ys[i - 1] = y[i - 1]; + } + } + } + } + if (newNJump != 1) { + double delta; + for (int i = 0; i < n - newNJump; i += newNJump) { + delta = (ys[i + newNJump] - ys[i]) * 1.0 / newNJump; + for (int j = 1; j < newNJump; j++) { + ys[i + j] = ys[i] + delta * j; + } + } + int k = ((n - 1) / newNJump) * newNJump + 1; + if (k != n) { + Double nys = stlest(y, n, length, ideg, n, ys[n - 1], nLeft, nRight, res, userW, weights); + if (nys != null) { + ys[n - 1] = nys; + } else { + ys[n - 1] = y[n - 1]; + } + if (k != n - 1) { + delta = (ys[n - 1] - ys[k - 1]) * 1.0 / (n - k); + for (int j = 0; j < n - 1 - k; j++) { + ys[k + j] = ys[k - 1] + delta * (1 + j); + } + } + } + + } + } + + //todo fit only get the first n data. + //weights is the robustness in the paper. + //这个是外层训练做的事情 + private static void stlrwt(double[] y, int n, double[] fit, double[] weights) { + double[] r = new double[n]; + for (int i = 0; i < n; i++) { + r[i] = Math.abs(y[i] - fit[i]); + } + double median = 6 * CalcMidian.tempMedian(r);//r is the reminder. + double lowThre = 0.001 * median; + double highThre = 0.999 * median; + for (int i = 0; i < n; i++) { + if (r[i] <= lowThre) { + weights[i] = 1; + } else if (r[i] > highThre) { + weights[i] = 0; + } else { + weights[i] = Math.pow(1 - Math.pow(r[i] / median, 2), 2); + } + } + } + + private static Double stlest(double[] y, int n, int length, int ideg, int xs, double ys, + int nLeft, int nRight, double[] w, boolean userW, double[] weights) { + int h = Math.max(xs - nLeft, nRight - xs); + if (length > n) { + h += (length - n) / 2; + } + int[] r = generateArange(nLeft - xs, nRight - xs + 1, Type.ABS); + int[] window = generateArange(nLeft - 1, nRight, Type.SELF); + int rLength = nRight + 1 - nLeft; + double lowThre = 0.001 * h; + double highThre = 0.999 * h; + int[] judge = new int[rLength]; + for (int i = 0; i < rLength; i++) { + if (r[i] <= lowThre) { + judge[i] = 0;//low + } else if (r[i] > highThre) { + judge[i] = 1;//high + } else { + judge[i] = 2;//middle + } + } + double a = 0; + for (int i = 0; i < rLength; i++) { + int num = judge[i]; + if (num == 0) { + w[window[i]] = 1; + } else if (num == 2) { + w[window[i]] = Math.pow(1 - Math.pow(r[i] * 1.0 / h, 3), 3);//w is the neighborhood weight + } + if (num != 1) { + if (userW) { + w[window[i]] *= weights[window[i]]; + } + a += w[window[i]]; + } + if (num == 1) { + w[window[i]] = 0; + } + } + Double ret; + if (a <= 0) { + ret = null; + } else { + for (int i = nLeft - 1; i < nRight; i++) { + w[i] /= a; + } + if (h > 0 && ideg > 0) { + a = 0; + for (int i = 0; i < rLength; i++) { + a += w[nLeft - 1 + i] * (nLeft + i); + } + double b = xs - a; + double c = 0; + for (int i = 0; i < rLength; i++) { + c += w[nLeft - 1 + i] * Math.pow(nLeft - a + i, 2); + } + if (Math.sqrt(c) > 0.001 * (n - 1)) { + b /= c; + for (int i = 0; i < rLength; i++) { + w[nLeft - 1 + i] *= (b * (nLeft - a + i) + 1); + } + } + } + ret = 0.; + for (int i = 0; i < rLength; i++) { + ret += w[nLeft - 1 + i] * y[nLeft - 1 + i]; + } + } + return ret; + } + + //arange + private static int[] generateArange(int start, int end, Type type) { + int length = end - start; + int[] res = new int[length]; + if (type == Type.ABS) { //abs + for (int i = 0; i < length; i++) { + res[i] = Math.abs(start + i); + } + } else if (type == Type.SQUARE) { //square + for (int i = 0; i < length; i++) { + res[i] = (int) Math.pow(start + i, 2); + } + } else if (type == Type.SELF) { + for (int i = 0; i < length; i++) { + res[i] = start + i; + } + } + return res; + } + + enum Type { + /** + * calculate the absolute value + */ + ABS, + /** + * calculate the square value + */ + SQUARE, + /** + * calculate itself + */ + SELF + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/ArimaPredictorCalc.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/ArimaPredictorCalc.java new file mode 100644 index 000000000..8b90a7ee7 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/ArimaPredictorCalc.java @@ -0,0 +1,131 @@ +package com.alibaba.alink.operator.common.outlier.tsa.tsacalculator; + +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.operator.common.timeseries.arima.Arima; +import com.alibaba.alink.operator.common.timeseries.arima.ArimaModel; +import com.alibaba.alink.operator.common.timeseries.sarima.Sarima; +import com.alibaba.alink.operator.common.timeseries.sarima.SarimaModel; +import com.alibaba.alink.params.timeseries.ArimaParamsOld; +import com.alibaba.alink.params.timeseries.HasEstmateMethod; + +import java.util.ArrayList; + +public class ArimaPredictorCalc extends TimeSeriesPredictorCalc { + + private static final long serialVersionUID = -5965170186690618039L; + HasEstmateMethod.EstMethod estMethod; + Integer[] order; + int[] seasonality; + + int ifIntercept; + + ArimaPredictorCalc() {} + + public ArimaPredictorCalc(Params params) { + estMethod = params.get(ArimaParamsOld.EST_METHOD); + predictNum = params.get(ArimaParamsOld.PREDICT_NUM); + order = params.get(ArimaParamsOld.ORDER); + seasonality = params.get(ArimaParamsOld.SEASONAL_ORDER); + + int d = order[1]; + ifIntercept = 1; + if (d > 0) { + ifIntercept = 0; + } + + if (predictNum > 21) { + throw new RuntimeException("Long step prediction is not meaningful. " + + "The limitation is 20 steps. " + + "Please set forecasteStep to be smaller than 21"); + } + } + + @Override + public double[] forecastWithoutException(double[] data, int forecastStep, boolean trainBeforeForecast) { + + ArimaModel model = Arima.fit(data, order[0], order[1], order[2], estMethod); + + return model.forecast(forecastStep).get(0); + } + + @Override + public double[] predict(double[] data) { + return forecast(data, predictNum, true).f0; + } + + @Override + public Row map(Row in) { + double[] data = (double[]) in.getField(groupNumber + 1); + if (seasonality == null) { + return arima(in, data); + } else { + return seasonalArima(in, data); + } + } + + @Override + public ArimaPredictorCalc clone() { + ArimaPredictorCalc calc = new ArimaPredictorCalc(); + calc.estMethod = estMethod; + if (order != null) { + calc.order = order.clone(); + } + if (seasonality != null) { + calc.seasonality = seasonality.clone(); + } + calc.ifIntercept = ifIntercept; + calc.predictNum = predictNum; + calc.groupNumber = groupNumber; + return calc; + } + + private Row arima(Row in, double[] data) { + ArimaModel model = Arima.fit(data, order[0], order[1], order[2], estMethod); + + ArrayList forecast = model.forecast(this.predictNum); + return getData(in, + new DenseVector(forecast.get(0)), + new DenseVector(forecast.get(1)), + new DenseVector(forecast.get(2)), + new DenseVector(forecast.get(3)), + new DenseVector(model.arma.estimate.arCoef), + new DenseVector(model.arma.estimate.arCoefStdError), + new DenseVector(model.arma.estimate.maCoef), + new DenseVector(model.arma.estimate.maCoefStdError), + model.arma.estimate.intercept, + model.arma.estimate.interceptStdError, + model.arma.estimate.variance, + model.arma.estimate.varianceStdError, + model.ic, + model.arma.estimate.logLikelihood); + } + + private Row seasonalArima(Row in, double[] data) { + + SarimaModel as = Sarima.fit(data, order[0], order[1], order[2], seasonality[0], seasonality[1], seasonality[1], + estMethod, ifIntercept, 2); + + ArrayList forecast = as.forecast(this.predictNum); + return getData(in, + new DenseVector(forecast.get(0)), + new DenseVector(forecast.get(1)), + new DenseVector(forecast.get(2)), + new DenseVector(forecast.get(3)), + new DenseVector(as.sARCoef), + new DenseVector(as.sArStdError), + new DenseVector(as.sMACoef), + new DenseVector(as.sMaStdError), + new DenseVector(as.arima.arma.estimate.arCoef), + new DenseVector(as.arima.arma.estimate.arCoefStdError), + new DenseVector(as.arima.arma.estimate.maCoef), + new DenseVector(as.arima.arma.estimate.maCoefStdError), + as.arima.arma.estimate.intercept, + as.arima.arma.estimate.interceptStdError, + as.arima.arma.estimate.variance, + as.arima.ic, + as.arima.arma.estimate.logLikelihood); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/AutoArimaGarchPredictorCalc.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/AutoArimaGarchPredictorCalc.java new file mode 100644 index 000000000..9bb0dd849 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/AutoArimaGarchPredictorCalc.java @@ -0,0 +1,115 @@ +package com.alibaba.alink.operator.common.outlier.tsa.tsacalculator; + +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.operator.common.timeseries.arimagarch.ArimaGarch; +import com.alibaba.alink.operator.common.timeseries.arimagarch.ArimaGarchModel; +import com.alibaba.alink.operator.common.timeseries.arimagarch.ModelInfo; +import com.alibaba.alink.params.outlier.tsa.baseparams.BaseStreamPredictParams; +import com.alibaba.alink.params.timeseries.AutoArimaGarchParams; +import com.alibaba.alink.params.timeseries.HasArimaGarchMethod; +import com.alibaba.alink.params.timeseries.HasIcType; + +import java.util.ArrayList; + +public class AutoArimaGarchPredictorCalc extends TimeSeriesPredictorCalc { + private static final long serialVersionUID = 2069937892958311944L; + HasIcType.IcType ic; + boolean ifGARCH11; + int maxARIMA; + int maxGARCH; + HasArimaGarchMethod.ArimaGarchMethod arimaGarchMethod; + + AutoArimaGarchPredictorCalc() {} + + public AutoArimaGarchPredictorCalc(Params params) { + ic = params.get(AutoArimaGarchParams.IC_TYPE); + maxARIMA = params.get(AutoArimaGarchParams.MAX_ARIMA); + maxGARCH = params.get(AutoArimaGarchParams.MAX_GARCH); + predictNum = params.get(BaseStreamPredictParams.PREDICT_NUM); + ifGARCH11 = params.get(AutoArimaGarchParams.IF_GARCH11); + arimaGarchMethod = params.get(AutoArimaGarchParams.ARIMA_GARCH_METHOD); + } + + @Override + public double[] forecastWithoutException(double[] data, int forecastStep, boolean trainBeforeForecast) { + ArimaGarchModel aag = ArimaGarch.autoFit(data, ic, arimaGarchMethod, maxARIMA, maxGARCH, ifGARCH11); + ArrayList forecast = aag.forecast(forecastStep); + return forecast.get(0); + } + + @Override + public double[] predict(double[] data) { + ArimaGarchModel aag = ArimaGarch.autoFit(data, ic, arimaGarchMethod, maxARIMA, maxGARCH, ifGARCH11); + + if (aag.isGoodFit()) { + ArrayList forecast = aag.forecast(this.predictNum); + return forecast.get(0); + } else { + return null; + } + } + + @Override + public Row map(Row in) { + double[] data = (double[]) in.getField(groupNumber + 1); + + ArimaGarchModel aag = ArimaGarch.autoFit(data, ic, arimaGarchMethod, maxARIMA, maxGARCH, ifGARCH11); + + if (aag.isGoodFit()) { + ArrayList forecast = aag.forecast(this.predictNum); + ModelInfo mi = aag.mi; + return getData( + in, + new DenseVector(forecast.get(0)), + new DenseVector(forecast.get(1)), + new DenseVector(forecast.get(2)), + new DenseVector(forecast.get(3)), + JsonConverter.toJson(mi.order), + mi.ic, + mi.loglike, + new DenseVector(mi.arCoef), + new DenseVector(mi.seARCoef), + new DenseVector(mi.maCoef), + new DenseVector(mi.seMACoef), + mi.intercept, + mi.seIntercept, + new DenseVector(mi.alpha), + new DenseVector(mi.seAlpha), + new DenseVector(mi.beta), + new DenseVector(mi.seBeta), + mi.c, + mi.seC, + new DenseVector(mi.estResidual), + new DenseVector(mi.hHat), + mi.ifHetero + ); + + } else { + return getData(in, + null, null, null, null, + null, null, null, null, + null, null, null, null, + null, null, null, null, + null, null, null, null, + null, null); + } + } + + @Override + public AutoArimaGarchPredictorCalc clone() { + AutoArimaGarchPredictorCalc calc = new AutoArimaGarchPredictorCalc(); + calc.ic = ic; + calc.ifGARCH11 = ifGARCH11; + calc.maxARIMA = maxARIMA; + calc.maxGARCH = maxGARCH; + calc.arimaGarchMethod = arimaGarchMethod; + calc.predictNum = predictNum; + calc.groupNumber = groupNumber; + return calc; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/AutoArimaPredictorCalc.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/AutoArimaPredictorCalc.java new file mode 100644 index 000000000..872c40706 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/AutoArimaPredictorCalc.java @@ -0,0 +1,155 @@ +package com.alibaba.alink.operator.common.outlier.tsa.tsacalculator; + +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.operator.common.timeseries.arima.Arima; +import com.alibaba.alink.operator.common.timeseries.arima.ArimaModel; +import com.alibaba.alink.operator.common.timeseries.sarima.Sarima; +import com.alibaba.alink.operator.common.timeseries.sarima.SarimaModel; +import com.alibaba.alink.params.outlier.tsa.baseparams.BaseStreamPredictParams; +import com.alibaba.alink.params.timeseries.AutoArimaParams; +import com.alibaba.alink.params.timeseries.HasEstmateMethod; +import com.alibaba.alink.params.timeseries.HasIcType; + +import java.util.ArrayList; + +public class AutoArimaPredictorCalc extends TimeSeriesPredictorCalc { + private static final long serialVersionUID = 1472791946534368311L; + private HasEstmateMethod.EstMethod estMethod; + private HasIcType.IcType ic; + private int maxOrder; + private int seasonalPeriod; + private int maxSeasonalOrder; + + AutoArimaPredictorCalc() {} + + public AutoArimaPredictorCalc(Params params) { + estMethod = params.get(AutoArimaParams.EST_METHOD); + ic = params.get(AutoArimaParams.IC_TYPE); + predictNum = params.get(BaseStreamPredictParams.PREDICT_NUM); + maxOrder = params.get(AutoArimaParams.MAX_ORDER); + seasonalPeriod = params.get(AutoArimaParams.MAX_SEASONAL_ORDER); + + if (predictNum > 21) { + throw new RuntimeException("Long step prediction is not meaningful. " + + "The limitation is 20 steps. " + + "Please set forecasteStep to be smaller than 21"); + } + } + + private Row arima(Row in, double[] data) { + ArimaModel model = Arima.autoFit(data, maxOrder, estMethod, ic,-1); + + if (model.isGoodFit()) { + ArrayList forecast = model.forecast(this.predictNum); + return getData(in, + new DenseVector(forecast.get(0)), + new DenseVector(forecast.get(1)), + new DenseVector(forecast.get(2)), + new DenseVector(forecast.get(3)), + new DenseVector(model.arma.estimate.arCoef), + new DenseVector(model.arma.estimate.arCoefStdError), + new DenseVector(model.arma.estimate.maCoef), + new DenseVector(model.arma.estimate.maCoefStdError), + model.arma.estimate.intercept, + model.arma.estimate.interceptStdError, + model.arma.estimate.variance, + model.arma.estimate.varianceStdError, + model.ic, + model.arma.estimate.logLikelihood); + } else { + return getData(in, + null, null, null, null, + null, null, null, null, + null, null, null, null, + null, null); + } + + } + + private Row seasonalArima(Row key, double[] data) { + SarimaModel bestModel = Sarima.autoFit( + data, maxOrder, maxSeasonalOrder, + estMethod, ic, seasonalPeriod); + + Row newRow = new Row(18); + if (bestModel.isGoodFit()) { + newRow.setField(0, key); + ArrayList forecast = bestModel.forecast(this.predictNum); + return getData(key, + new DenseVector(forecast.get(0)), + new DenseVector(forecast.get(1)), + new DenseVector(forecast.get(2)), + new DenseVector(forecast.get(3)), + new DenseVector(bestModel.sARCoef), + new DenseVector(bestModel.sArStdError), + new DenseVector(bestModel.sMACoef), + new DenseVector(bestModel.sMaStdError), + new DenseVector(bestModel.arima.arma.estimate.arCoef), + new DenseVector(bestModel.arima.arma.estimate.arCoefStdError), + new DenseVector(bestModel.arima.arma.estimate.maCoef), + new DenseVector(bestModel.arima.arma.estimate.maCoefStdError), + bestModel.arima.arma.estimate.intercept, + bestModel.arima.arma.estimate.variance, + bestModel.ic, + bestModel.arima.arma.estimate.logLikelihood); + } else { + newRow.setField(0, key); + for (int i = 1; i < newRow.getArity(); i++) { + newRow.setField(i, null); + } + } + return newRow; + } + + @Override + public double[] forecastWithoutException(double[] data, int forecastStep, boolean trainBeforeForecast) { + if (seasonalPeriod == 1) { + ArimaModel model = Arima.autoFit(data, maxOrder, estMethod, ic, -1); + if (model.isGoodFit()) { + return model.forecast(forecastStep).get(0); + } else { + return null; + } + } else { + SarimaModel bestModel = Sarima.autoFit(data, maxOrder, + maxSeasonalOrder, estMethod, ic, + seasonalPeriod); + if (bestModel.isGoodFit()) { + return bestModel.forecast(forecastStep).get(0); + } else { + return null; + } + } + } + + @Override + public double[] predict(double[] data) { + return forecast(data, this.predictNum, true).f0; + } + + @Override + public Row map(Row in) { + + double[] data = (double[]) in.getField(groupNumber + 1); + if (seasonalPeriod == 1) { + return arima(in, data); + } else { + return seasonalArima(in, data); + } + } + + @Override + public AutoArimaPredictorCalc clone() { + AutoArimaPredictorCalc calc = new AutoArimaPredictorCalc(); + calc.predictNum = predictNum; + calc.groupNumber = groupNumber; + calc.estMethod = estMethod; + calc.ic = ic; + calc.maxOrder = maxOrder; + calc.seasonalPeriod = seasonalPeriod; + return calc; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/AutoGarchPredictorCalc.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/AutoGarchPredictorCalc.java new file mode 100644 index 000000000..fc093990f --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/AutoGarchPredictorCalc.java @@ -0,0 +1,87 @@ +package com.alibaba.alink.operator.common.outlier.tsa.tsacalculator; + +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.operator.common.timeseries.garch.Garch; +import com.alibaba.alink.operator.common.timeseries.garch.GarchModel; +import com.alibaba.alink.params.outlier.tsa.baseparams.BaseStreamPredictParams; +import com.alibaba.alink.params.timeseries.AutoGarchParams; +import com.alibaba.alink.params.timeseries.HasIcType; + +public class AutoGarchPredictorCalc extends TimeSeriesPredictorCalc { + private static final long serialVersionUID = -446589576860978836L; + private HasIcType.IcType ic; + private int upperbound; + private boolean ifGARCH11; + private boolean minusMean; + + AutoGarchPredictorCalc() {} + + public AutoGarchPredictorCalc(Params params) { + ic = params.get(AutoGarchParams.IC_TYPE); + upperbound = params.get(AutoGarchParams.MAX_ORDER); + predictNum = params.get(BaseStreamPredictParams.PREDICT_NUM); + ifGARCH11 = params.get(AutoGarchParams.IF_GARCH11); + minusMean = params.get(AutoGarchParams.MINUS_MEAN); + } + + @Override + public double[] forecastWithoutException(double[] data, int forecastStep, boolean trainBeforeForecast) { + GarchModel ag = Garch.autoFit(data, upperbound, minusMean, ic, ifGARCH11); + if (!ag.isGoodFit()) { + throw new RuntimeException("fail to fit the Garch model."); + } + return ag.forecast(forecastStep); + } + + @Override + public double[] predict(double[] data) { + GarchModel ag = Garch.autoFit(data, upperbound, minusMean, ic, ifGARCH11); + if (ag.isGoodFit()) { + return ag.forecast(this.predictNum); + } + return null; + } + + @Override + public Row map(Row in) { + double[] data = (double[]) in.getField(groupNumber + 1); + GarchModel ag = Garch.autoFit(data, upperbound, minusMean, ic, ifGARCH11); + + if (ag.isGoodFit()) { + double[] forecast = ag.forecast(this.predictNum); + return getData(in, + new DenseVector(forecast), + new DenseVector(ag.alpha), + new DenseVector(ag.seAlpha), + new DenseVector(ag.beta), + new DenseVector(ag.seBeta), + ag.c, + ag.seC, + ag.unconSigma2, + ag.ic, + ag.loglike, + new DenseVector(ag.hHat), + new DenseVector(ag.residual)); + + } else { + return getData(in, null, null, null, null, + null, null, null, null, + null, null, null, null); + } + } + + @Override + public AutoGarchPredictorCalc clone() { + AutoGarchPredictorCalc calc = new AutoGarchPredictorCalc(); + calc.predictNum = predictNum; + calc.groupNumber = groupNumber; + calc.ic = ic; + calc.upperbound = upperbound; + calc.ifGARCH11 = ifGARCH11; + calc.minusMean = minusMean; + return calc; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/BoxPlotDetectorCalc.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/BoxPlotDetectorCalc.java new file mode 100644 index 000000000..d3cab35f2 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/BoxPlotDetectorCalc.java @@ -0,0 +1,93 @@ +package com.alibaba.alink.operator.common.outlier.tsa.tsacalculator; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.operator.common.outlier.TimeSeriesAnomsUtils; +import com.alibaba.alink.params.outlier.HasBoxPlotK; +import com.alibaba.alink.params.outlier.tsa.HasBoxPlotRoundMode; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class BoxPlotDetectorCalc extends DecomposeOutlierDetectorCalc { + private static final long serialVersionUID = 2367292760089913511L; + private HasBoxPlotRoundMode.RoundMode roundMode; + private double k; + + public BoxPlotDetectorCalc(Params params) { + roundMode = params.get(HasBoxPlotRoundMode.ROUND_MODE); + k = params.get(HasBoxPlotK.K); + } + + @Override + public int[] detect(double[] data) { + return calcBoxPlot(data, roundMode, k, false); + } + + /** + * Here we consider the fact that data may contains NaN. + * If detectLast, we detect the last data which is not NaN. + */ + public static int[] calcBoxPlot(double[] data, HasBoxPlotRoundMode.RoundMode roundMode, double k, + boolean detectLast) { + int length = data.length; + List listData = new ArrayList <>(length); + for (double v : data) { + if (Double.isNaN(v)) { + length -= 1; + continue; + } + listData.add(v); + } + + if (detectLast && length <= 4) { + throw new RuntimeException("in detectLast mode, the data size must be larger than 4."); + } + if (length <= 3) { + return new int[0]; + } + double[] sortedData; + double lastData = 0; + if (detectLast) { + length -= 1; + lastData = listData.get(length); + } + sortedData = new double[length]; + for (int i = 0; i < length; i++) { + sortedData[i] = listData.get(i); + } + Arrays.sort(sortedData); + double q1; + double q3; + switch (roundMode) { + case CEIL: + q1 = sortedData[(int) Math.ceil(length * 0.25)]; + q3 = sortedData[(int) Math.ceil(length * 0.75)]; + break; + case FLOOR: + q1 = sortedData[(int) Math.floor(length * 0.25)]; + q3 = sortedData[(int) Math.floor(length * 0.75)]; + break; + case AVERAGE: + q1 = (sortedData[(int) Math.ceil(length * 0.25)] + sortedData[(int) Math.floor(length * 0.25)]) / 2; + q3 = (sortedData[(int) Math.ceil(length * 0.75)] + sortedData[(int) Math.floor(length * 0.75)]) / 2; + break; + default: + throw new RuntimeException("Only support ceil, floor and average strategy."); + } + List indices = new ArrayList <>(); + if (detectLast) { + if (TimeSeriesAnomsUtils.judgeBoxPlotAnom(lastData, q1, q3, k)) { + indices.add(length); + } + } else { + for (int i = 0; i < length; i++) { + if (TimeSeriesAnomsUtils.judgeBoxPlotAnom(data[i], q1, q3, k)) { + indices.add(i); + } + } + } + return indices.stream().mapToInt(a -> a).toArray(); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/ConvolutionDecomposerCalc.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/ConvolutionDecomposerCalc.java new file mode 100644 index 000000000..925719959 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/ConvolutionDecomposerCalc.java @@ -0,0 +1,119 @@ +package com.alibaba.alink.operator.common.outlier.tsa.tsacalculator; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.operator.common.outlier.TimeSeriesAnomsUtils; +import com.alibaba.alink.params.timeseries.HasFrequency; +import com.alibaba.alink.params.timeseries.holtwinters.HasSeasonalType; + +import java.util.Arrays; + +public class ConvolutionDecomposerCalc extends TimeSeriesDecomposerCalc { + + private boolean isAddType; + private int frequency; + + public ConvolutionDecomposerCalc(Params params) { + frequency = params.get(HasFrequency.FREQUENCY); + isAddType = params.get(HasSeasonalType.SEASONAL_TYPE) == HasSeasonalType.SeasonalType.ADDITIVE; + } + + //used in holt winters decompose. + @Override + public DenseVector[] decompose(double[] data) { + boolean addType = this.isAddType; + double[] filter; + double filterData = 1.0 / frequency; + if (frequency % 2 == 0) { + filter = new double[frequency + 1]; + Arrays.fill(filter, filterData); + filter[0] /= 2; + filter[frequency] /= 2; + } else { + filter = new double[frequency]; + Arrays.fill(filter, filterData); + } + //the length of filter equals to frequency. + //trend is the convolution of data. + double[] trend = cFilter(data, filter); + double[] seasonal = data.clone(); + if (addType) { + for (int i = 0; i < seasonal.length; i++) { + seasonal[i] -= trend[i]; + } + } else { + for (int i = 0; i < seasonal.length; i++) { + seasonal[i] /= trend[i]; + } + } + + int period = data.length / frequency; + double[] figure = new double[frequency]; + double tmp1; + int num1; + for (int i = 0; i < figure.length; i++) { + tmp1 = 0; + num1 = 0; + for (int j = 0; j < period; j++) { + if (Double.isNaN(seasonal[i + j * frequency])) { + continue; + } + tmp1 += seasonal[i + j * frequency]; + num1 += 1; + } + figure[i] = tmp1 / num1; + } + + double mean = TimeSeriesAnomsUtils.mean(figure); + double[] reminder = new double[data.length]; + if (addType) { + for (int i = 0; i < figure.length; i++) { + figure[i] -= mean; + } + for (int i = 0; i < data.length; i++) { + seasonal[i] = figure[i % frequency]; + reminder[i] = data[i] - trend[i] - seasonal[i]; + } + } else { + for (int i = 0; i < figure.length; i++) { + figure[i] /= mean; + } + for (int i = 0; i < data.length; i++) { + seasonal[i] = figure[i % frequency]; + reminder[i] = data[i] / trend[i] / seasonal[i]; + } + } + return new DenseVector[] {new DenseVector(trend), new DenseVector(seasonal), new DenseVector(reminder)}; + } + + //used in holt winters decompose. + private static double[] cFilter(double[] data, double[] filter) { + int start = 0; + int end = data.length; + int allSize = end - start; + double[] filterData = new double[allSize]; + Arrays.fill(filterData, Double.NaN); + int nf = filter.length; + int nShift = nf / 2; + double z; + boolean flag = true; + for (int i = (start + nShift); i < (end - nShift); i++) { + z = 0; + //卷积操作 + for (int j = (i - nShift); j < (i + nShift + 1); j++) { + if (Double.isNaN(data[j])) { + flag = false; + break; + } + z += data[j] * filter[j + nShift - i]; + } + if (flag) { + filterData[i] = z; + } + flag = true; + } + return filterData; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/DecomposeOutlierDetectorCalc.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/DecomposeOutlierDetectorCalc.java new file mode 100644 index 000000000..059017e88 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/DecomposeOutlierDetectorCalc.java @@ -0,0 +1,11 @@ +package com.alibaba.alink.operator.common.outlier.tsa.tsacalculator; + +import java.io.Serializable; + +public abstract class DecomposeOutlierDetectorCalc implements Serializable { + private static final long serialVersionUID = -1356304484256542192L; + + public abstract int[] detect(double[] data); + + public void reset() {} +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/HoltWintersPredictorCalc.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/HoltWintersPredictorCalc.java new file mode 100644 index 000000000..ec645f8de --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/HoltWintersPredictorCalc.java @@ -0,0 +1,160 @@ +package com.alibaba.alink.operator.common.outlier.tsa.tsacalculator; + +import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.operator.common.timeseries.holtwinter.HoltWinters; +import com.alibaba.alink.operator.common.timeseries.holtwinter.HoltWintersModel; +import com.alibaba.alink.params.timeseries.HasFrequency; +import com.alibaba.alink.params.outlier.tsa.HasPredictNum; +import com.alibaba.alink.params.timeseries.holtwinters.HasAlpha; +import com.alibaba.alink.params.timeseries.holtwinters.HasBeta; +import com.alibaba.alink.params.timeseries.holtwinters.HasDoSeasonal; +import com.alibaba.alink.params.timeseries.holtwinters.HasDoTrend; +import com.alibaba.alink.params.timeseries.holtwinters.HasGamma; +import com.alibaba.alink.params.timeseries.holtwinters.HasLevelStart; +import com.alibaba.alink.params.timeseries.holtwinters.HasSeasonalStart; +import com.alibaba.alink.params.timeseries.holtwinters.HasSeasonalType; +import com.alibaba.alink.params.timeseries.holtwinters.HasSeasonalType.SeasonalType; +import com.alibaba.alink.params.timeseries.holtwinters.HasTrendStart; + +public class HoltWintersPredictorCalc extends TimeSeriesPredictorCalc { + private static final long serialVersionUID = 6298595998408725962L; + private double alpha; + private double beta; + private double gamma; + private int size; + private boolean isAddType; + private SeasonalType seasonalType; + boolean doTrend; + boolean doSeasonal; + private int frequency; + //a, b, s is the initial data of level, trend and seasonalPeriod. + private Double a; + private Double b; + private double[] s; + private DenseVector res; + private Double sse; + //for reset + private Double saveA; + private Double saveB; + private double[] saveS; + private DenseVector saveRes; + + HoltWintersPredictorCalc() {} + + public HoltWintersPredictorCalc(Params params) { + initParams(params); + saveA = a; + saveB = b; + saveS = s; + saveRes = res; + } + + @Override + public void reset() { + a = saveA; + b = saveB; + s = saveS; + res = saveRes; + } + + @Override + public HoltWintersPredictorCalc clone() { + HoltWintersPredictorCalc calc = new HoltWintersPredictorCalc(); + calc.alpha = alpha; + calc.beta = beta; + calc.gamma = gamma; + calc.size = size; + calc.isAddType = isAddType; + calc.frequency = frequency; + calc.a = a; + calc.b = b; + if (s != null) { + calc.s = s.clone(); + } + if (res != null) { + calc.res = res.clone(); + } + calc.saveA = saveA; + calc.saveB = saveB; + if (saveS != null) { + calc.saveS = saveS.clone(); + } + if (saveRes != null) { + calc.saveRes = saveRes.clone(); + } + return calc; + } + + @Override + public double[] forecastWithoutException(double[] data, int forecastStep, boolean overWritten) { + HoltWintersModel model = HoltWinters.fit(data, frequency, + alpha, + beta, + gamma, + doTrend, + doSeasonal, + seasonalType, + a, + b, + s); + + return model.forecast(predictNum); + } + + @Override + public double[] predict(double[] data) { + if (predictNum == null) { + throw new RuntimeException("Please set forecast number first!"); + } + return forecast(data, predictNum, false).f0; + } + + @Override + public Row map(Row in) { + double[] data = (double[]) in.getField(groupNumber + 1); + data = predict(data); + return getData(in, new DenseVector(data)); + } + + + private void initParams(Params params) { + frequency = params.get(HasFrequency.FREQUENCY); + alpha = params.get(HasAlpha.ALPHA); + beta = params.get(HasBeta.BETA); + gamma = params.get(HasGamma.GAMMA); + predictNum = params.get(HasPredictNum.PREDICT_NUM); + doTrend = params.get(HasDoTrend.DO_TREND); + doSeasonal = params.get(HasDoSeasonal.DO_SEASONAL); + //level是一定会有的。而且alpha, beta, gamma有初始值。所以只需要判断这一个就好了。 + if (doSeasonal && !doTrend) { + throw new RuntimeException("seasonal time serial must have trend."); + } + isAddType = params.get(HasSeasonalType.SEASONAL_TYPE) == HasSeasonalType.SeasonalType.ADDITIVE; + if (doSeasonal) { + size = 3; + } else if (doTrend) { + size = 2; + } else { + size = 1; + } + + if (params.contains(HasLevelStart.LEVEL_START)) { + a = params.get(HasLevelStart.LEVEL_START); + } + if (params.contains(HasTrendStart.TREND_START)) { + b = params.get(HasTrendStart.TREND_START); + } + if (params.contains(HasSeasonalStart.SEASONAL_START)) { + s = params.get(HasSeasonalStart.SEASONAL_START); + if (s.length != frequency) { + throw new RuntimeException("the length of " + + "seasonal start data must equal to frequency."); + } + } + seasonalType = params.get(HasSeasonalType.SEASONAL_TYPE); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/KSigmaDetectorCalc.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/KSigmaDetectorCalc.java new file mode 100644 index 000000000..881ac0316 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/KSigmaDetectorCalc.java @@ -0,0 +1,70 @@ +package com.alibaba.alink.operator.common.outlier.tsa.tsacalculator; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.operator.common.outlier.TimeSeriesAnomsUtils; +import com.alibaba.alink.params.outlier.HasKSigmaK; + +import java.util.ArrayList; +import java.util.List; + +public class KSigmaDetectorCalc extends DecomposeOutlierDetectorCalc { + + private double k; + + public KSigmaDetectorCalc(Params params) { + this.k = params.get(HasKSigmaK.K); + } + + @Override + public int[] detect(double[] data) { + return calcKSigma(data, this.k, false); + } + + //calc ksigma in the whole time serial. + public static int[] calcKSigma(double[] data, double k, boolean detectLast) { + double sum = 0.0; + double squareSum = 0.0; + int length = data.length; + + if (detectLast) { + for (int i = 0; i < data.length - 1; i++) { + sum += data[i]; + squareSum += Math.pow(data[i], 2); + } + } else { + for (double datum : data) { + if (Double.isNaN(datum)) { + length -= 1; + continue; + } + sum += datum; + squareSum += Math.pow(datum, 2); + } + } + double mean = sum / length; + double variance; + if (length == 0 || length == 1) { + variance = 0; + } else { + variance = Math.max((squareSum - Math.pow(sum, 2) / length) / (length - 1), 0); + } + List indices = new ArrayList <>(); + if (detectLast) { + int i = length - 1; + double score = TimeSeriesAnomsUtils.calcKSigmaScore(data[i], mean, variance); + if (score >= k) { + indices.add(i); + } + } else { + for (int i = 0; i < length; i++) { + double score = TimeSeriesAnomsUtils.calcKSigmaScore(data[i], mean, variance); + if (score >= k) { + indices.add(i); + } + } + } + + return indices.stream().mapToInt(a -> a).toArray(); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/PredictOutlierDetectorCalc.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/PredictOutlierDetectorCalc.java new file mode 100644 index 000000000..c9830abcb --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/PredictOutlierDetectorCalc.java @@ -0,0 +1,44 @@ +package com.alibaba.alink.operator.common.outlier.tsa.tsacalculator; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.linalg.SparseVector; +import com.alibaba.alink.params.outlier.tsa.HasTrainNum; + +import java.io.Serializable; + +public abstract class PredictOutlierDetectorCalc implements Serializable { + private static final long serialVersionUID = 8451571015820252352L; + public int trainNum; + + public PredictOutlierDetectorCalc() {} + public PredictOutlierDetectorCalc(Params params) { + this.trainNum = params.get(HasTrainNum.TRAIN_NUM); + } + + //this is used in fitted data detect, and the temp sum and square sum will be saved to speed up calculation. + //return whether the last data is outlier or not. If outlier, return the modified data. + // todo only used in test. + public abstract Tuple2 predictBatchLast(double[] data); + + public void setTrainNum(int trainNum) { + this.trainNum = trainNum; + } + + /** + * @param realData the real data. + * @param formerData the former residual data which is used to help detect outliers. In this function, it will be + * updated with current residual data and help in the next batch data. + * @param predData the predicted data. + */ + public abstract int[] detectAndUpdateFormerData(double[] realData, double[] formerData, double[] predData); + + public abstract SparseVector detect(double[] data); + + public abstract void trainModel(double[] data); + + public abstract PredictOutlierDetectorCalc clone(); + + public abstract void reset(); +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/SHESDDetectorCalc.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/SHESDDetectorCalc.java new file mode 100644 index 000000000..f912ca0a9 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/SHESDDetectorCalc.java @@ -0,0 +1,126 @@ +package com.alibaba.alink.operator.common.outlier.tsa.tsacalculator; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.operator.common.outlier.CalcMidian; +import com.alibaba.alink.operator.common.outlier.TimeSeriesAnomsUtils; +import com.alibaba.alink.params.outlier.HasDirection; +import com.alibaba.alink.params.outlier.tsa.HasMaxAnoms; +import com.alibaba.alink.params.outlier.tsa.HasSHESDAlpha; +import com.alibaba.alink.params.outlier.tsa.TsaAlgoParams.SHESDAlgoParams; + +import java.util.HashSet; + +public class SHESDDetectorCalc extends DecomposeOutlierDetectorCalc { + private static final long serialVersionUID = 8617646058413698191L; + private double maxAnoms; + private double alpha; + private HasDirection.Direction direction; + + private int frequency; + + public SHESDDetectorCalc(Params params) { + maxAnoms = params.get(HasMaxAnoms.MAX_ANOMS); + alpha = params.get(HasSHESDAlpha.SHESD_ALPHA); + direction = params.get(HasDirection.DIRECTION); + frequency = params.get(SHESDAlgoParams.FREQUENCY); + } + + @Override + public int[] detect(double[] data) { + return detect(data, frequency, maxAnoms, alpha, direction); + } + + public static int[] detect(double[] data, int frequency, double maxAnoms, double alpha, + HasDirection.Direction direction) { + DenseVector[] components = STLDecomposerCalc.decompose(data, frequency); + return shesdMethod(data, components, maxAnoms, alpha, direction); + } + + //the indices of sv is the indices of anomalies data, and value is the reminder. + public static int[] shesdMethod(double[] data, DenseVector[] components, + double maxAnoms, double alpha, + HasDirection.Direction direction) { + double[] trend = components[0].getData(); + double[] seasonal = components[1].getData(); + int dataNum = data.length; + double dataMedian = CalcMidian.tempMedian(data); + double[] dataDecomp = new double[dataNum]; + //here minus median and seasonal of data. + for (int i = 0; i < dataNum; i++) { + data[i] -= (dataMedian + seasonal[i]); + dataDecomp[i] = trend[i] + seasonal[i]; + } + CalcMidian calcMidian = new CalcMidian(data); + int maxOutliers = (int) Math.floor(dataNum * maxAnoms); + int[] outlierIndex = new int[maxOutliers]; + double[] ares = new double[dataNum]; + HashSet excludedIndices = new HashSet <>(); + int numAnoms = 0; + for (int i = 1; i < maxOutliers + 1; i++) { + double median = calcMidian.median(); + //area is the deviation value of each data. + switch (direction) { + case POSITIVE: + for (int j = 0; j < dataNum; j++) { + ares[j] = data[j] - median; + } + break; + case NEGATIVE: + for (int j = 0; j < dataNum; j++) { + ares[j] = median - data[j]; + } + break; + case BOTH: + for (int j = 0; j < dataNum; j++) { + ares[j] = Math.abs(data[j] - median); + } + break; + default: + } + double dataSigma = TimeSeriesAnomsUtils.mad(calcMidian); + + if (Math.abs(dataSigma) < 1e-4) { + break; + } + double maxValue = -Double.MAX_VALUE; + int maxIndex = -1; + for (int j = 0; j < dataNum; j++) { + if (!excludedIndices.contains(j)) { + if (ares[j] > maxValue) { + maxValue = ares[j]; + maxIndex = j; + } + } + } + //添加的时候就意味着计算中位数的地方要扣除相应的。 + calcMidian.remove(data[maxIndex]); + excludedIndices.add(maxIndex); + + maxValue /= dataSigma; + outlierIndex[i - 1] = maxIndex; + double p; + //tempNum is the sample num? + int tempNum = dataNum - i + 1; + if (direction == HasDirection.Direction.BOTH) { + p = 1 - alpha / (2 * tempNum); + } else { + p = 1 - alpha / tempNum; + } + double t = TimeSeriesAnomsUtils.tppf(p, tempNum - 2); + //lam is the hypothesis test condition. + double lam = t * (tempNum - 1) / Math.sqrt((tempNum - 2 + t * t) * tempNum); + if (maxValue > lam) { + numAnoms = i; + } + } + if (numAnoms == 0) { + return new int[0]; + } else { + int[] indices = new int[numAnoms]; + System.arraycopy(outlierIndex, 0, indices, 0, numAnoms); + return indices; + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/STLDecomposerCalc.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/STLDecomposerCalc.java new file mode 100644 index 000000000..89c3f95d6 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/STLDecomposerCalc.java @@ -0,0 +1,28 @@ +package com.alibaba.alink.operator.common.outlier.tsa.tsacalculator; + +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.operator.common.outlier.tsa.STLMethod; +import com.alibaba.alink.params.timeseries.HasFrequency; + +public class STLDecomposerCalc extends TimeSeriesDecomposerCalc { + private static final long serialVersionUID = 6666098921428864819L; + private int frequency; + + public STLDecomposerCalc(Params params) { + frequency = params.get(HasFrequency.FREQUENCY); + } + + @Override + public DenseVector[] decompose(double[] data) { + return decompose(data, frequency); + } + + public static DenseVector[] decompose(double[] data, int frequency) { + STLMethod stl = new STLMethod(data, frequency, "periodic", 0, null, null, + null, null, null, null, null, true, null, null); + return new DenseVector[] {new DenseVector(stl.trend), + new DenseVector(stl.seasonal), new DenseVector(stl.remainder)}; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/ShortMoMDetectorCalc.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/ShortMoMDetectorCalc.java new file mode 100644 index 000000000..f77756f2b --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/ShortMoMDetectorCalc.java @@ -0,0 +1,202 @@ +package com.alibaba.alink.operator.common.outlier.tsa.tsacalculator; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.linalg.SparseVector; +import com.alibaba.alink.operator.common.outlier.TimeSeriesAnomsUtils; +import com.alibaba.alink.params.outlier.tsa.TsaAlgoParams.ShortMoMAlgoParams; +import org.apache.commons.lang3.ArrayUtils; + +import java.util.ArrayList; +import java.util.List; +import java.util.PriorityQueue; + +public class ShortMoMDetectorCalc extends PredictOutlierDetectorCalc { + private static final long serialVersionUID = -1173292168357077056L; + private int anomsNum; + private double influence; + // + private PriorityQueue smallHeap = new PriorityQueue <>(); + private PriorityQueue largeHeap = new PriorityQueue <>(); + private double sum; + private Double lastData; + + ShortMoMDetectorCalc() {} + + public ShortMoMDetectorCalc(Params params) { + super(params); + anomsNum = params.get(ShortMoMAlgoParams.ANOMS_NUM); + influence = params.get(ShortMoMAlgoParams.INFLUENCE); + if (anomsNum > trainNum) { + throw new RuntimeException("the trainNum should be set larger than the anomaly number."); + } + } + + @Override + public ShortMoMDetectorCalc clone() { + ShortMoMDetectorCalc calc = new ShortMoMDetectorCalc(); + calc.trainNum = trainNum; + calc.anomsNum = anomsNum; + calc.influence = influence; + if (smallHeap != null) { + calc.smallHeap = new PriorityQueue <>(); + calc.smallHeap.addAll(smallHeap); + } + if (largeHeap != null) { + calc.largeHeap = new PriorityQueue <>(); + calc.largeHeap.addAll(largeHeap); + } + calc.sum = sum; + calc.lastData = lastData; + return calc; + } + + @Override + public void reset() { + smallHeap = new PriorityQueue <>(); + largeHeap = new PriorityQueue <>(); + sum = 0; + lastData = null; + } + + @Override + public Tuple2 predictBatchLast(double[] data) { + int trainLength = data.length - 1; + boolean isOutlier = false; + double fittedData = data[trainLength]; + if (lastData == null) { + for (int i = 0; i < trainLength; i++) { + largeHeap.add(-data[i]); + smallHeap.add(data[i]); + sum += data[i]; + } + } else { + largeHeap.remove(-lastData); + smallHeap.remove(lastData); + largeHeap.add(-data[trainLength - 1]); + smallHeap.add(data[trainLength - 1]); + sum += (data[trainLength - 1] - lastData); + } + lastData = data[0]; + double max = -largeHeap.peek(); + double min = smallHeap.peek(); + double mean = sum / trainLength; + double threshold = Math.min(max - mean, mean - min); + int num = countNum(0, trainLength, threshold, data); + if (num >= anomsNum) { + fittedData = influence * data[trainLength] + (1 - influence) * data[trainLength - 1]; + isOutlier = true; + } + return Tuple2.of(isOutlier, fittedData); + } + + @Override + public int[] detectAndUpdateFormerData(double[] realData, double[] formerData, double[] predData) { + int length = realData.length; + List anomsIndices = new ArrayList <>(); + for (int i = 0; i < length; i++) { + double max = -largeHeap.peek(); + double min = smallHeap.peek(); + double mean = sum / length; + double threshold = Math.min(max - mean, mean - min); + int num = countNum(formerData, predData[i] - realData[i], threshold); + if (num >= anomsNum) { + if (i == 0) { + predData[i] = influence * predData[i] + (1 - influence) * formerData[length - 1]; + } else { + predData[i] = influence * predData[i] + (1 - influence) * predData[i - 1]; + } + anomsIndices.add(i); + } + double detectData = predData[i] - realData[i]; + largeHeap.remove(-formerData[i]); + smallHeap.remove(formerData[i]); + largeHeap.add(-detectData); + smallHeap.add(detectData); + sum += detectData - formerData[i]; + formerData[i] = detectData;//直接将新的值写在formerData中。 + } + return ArrayUtils.toPrimitive(anomsIndices.toArray(new Integer[0])); + } + + @Override + public void trainModel(double[] data) { + for (double v : data) { + largeHeap.add(-v); + smallHeap.add(v); + sum += v; + } + } + + @Override + public SparseVector detect(double[] data) { + return detectAnoms(data, trainNum, anomsNum, influence); + } + + public static SparseVector detectAnoms(double[] data, int trainNum, int anomsNum, double influence) { + List indices = new ArrayList <>(); + + PriorityQueue littleHeap = new PriorityQueue <>(); + PriorityQueue largeHeap = new PriorityQueue <>(); + littleHeap.add(data[0]); + largeHeap.add(-data[0]); + double sum = data[0]; + double threshold; + int length = data.length; + //对于初始的样本数目小于trainNum的,则将全部样本用于计算。 + int num; + for (int i = 1; i < length - 1; i++) { + if (i < trainNum) { + largeHeap.add(-data[i]); + littleHeap.add(data[i]); + double max = -largeHeap.peek(); + double min = littleHeap.peek(); + sum += data[i]; + double mean = sum / (i + 1); + threshold = Math.min(max - mean, mean - min); + num = countNum(0, i + 1, threshold, data); + } else { + int preIndex = i - trainNum; + littleHeap.remove(data[preIndex]); + largeHeap.remove(-data[preIndex]); + littleHeap.add(data[i]); + largeHeap.add(-data[i]); + double max = -largeHeap.peek(); + double min = littleHeap.peek(); + sum += (data[i] - data[preIndex]); + double mean = sum / trainNum; + threshold = Math.min(max - mean, mean - min); + num = countNum(preIndex + 1, i + 1, threshold, data); + } + if (num >= anomsNum) { + indices.add(i + 1); + //修正异常值 + data[i + 1] = influence * data[i + 1] + (1 - influence) * data[i]; + } + } + SparseVector sv = TimeSeriesAnomsUtils.generateOutput(indices, data); + return sv; + } + + private static int countNum(int startIndex, int currentIndex, double threshold, double[] data) { + int num = 0; + for (int i = startIndex; i < currentIndex; i++) { + if (Math.abs(data[currentIndex] - data[i]) > threshold) { + num++; + } + } + return num; + } + + private static int countNum(double[] formerData, double currentData, double threshold) { + int num = 0; + int length = formerData.length; + for (int i = 0; i < length; i++) { + if (Math.abs(currentData - formerData[i]) > threshold) { + num++; + } + } + return num; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/SmoothZScoreDetectorCalc.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/SmoothZScoreDetectorCalc.java new file mode 100644 index 000000000..b81bdfb86 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/SmoothZScoreDetectorCalc.java @@ -0,0 +1,201 @@ +package com.alibaba.alink.operator.common.outlier.tsa.tsacalculator; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.misc.param.Params; + +import com.alibaba.alink.common.linalg.SparseVector; +import com.alibaba.alink.params.outlier.tsa.TsaAlgoParams.SmoothZScoreAlgoParams; +import org.apache.commons.lang3.ArrayUtils; + +import java.util.ArrayList; +import java.util.List; + +public class SmoothZScoreDetectorCalc extends PredictOutlierDetectorCalc { + private static final long serialVersionUID = -1510748793393605203L; + private double threshold; + private double influence; + //辅助计算,减少计算量。 + private double lastSum; + private double lastSquareSum; + private Double lastData; + + @Override + public SmoothZScoreDetectorCalc clone() { + SmoothZScoreDetectorCalc calc = new SmoothZScoreDetectorCalc(); + calc.trainNum = trainNum; + calc.threshold = threshold; + calc.influence = influence; + calc.lastSum = lastSum; + calc.lastSquareSum = lastSquareSum; + calc.lastData = lastData; + return calc; + } + + @Override + public void reset() { + this.lastSum = 0; + this.lastSquareSum = 0; + this.lastData = null; + } + + SmoothZScoreDetectorCalc() {} + + public SmoothZScoreDetectorCalc(Params params) { + super(params); + threshold = params.get(SmoothZScoreAlgoParams.THRESHOLD); + influence = params.get(SmoothZScoreAlgoParams.INFLUENCE); + } + + @Override + public SparseVector detect(double[] data) { + return calcZScore(data, trainNum, threshold, influence, false); + } + + @Override + public void trainModel(double[] data) { + double sum = 0; + double squareSum = 0; + for (double datum : data) { + sum += datum; + squareSum += Math.pow(datum, 2); + } + lastSum = sum; + lastSquareSum = squareSum; + } + + //detect last just use all to detect the last one. + //train the initial model with the first trainNum data. + public static SparseVector calcZScore(double[] data, int trainNum, double threshold, double influence, + boolean detectLast) { + if (detectLast) { + int length = data.length - 1; + double sum = 0; + double squareSum = 0; + for (int i = 0; i < length; i++) { + sum += data[i]; + squareSum += Math.pow(data[i], 2); + } + double avg = sum / length; + double std = Math.sqrt(squareSum / length - Math.pow(avg, 2)); + + if (Math.abs((data[length] - avg)) > threshold * std) { + return new SparseVector(1, new int[] {0}, new double[] {0}); + } else { + return new SparseVector(0); + } + } else { + if (data.length <= trainNum) { + return new SparseVector(); + } + Tuple2 > res = analyzeDataForSignals(data, trainNum, threshold, influence); + int length = res.f1.size(); + int[] indices = new int[length]; + double[] values = new double[length]; + for (int i = 0; i < length; i++) { + indices[i] = res.f1.get(i); + values[i] = data[indices[i]]; + } + return new SparseVector(data.length, indices, values); + } + } + + @Override + public Tuple2 predictBatchLast(double[] data) { + int trainLength = data.length - 1; + double avg; + double std; + boolean isOutlier = false; + double fittedData = data[trainLength]; + if (lastData == null) { + double sum = 0; + double squareSum = 0; + for (int i = 0; i < trainLength; i++) { + sum += data[i]; + squareSum += Math.pow(data[i], 2); + } + lastSum = sum; + lastSquareSum = squareSum; + } else { + lastSum += (data[trainLength - 1] - lastData); + lastSquareSum += (Math.pow(data[trainLength - 1], 2) - Math.pow(lastData, 2)); + } + lastData = data[0]; + avg = lastSum / trainLength; + std = Math.sqrt(lastSquareSum / trainLength - Math.pow(avg, 2)); + if (Math.abs((data[trainLength] - avg)) > threshold * std) { + fittedData = influence * data[trainLength] + (1 - influence) * data[trainLength - 1]; + isOutlier = true; + } + return Tuple2.of(isOutlier, fittedData); + } + + @Override + public int[] detectAndUpdateFormerData(double[] realData, double[] formerData, double[] predData) { + int predictNum = realData.length; + + List anomsIndices = new ArrayList <>(); + for (int i = 0; i < predictNum; i++) { + double avg = lastSum / predictNum; + double std = Math.sqrt(lastSquareSum / predictNum - Math.pow(avg, 2)); + double detectData = predData[i] - realData[i]; + if (Math.abs((detectData - avg)) > threshold * std) { + if (i == 0) { + predData[i] = influence * predData[i] + (1 - influence) * formerData[predictNum - 1]; + } else { + predData[i] = influence * predData[i] + (1 - influence) * predData[i - 1]; + } + detectData = predData[i] - realData[i]; + anomsIndices.add(i); + } + lastSum += detectData - formerData[i]; + lastSquareSum += Math.pow(detectData, 2) - Math.pow(formerData[i], 2); + //构造新的former。被替换成了新一轮预测中的残差数据。 + formerData[i] = predData[i] - realData[i]; + } + return ArrayUtils.toPrimitive(anomsIndices.toArray(new Integer[0])); + } + + //返回拟合了的数据以及异常点的id。 + public static Tuple2 > analyzeDataForSignals(double[] data, int trainNum, + double threshold, double influence) { + + // the results of our algorithm + ArrayList signals = new ArrayList <>(); + + // filter out the signals (peaks) from our original list (using influence arg) + double[] filteredData = data.clone();//the initial data (which is in the window count) is important. + // init avgFilter and stdFilter + double sum = 0; + double squareSum = 0; + for (int i = 0; i < trainNum; i++) { + sum += filteredData[i]; + squareSum += Math.pow(filteredData[i], 2); + } + + double avg = sum / trainNum; + double std = Math.sqrt(squareSum / trainNum - Math.pow(avg, 2)); + + // loop input starting at end of rolling window + for (int i = trainNum + 1; i < data.length; i++) { + // if the distance between the current value and average is enough standard deviations (threshold) away + if (Math.abs((data[i] - avg)) > threshold * std) { + // this is a signal (i.e. peak), determine if it is a positive or negative signal + signals.add(i); + // filter this signal out using influence + filteredData[i] = (influence * data[i]) + ((1 - influence) * filteredData[i - 1]); + } else { + // ensure this value is not filtered + filteredData[i] = data[i]; + } + sum -= filteredData[i - trainNum]; + sum += filteredData[i]; + squareSum -= Math.pow(filteredData[i - trainNum], 2); + squareSum += Math.pow(filteredData[i], 2); + + avg = sum / trainNum; + std = Math.sqrt(squareSum / trainNum - Math.pow(avg, 2)); + } + + return Tuple2.of(filteredData, signals); + } // end +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/TimeSeriesDecomposerCalc.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/TimeSeriesDecomposerCalc.java new file mode 100644 index 000000000..e87ef306b --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/TimeSeriesDecomposerCalc.java @@ -0,0 +1,13 @@ +package com.alibaba.alink.operator.common.outlier.tsa.tsacalculator; + +import com.alibaba.alink.common.linalg.DenseVector; + +import java.io.Serializable; + +public abstract class TimeSeriesDecomposerCalc implements Serializable { + private static final long serialVersionUID = 7014060875543471259L; + + public abstract DenseVector[] decompose(double[] data); + + public void reset() {} +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/TimeSeriesPredictorCalc.java b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/TimeSeriesPredictorCalc.java new file mode 100644 index 000000000..0fb9b0973 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/outlier/tsa/tsacalculator/TimeSeriesPredictorCalc.java @@ -0,0 +1,72 @@ +package com.alibaba.alink.operator.common.outlier.tsa.tsacalculator; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.types.Row; +import java.io.Serializable; +import java.util.Arrays; + +public abstract class TimeSeriesPredictorCalc implements Serializable { + + private static final long serialVersionUID = 7329549906350002422L; + public Integer predictNum; + public int groupNumber; + + public void setGroupNumber(int groupNumber) { + this.groupNumber = groupNumber; + } + + // todo: for the train of time series predict algo may fail, we catch the exception and + // fill all the predicted data with mean of train data. + public Tuple2 forecast(double[] data, int forecastStep, boolean trainBeforeForecast) { + try { + return Tuple2.of(forecastWithoutException(data, forecastStep, trainBeforeForecast), false); + } catch (Exception e) { + double[] res = new double[forecastStep]; + double sum = 0; + int count = 0; + for (double datum : data) { + ++count; + sum += datum; + } + sum /= count;//mean + Arrays.fill(res, sum); + return Tuple2.of(res, true); + } + } + + public abstract double[] forecastWithoutException(double[] data, int forecastStep, boolean trainBeforeForecast); + + public abstract double[] predict(double[] data); + + //this helps map. If output has more column than just predict, than use this. + public abstract Row map(Row in); + + public void reset() {} + + public Row getData(Row group, Object... others) { + return getData(group, groupNumber, others); + } + + + public abstract TimeSeriesPredictorCalc clone(); + + public String getCurrentModel() { + return null; + } + + public void setCurrentModel(String currentModel) { + + } + + public static Row getGroupData(Row group, int groupNumber, Object... others) { + int othersLength = others.length; + Row res = new Row(groupNumber + othersLength); + for (int i = 0; i < groupNumber; i++) { + res.setField(i, group.getField(i)); + } + for (int i = 0; i < othersLength; i++) { + res.setField(groupNumber + i, others[i]); + } + return res; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/recommendation/HugeMfAlsImpl.java b/core/src/main/java/com/alibaba/alink/operator/common/recommendation/HugeMfAlsImpl.java index 7604aa2ad..128ccb2ba 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/recommendation/HugeMfAlsImpl.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/recommendation/HugeMfAlsImpl.java @@ -34,7 +34,7 @@ import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.linalg.NormalEquation; import com.alibaba.alink.common.linalg.VectorUtil; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/recommendation/RecommKernel.java b/core/src/main/java/com/alibaba/alink/operator/common/recommendation/RecommKernel.java index 415383198..2b6dac23d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/recommendation/RecommKernel.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/recommendation/RecommKernel.java @@ -1,13 +1,12 @@ package com.alibaba.alink.operator.common.recommendation; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.table.api.Types; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; import org.apache.flink.table.types.DataType; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.MTable; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import com.alibaba.alink.common.utils.OutputColsHelper; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/recommendation/RecommMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/recommendation/RecommMapper.java index 457a338a4..8b717798d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/recommendation/RecommMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/recommendation/RecommMapper.java @@ -7,7 +7,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.MTable; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import com.alibaba.alink.common.mapper.ModelMapper; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/recommendation/RecommendationRankingMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/recommendation/RecommendationRankingMapper.java index 078121acf..0b771b99b 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/recommendation/RecommendationRankingMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/recommendation/RecommendationRankingMapper.java @@ -8,7 +8,7 @@ import org.apache.flink.types.Row; import com.alibaba.alink.common.MTable; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.MTableUtil; import com.alibaba.alink.common.mapper.MapperChain; import com.alibaba.alink.common.mapper.ModelMapper; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/regression/LinearReg.java b/core/src/main/java/com/alibaba/alink/operator/common/regression/LinearReg.java new file mode 100644 index 000000000..b67ec91b3 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/regression/LinearReg.java @@ -0,0 +1,261 @@ +package com.alibaba.alink.operator.common.regression; + +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.jama.JMatrixFunc; +import com.alibaba.alink.common.probabilistic.CDF; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.common.statistics.statistics.SummaryResultTable; + +import java.util.ArrayList; + +/** + * @author yangxu + */ +public class LinearReg { + + /** + * * + * 线性回归训练 + * + * @param srt 数据表的基本统计结果 + * @param nameY 因变量名称 + * @param nameX 自变量名称 + * @return 线性回归模型 + * @throws Exception + */ + public static LinearRegressionModel train(SummaryResultTable srt, String nameY, String[] nameX) throws Exception { + if (srt == null) { + throw new Exception("srt must not null!"); + } + String[] colNames = srt.colNames; + Class[] types = new Class[colNames.length]; + for (int i = 0; i < colNames.length; i++) { + types[i] = srt.col(i).dataType; + } + int indexY = TableUtil.findColIndexWithAssertAndHint(colNames, nameY); + Class typeY = types[indexY]; + if (typeY != Double.class && typeY != Long.class && typeY != Integer.class) { + throw new Exception("col type must be double or bigint!"); + } + if (nameX.length == 0) { + throw new Exception("nameX must input!"); + } + for (int i = 0; i < nameX.length; i++) { + int indexX = TableUtil.findColIndexWithAssertAndHint(colNames, nameX[i]); + Class typeX = types[indexX]; + if (typeX != Double.class && typeX != Long.class && typeX != Integer.class) { + throw new Exception("col type must be double or bigint!"); + } + } + int nx = nameX.length; + int[] indexX = new int[nx]; + for (int i = 0; i < nx; i++) { + indexX[i] = TableUtil.findColIndexWithAssert(srt.colNames, nameX[i]); + } + + return train(srt, indexY, indexX, nameY, nameX); + } + + /** + * * + * 加载线性回归模型 + * + * @param inputModelTableName 模型表 + * @return 线性回归模型 + * @throws Exception + */ + public static LinearRegressionModel loadModel(String inputModelTableName) throws Exception { + // if (!isModel(inputModelTableName)) { + // throw new Exception("model must be linear regression model!"); + // } + // OdpsTable otable = new OdpsTable(inputModelTableName); + // int count = (int) otable.getRecordCount(); + // long nRecord = 0; + // String nameY; + // String[] nameX = new String[count - 2]; + // double[] beta = new double[count - 1]; + // + // OdpsTableInputStream xis = new OdpsTableInputStream(otable); + // ArrayList r = xis.getRecordTemplate(); + // xis.read(r); + // beta[0] = (Double) r.get(1); + // for (int i = 0; i < count - 2; i++) { + // xis.read(r); + // beta[i + 1] = (Double) r.get(1); + // nameX[i] = (String) r.get(0); + // } + // xis.read(r); + // nameY = (String) r.get(0); + // xis.close(); + // LinearRegressionModel lrm = new LinearRegressionModel(nRecord, nameY, nameX); + // lrm.beta = beta; + // return lrm; + return null; + } + + /** + * * + * 线性回归预测 + * + * @param model 线性回归模型 + * @param predictTableName 预测输入表 + * @param selectedPartitions 预测输入表的分区 + * @param appendColNames 输出表添加预测表的列名 + * @param resultTableName 输出表 + * @param resultTablePartitionName 输出表的分区 + * @return + * @throws Exception + */ + public static void predict(LinearRegressionModel model, String predictTableName, String[] selectedPartitions, + String[] appendColNames, String resultTableName, String resultTablePartitionName) + throws Exception { + + } + + static void write(LinearRegressionModel model, String outModelTableName) throws Exception { + // String[] colNames = new String[]{"colName", "coefficient"}; + // Class[] types = new Class[]{String.class, Double.class}; + // Object[][] data = new Object[model.nameX.size + 2][2]; + // data[0][0] = "constant term"; + // data[0][1] = model.beta[0]; + // for (int i = 0; i < model.nameX.size; i++) { + // data[i + 1][0] = model.nameX[i]; + // data[i + 1][1] = model.beta[i + 1]; + // } + // data[model.nameX.size + 1][0] = model.nameY; + // data[model.nameX.size + 1][1] = 0.0; + // MTable mt = new MTable(colNames, types, data); + // boolean res = OdpsTableWriter.write(outModelTableName, mt); + // if (!res) { + // throw new Exception("write model error!"); + // } + // setModelMeta(outModelTableName); + } + + static LinearRegressionModel train(SummaryResultTable srt, int indexY, int[] indexX) throws Exception { + int nx = indexX.length; + String[] nameX = new String[nx]; + for (int i = 0; i < nx; i++) { + nameX[i] = srt.colNames[indexX[i]]; + } + String nameY = srt.colNames[indexY]; + return train(srt, indexY, indexX, nameY, nameX); + } + + private static LinearRegressionModel train(SummaryResultTable srt, int indexY, int[] indexX, String nameY, + String[] nameX) throws Exception { + //check if has missing value or nan value + if (srt.col(indexY).countMissValue > 0 || srt.col(indexY).countNanValue > 0) { + throw new Exception("col " + nameY + " has null value or nan value!"); + } + for (int i = 0; i < indexX.length; i++) { + if (srt.col(indexX[i]).countMissValue > 0 || srt.col(indexX[i]).countNanValue > 0) { + throw new Exception("col " + nameX[i] + " has null value or nan value!"); + } + } + + if (srt.col(0).countTotal == 0) { + throw new Exception("table is empty!"); + } + if (srt.col(0).countTotal < nameX.length) { + throw new Exception("record size Less than features size!"); + } + + int nx = indexX.length; + long N = srt.col(indexY).count; + if (N == 0) { + throw new Exception("Y valid value num is zero!"); + } + + ArrayList nameXList = new ArrayList (); + for (int i = 0; i < indexX.length; i++) { + if (srt.col(indexX[i]).count != 0) { + nameXList.add(nameX[i]); + } + } + nameXList.toArray(nameX); + LinearRegressionModel lrr = new LinearRegressionModel(N, nameY, nameX); + + double[] XBar = new double[nx]; + for (int i = 0; i < nx; i++) { + XBar[i] = srt.col(indexX[i]).mean(); + } + double yBar = srt.col(indexY).mean(); + + double[][] cov = srt.getCov(); + DenseMatrix A = new DenseMatrix(nx, nx); + for (int i = 0; i < nx; i++) { + for (int j = 0; j < nx; j++) { + A.set(i, j, cov[indexX[i]][indexX[j]]); + } + } + DenseMatrix C = new DenseMatrix(nx, 1); + for (int i = 0; i < nx; i++) { + C.set(i, 0, cov[indexX[i]][indexY]); + } + + DenseMatrix BetaMatrix = A.solveLS(C); + + double d = yBar; + for (int i = 0; i < nx; i++) { + lrr.beta[i + 1] = BetaMatrix.get(i, 0); + d -= XBar[i] * lrr.beta[i + 1]; + } + lrr.beta[0] = d; + + double S = srt.col(nameY).variance() * (srt.col(nameY).count - 1); + double alpha = lrr.beta[0] - yBar; + double U = 0.0; + U += alpha * alpha * N; + for (int i = 0; i < nx; i++) { + U += 2 * alpha * srt.col(indexX[i]).sum * lrr.beta[i + 1]; + } + for (int i = 0; i < nx; i++) { + for (int j = 0; j < nx; j++) { + U += lrr.beta[i + 1] * lrr.beta[j + 1] * (cov[indexX[i]][indexX[j]] * (N - 1) + srt.col(indexX[i]) + .mean() * srt.col(indexX[j]).mean() * N); + } + } + + lrr.SST = S; + lrr.SSR = U; + lrr.SSE = S - U; + lrr.dfSST = N - 1; + lrr.dfSSR = nx; + lrr.dfSSE = N - nx - 1; + lrr.R2 = Math.max(0.0, Math.min(1.0, lrr.SSR / lrr.SST)); + lrr.R = Math.sqrt(lrr.R2); + lrr.MST = lrr.SST / lrr.dfSST; + lrr.MSR = lrr.SSR / lrr.dfSSR; + lrr.MSE = lrr.SSE / lrr.dfSSE; + lrr.Ra2 = 1 - lrr.MSE / lrr.MST; + lrr.s = Math.sqrt(lrr.MSE); + lrr.F = lrr.MSR / lrr.MSE; + if (lrr.F < 0) { + lrr.F = 0; + } + lrr.AIC = N * Math.log(lrr.SSE) + 2 * nx; + + A.scaleEqual(N - 1); + // DenseMatrix invA = A.Inverse(); + DenseMatrix invA = A.solveLS(JMatrixFunc.identity(A.numRows(), A.numRows())); + + for (int i = 0; i < nx; i++) { + lrr.FX[i] = lrr.beta[i + 1] * lrr.beta[i + 1] / (lrr.MSE * invA.get(i, i)); + lrr.TX[i] = lrr.beta[i + 1] / (lrr.s * Math.sqrt(invA.get(i, i))); + } + + try { + int p = nameX.length; + lrr.pEquation = 1 - CDF.F(lrr.F, p, N - p - 1); + lrr.pX = new double[nx]; + for (int i = 0; i < nx; i++) { + lrr.pX[i] = (1 - CDF.studentT(Math.abs(lrr.TX[i]), N - p - 1)) * 2; + } + } catch (Exception ex) { + + } + return lrr; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/regression/LinearRegressionModel.java b/core/src/main/java/com/alibaba/alink/operator/common/regression/LinearRegressionModel.java new file mode 100644 index 000000000..c176e064a --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/regression/LinearRegressionModel.java @@ -0,0 +1,216 @@ +package com.alibaba.alink.operator.common.regression; + +import com.alibaba.alink.common.annotation.NameCn; +import com.alibaba.alink.common.utils.JsonConverter; + +/** + * @author yangxu + */ +@NameCn("线性回归模型") +public class LinearRegressionModel implements RegressionModelInterface { + + /** + * * + * 因变量名称 + */ + public String nameY; + /** + * * + * 自变量名称 + */ + public String[] nameX = new String[0]; + /** + * * + * 记录的总个数 + */ + public long n; + /** + * * + * 总离差平方和 + */ + public double SST; + /** + * * + * 回归平方和 + */ + public double SSR; + /** + * * + * 剩余平方和 + */ + public double SSE; + /** + * * + * 总离差平方和的自由度 + */ + public double dfSST; + /** + * * + * 回归平方和的自由度 + */ + public double dfSSR; + /** + * * + * 剩余平方和的自由度 + */ + public double dfSSE; + /** + * * + * 多重判定系数 + */ + public double R2; + /** + * * + * 多重相关系数 + */ + public double R; + /** + * * + * 修正的多重判定系数 + */ + public double Ra2; + /** + * * + * 总离差均方 + */ + public double MST; + /** + * * + * 回归均方 + */ + public double MSR; + /** + * * + * 剩余均方 + */ + public double MSE; + /** + * * + * 剩余标准差 + */ + public double s; + /** + * * + * 回归方程F检验值 + */ + public double F; + /** + * * + * 回归方程F检验的P-值 + */ + public double pEquation; + /** + * * + * 回归系数 + */ + public double[] beta = null; + /** + * * + * 各变量F检验值 + */ + public double[] FX = null; + /** + * * + * 各变量T检验值 + */ + public double[] TX = null; + /** + * * + * 各变量双侧T检验的P-值 + */ + public double[] pX = null; + /** + * * + * AIC信息统计量,Akaike Information Criterion AIC = n*Ln(SSE)+2*p + */ + public double AIC; + + public LinearRegressionModel(long nRecord, String nameY, String[] nameX) { + int nx = nameX.length; + n = nRecord; + beta = new double[nx + 1]; + FX = new double[nx]; + TX = new double[nx]; + this.nameY = nameY; + this.nameX = new String[nameX.length]; + System.arraycopy(nameX, 0, this.nameX, 0, nameX.length); + } + + /** + * * + * 计算Cp统计量 Cp = ( n - m - 1 )*( SSEp / SSEm )- n + 2*( p + 1 ) + * + * @param m 全部变量的个数 + * @param SSEm 对于全部变量的误差和 + * @return Cp统计量 + */ + public double getCp(int m, double SSEm) { + int p = nameX.length; + if (p > m) { + throw new RuntimeException(); + } + return (n - m - 1) * this.SSE / SSEm - n + 2 * (p + 1); + } + + public String toJson() { + return JsonConverter.gson.toJson(this); + } + + @Override + public String toString() { + int m = nameX.length; + java.io.CharArrayWriter cw = new java.io.CharArrayWriter(); + java.io.PrintWriter pw = new java.io.PrintWriter(cw, true); + if (nameX.length > 0) { + pw.print(nameY + " = " + beta[0]); + for (int i = 0; i < nameX.length; i++) { + pw.print(" + " + beta[i + 1] + " * " + nameX[i]); + } + pw.println(); + } + if (TX != null) { + pw.print("RegCoef ="); + for (int i = 0; i <= m; i++) { + pw.print(" \t"); + pw.print(beta[i]); + } + pw.println(); + pw.print("R = "); + pw.print(R); + pw.print(" \tR2 = "); + pw.print(R2); + pw.print(" \tRa2 = "); + pw.println(Ra2); + pw.print("F = "); + pw.print(F); + pw.print(" \tp_value = "); + pw.println(this.pEquation); + + pw.print("FX ="); + for (int i = 0; i < m; i++) { + pw.print(" \t"); + pw.print(FX[i]); + } + + pw.println(); + + pw.print("TX ="); + for (int i = 0; i < m; i++) { + pw.print(" \t"); + pw.print(TX[i]); + } + + pw.println(); + + pw.print("pX ="); + for (int i = 0; i < m; i++) { + pw.print(" \t"); + pw.print(pX[i]); + } + + } + pw.println(); + return cw.toString(); + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/regression/LinearRegressionStepwise.java b/core/src/main/java/com/alibaba/alink/operator/common/regression/LinearRegressionStepwise.java new file mode 100644 index 000000000..e9fa10440 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/regression/LinearRegressionStepwise.java @@ -0,0 +1,312 @@ +/* + * To change this template, choose Tools | Templates + * and open the template in the editor. + */ +package com.alibaba.alink.operator.common.regression; + +import com.alibaba.alink.common.AlinkGlobalConfiguration; +import com.alibaba.alink.operator.common.statistics.statistics.SummaryResultTable; +import com.alibaba.alink.params.regression.LinearRegStepwiseTrainParams.Method; + +import java.util.ArrayList; + +/** + * @author yangxu + */ +public class LinearRegressionStepwise { + + public static LinearRegressionStepwiseModel step(SummaryResultTable srt, String nameY, String[] nameX, + Method selection) throws Exception { + return step(srt, nameY, nameX, selection, Double.NaN, Double.NaN); + } + + public static LinearRegressionStepwiseModel step(SummaryResultTable srt, String nameY, String[] nameX, + Method selection, double alphaEntry, double alphaStay) + throws Exception { + if (null == nameX || null == nameY) { + throw new RuntimeException(); + } + LinearRegressionStepwiseModel lrsr = null; + switch (selection) { + case Forward: + if (Double.isNaN(alphaEntry)) { + alphaEntry = 0.2; + } + lrsr = stepForward(srt, nameY, nameX, alphaEntry, alphaStay); + break; + case Backward: + if (Double.isNaN(alphaStay)) { + alphaStay = 0.1; + } + lrsr = stepBackward(srt, nameY, nameX, alphaEntry, alphaStay); + break; + case Stepwise: + if (Double.isNaN(alphaEntry)) { + alphaEntry = 0.1; + } + if (Double.isNaN(alphaStay)) { + alphaStay = 0.15; + } + lrsr = stepStepwise(srt, nameY, nameX, alphaEntry, alphaStay); + break; + default: + throw new RuntimeException("Not implemented yet!"); + } + + return lrsr; + } + + public static LinearRegressionStepwiseModel step(SummaryResultTable srt, String nameY, String[] nameX, + StepCriterion crit) throws Exception { + java.io.CharArrayWriter cw = new java.io.CharArrayWriter(); + java.io.PrintWriter pw = new java.io.PrintWriter(cw, true); + pw.println("\nSelected Criterion : " + crit + "\n"); + int nx = nameX.length; + if (nx > 13) { + throw new RuntimeException("Not implemented yet!"); + } + int nSet = (1 << nx) - 1; + + LinearRegressionModel lrrall = LinearReg.train(srt, nameY, nameX); + + double maxValue = Double.NEGATIVE_INFINITY; + switch (crit) { + //下面两个准则的值,希望达到最大 + case R2: + case AdjR2: + maxValue = Double.NEGATIVE_INFINITY; + break; + //以下准则的值,希望达到最小 + case SSE: + case MSE: + case AIC: + case Cp: + maxValue = Double.POSITIVE_INFINITY; + break; + default: + throw new RuntimeException("Not implemented yet!"); + } + + LinearRegressionModel lrrbest = null; + for (int k = 1; k <= nSet; k++) { + ArrayList cols = new ArrayList (); + int subk = k; + for (int i = 0; i < nx; i++) { + if (subk % 2 == 1) { + cols.add(nameX[i]); + } + subk = subk >> 1; + } + LinearRegressionModel lrrcur = LinearReg.train(srt, nameY, cols.toArray(new String[0])); + double valCrit = Double.NaN; + switch (crit) { + case SSE: + valCrit = lrrcur.SSE; + break; + case MSE: + valCrit = lrrcur.MSE; + break; + case R2: + valCrit = lrrcur.R2; + break; + case AdjR2: + valCrit = lrrcur.Ra2; + break; + case AIC: + valCrit = lrrcur.AIC; + break; + case Cp: + valCrit = lrrcur.getCp(nx, lrrall.SSE); + break; + default: + throw new RuntimeException("Not implemented yet!"); + } + pw.println(cols + " : " + valCrit); + switch (crit) { + //下面两个准则的值,希望达到最大 + case R2: + case AdjR2: + if (valCrit > maxValue) { + maxValue = valCrit; + lrrbest = lrrcur; + } + break; + //以下准则的值,希望达到最小 + case SSE: + case MSE: + case AIC: + case Cp: + if (valCrit < maxValue) { + maxValue = valCrit; + lrrbest = lrrcur; + } + break; + default: + throw new RuntimeException("Not implemented yet!"); + } + } + + return new LinearRegressionStepwiseModel(lrrbest, cw.toString()); + } + + private static LinearRegressionStepwiseModel stepForward(SummaryResultTable srt, String nameY, String[] nameX, + double alphaEntry, double alphaStay) throws Exception { + java.io.CharArrayWriter cw = new java.io.CharArrayWriter(); + java.io.PrintWriter pw = new java.io.PrintWriter(cw, true); + int nx = nameX.length; + double maxValue = Double.NEGATIVE_INFINITY; + LinearRegressionModel lrrbest = null; + ArrayList selected = new ArrayList (); + ArrayList rest = new ArrayList (); + for (int k = 0; k < nx; k++) { + rest.add(nameX[k]); + } + for (int k = 0; k < nx; k++) { + int indexRestset = -1; + for (int i = 0; i < rest.size(); i++) { + ArrayList cols = new ArrayList (); + cols.addAll(selected); + cols.add(rest.get(i)); + LinearRegressionModel lrrcur = LinearReg.train(srt, nameY, cols.toArray(new String[0])); + pw.println(cols + " : " + k + " : " + i + " : " + lrrcur.F); + if (lrrcur.F > maxValue && lrrcur.pEquation < alphaEntry) { + maxValue = lrrcur.F; + lrrbest = lrrcur; + indexRestset = i; + } + } + if (indexRestset >= 0) { + selected.add(rest.get(indexRestset)); + rest.remove(indexRestset); + } else { + break; + } + } + + return new LinearRegressionStepwiseModel(lrrbest, cw.toString()); + } + + private static LinearRegressionStepwiseModel stepBackward(SummaryResultTable srt, String nameY, String[] nameX, + double alphaEntry, double alphaStay) throws Exception { + java.io.CharArrayWriter cw = new java.io.CharArrayWriter(); + java.io.PrintWriter pw = new java.io.PrintWriter(cw, true); + int nx = nameX.length; + double maxValue = Double.NEGATIVE_INFINITY; + LinearRegressionModel lrrbest = null; + ArrayList selected = new ArrayList (); + ArrayList rest = new ArrayList (); + for (int k = 0; k < nx; k++) { + selected.add(nameX[k]); + } + for (int k = 0; k < nx; k++) { + int indexRestset = -1; + for (int i = 0; i < selected.size(); i++) { + ArrayList cols = new ArrayList (); + cols.addAll(selected); + cols.remove(i); + LinearRegressionModel lrrcur = LinearReg.train(srt, nameY, cols.toArray(new String[0])); + pw.println(cols + " : " + k + " : " + i + " : " + lrrcur.F); + if (AlinkGlobalConfiguration.isPrintProcessInfo()) { + System.out.println(cols + " : " + k + " : " + i + " : " + lrrcur.F); + } + if (lrrcur.F > maxValue && lrrcur.pEquation < alphaStay) { + maxValue = lrrcur.F; + lrrbest = lrrcur; + indexRestset = i; + } + } + if (indexRestset >= 0) { + rest.add(selected.get(indexRestset)); + selected.remove(indexRestset); + } else { + break; + } + } + + return new LinearRegressionStepwiseModel(lrrbest, cw.toString()); + } + + private static LinearRegressionStepwiseModel stepStepwise(SummaryResultTable srt, String nameY, String[] nameX, + double alphaEntry, double alphaStay) throws Exception { + java.io.CharArrayWriter cw = new java.io.CharArrayWriter(); + java.io.PrintWriter pw = new java.io.PrintWriter(cw, true); + int nx = nameX.length; + double maxValue = Double.NEGATIVE_INFINITY; + LinearRegressionModel lrrbest = null; + ArrayList selected = new ArrayList (); + ArrayList rest = new ArrayList (); + for (int k = 0; k < nx; k++) { + rest.add(nameX[k]); + } + for (int k = 0; k < nx; k++) { + int indexRestset = -1; + LinearRegressionModel lrr = null; + for (int i = 0; i < rest.size(); i++) { + ArrayList cols = new ArrayList (); + cols.addAll(selected); + cols.add(rest.get(i)); + LinearRegressionModel lrrcur = LinearReg.train(srt, nameY, cols.toArray(new String[0])); + pw.println(cols + " : " + k + " : " + i + " : " + lrrcur.F); + if (AlinkGlobalConfiguration.isPrintProcessInfo()) { + System.out.println(cols + " : " + k + " : " + i + " : " + lrrcur.F + " " + lrrcur.pEquation); + } + if (lrrcur.F > maxValue && lrrcur.pEquation < alphaEntry) { + maxValue = lrrcur.F; + lrr = lrrcur; + indexRestset = i; + } + } + if (indexRestset >= 0) { + String coladd = rest.get(indexRestset); + selected.add(coladd); + rest.remove(indexRestset); + ArrayList deleted = new ArrayList (); + for (int i = 0; i < lrr.nameX.length; i++) { + if (lrr.pX[i] > alphaStay) { + String colrmv = lrr.nameX[i]; + deleted.add(colrmv); + selected.remove(colrmv); + rest.add(colrmv); + } + } + if (deleted.size() == 1 && deleted.get(0).equals(coladd)) { + break; + } else { + lrrbest = lrr; + } + } else { + break; + } + } + + return new LinearRegressionStepwiseModel(lrrbest, cw.toString()); + } + + public enum StepCriterion { + + /*** + * 多重判定系数, 该值最大为准则 + */ + R2, + /*** + * 调整的多重判定系数, 该值最大为准则 + */ + AdjR2, + /*** + * 误差平方和, 该值最小为准则 + */ + SSE, + /*** + * 误差均方,, 该值最小为准则 + */ + MSE, + /*** + * AIC(Akaike Information Criterion)信息统计量, 该值最小为准则 + */ + AIC, + /*** + * Cp统计量, 该值最小为准则 + */ + Cp; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/regression/LinearRegressionStepwiseModel.java b/core/src/main/java/com/alibaba/alink/operator/common/regression/LinearRegressionStepwiseModel.java new file mode 100644 index 000000000..6c215b855 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/regression/LinearRegressionStepwiseModel.java @@ -0,0 +1,35 @@ +/* + * To change this template, choose Tools | Templates + * and open the template in the editor. + */ +package com.alibaba.alink.operator.common.regression; + +import com.alibaba.alink.common.utils.JsonConverter; + +/** + * @author yangxu + */ +public class LinearRegressionStepwiseModel implements RegressionModelInterface { + + public LinearRegressionModel lrr; + public String stepInfo; + + LinearRegressionStepwiseModel(LinearRegressionModel lrr, String stepInfo) { + this.lrr = lrr; + this.stepInfo = stepInfo; + } + + @Override + public String toString() { + return lrr.toString() + "\nLinear Regression Step Info:\n" + stepInfo; + } + + public String __repr__() { + return toString(); + } + + public String toJson() { + return JsonConverter.gson.toJson(this); + + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/regression/RegressionModelInterface.java b/core/src/main/java/com/alibaba/alink/operator/common/regression/RegressionModelInterface.java new file mode 100644 index 000000000..3d7c1188d --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/regression/RegressionModelInterface.java @@ -0,0 +1,16 @@ +/* + * To change this template, choose Tools | Templates + * and open the template in the editor. + */ +package com.alibaba.alink.operator.common.regression; + +import com.alibaba.alink.common.utils.AlinkSerializable; + +/** + * @author yangxu + */ +interface RegressionModelInterface extends AlinkSerializable { + + public String toJson(); + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/regression/RidgeRegressionProcess.java b/core/src/main/java/com/alibaba/alink/operator/common/regression/RidgeRegressionProcess.java new file mode 100644 index 000000000..cb6321e44 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/regression/RidgeRegressionProcess.java @@ -0,0 +1,195 @@ +package com.alibaba.alink.operator.common.regression; + +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.jama.JMatrixFunc; +import com.alibaba.alink.common.probabilistic.CDF; +import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.common.statistics.statistics.SummaryResultTable; + +import java.util.ArrayList; + +/** + * @author yangxu + */ +public class RidgeRegressionProcess { + + private String nameY = null; + private String[] nameX = null; + private SummaryResultTable srt = null; + + public RidgeRegressionProcess(SummaryResultTable srt, String nameY, String[] nameX) { + this.srt = srt; + this.nameY = nameY; + if (null != nameX) { + this.nameX = new String[nameX.length]; + System.arraycopy(nameX, 0, this.nameX, 0, nameX.length); + } + } + + public static RidgeRegressionProcessResult calc(SummaryResultTable srt, String nameY, String[] nameX, + double[] kVals) throws Exception { + RidgeRegressionProcess rrp = new RidgeRegressionProcess(srt, nameY, nameX); + return rrp.calc(kVals); + } + + public RidgeRegressionProcessResult calc(double[] kVals) throws Exception { + + ////////////////////////////////////////////////////////////// + String[] colNames = srt.colNames; + Class[] types = new Class[colNames.length]; + for (int i = 0; i < colNames.length; i++) { + types[i] = srt.col(i).dataType; + } + int indexY = TableUtil.findColIndexWithAssertAndHint(colNames, nameY); + Class typeY = types[indexY]; + if (typeY != Double.class && typeY != Long.class && typeY != Integer.class) { + throw new Exception("col type must be double or bigint!"); + } + if (nameX.length == 0) { + throw new Exception("nameX must input!"); + } + for (int i = 0; i < nameX.length; i++) { + int indexX = TableUtil.findColIndexWithAssertAndHint(colNames, nameX[i]); + Class typeX = types[indexX]; + if (typeX != Double.class && typeX != Long.class && typeX != Integer.class) { + throw new Exception("col type must be double or bigint!"); + } + } + int nx = nameX.length; + int[] indexX = new int[nx]; + for (int i = 0; i < nx; i++) { + indexX[i] = TableUtil.findColIndexWithAssert(srt.colNames, nameX[i]); + } + + ////////////////////////////////////////////////////////////// + if (srt.col(indexY).countMissValue > 0 || srt.col(indexY).countNanValue > 0) { + throw new Exception("col " + nameY + " has null value or nan value!"); + } + for (int i = 0; i < indexX.length; i++) { + if (srt.col(indexX[i]).countMissValue > 0 || srt.col(indexX[i]).countNanValue > 0) { + throw new Exception("col " + nameX[i] + " has null value or nan value!"); + } + } + + if (srt.col(0).countTotal == 0) { + throw new Exception("table is empty!"); + } + if (srt.col(0).countTotal < nameX.length) { + throw new Exception("record size Less than features size!"); + } + + long N = srt.col(indexY).count; + if (N == 0) { + throw new Exception("Y valid value num is zero!"); + } + + ArrayList nameXList = new ArrayList (); + for (int i = 0; i < indexX.length; i++) { + if (srt.col(indexX[i]).count != 0) { + nameXList.add(nameX[i]); + } + } + nameXList.toArray(nameX); + + double[] XBar = new double[nx]; + for (int i = 0; i < nx; i++) { + XBar[i] = srt.col(indexX[i]).mean(); + } + double yBar = srt.col(indexY).mean(); + + double[][] cov = srt.getCov(); + + DenseMatrix C = new DenseMatrix(nx, 1); + for (int i = 0; i < nx; i++) { + C.set(i, 0, cov[indexX[i]][indexY]); + } + + RidgeRegressionProcessResult ridgeResult = new RidgeRegressionProcessResult(kVals.length); + + for (int k = 0; k < kVals.length; k++) { + double kval = kVals[k]; + ridgeResult.kVals[k] = kval; + // System.out.println(kval); + + ridgeResult.lrModels[k] = new LinearRegressionModel(N, nameY, nameX); + + DenseMatrix A = new DenseMatrix(nx, nx); + for (int i = 0; i < nx; i++) { + for (int j = 0; j < nx; j++) { + if (i == j) { + A.set(i, j, cov[indexX[i]][indexX[j]] + kval); + // A.set(i, j, cov[indexX[i]][indexX[j]] + kval / (N - 1)); + } else { + A.set(i, j, cov[indexX[i]][indexX[j]]); + } + } + } + + DenseMatrix BetaMatrix = A.solveLS(C); + + double d = yBar; + for (int i = 0; i < nx; i++) { + ridgeResult.lrModels[k].beta[i + 1] = BetaMatrix.get(i, 0); + d -= XBar[i] * ridgeResult.lrModels[k].beta[i + 1]; + } + ridgeResult.lrModels[k].beta[0] = d; + + double S = srt.col(nameY).variance() * (srt.col(nameY).count - 1); + double alpha = ridgeResult.lrModels[k].beta[0] - yBar; + double U = 0.0; + U += alpha * alpha * N; + for (int i = 0; i < nx; i++) { + U += 2 * alpha * srt.col(indexX[i]).sum * ridgeResult.lrModels[k].beta[i + 1]; + } + for (int i = 0; i < nx; i++) { + for (int j = 0; j < nx; j++) { + U += ridgeResult.lrModels[k].beta[i + 1] * ridgeResult.lrModels[k].beta[j + 1] * ( + cov[indexX[i]][indexX[j]] * (N - 1) + srt.col(indexX[i]).mean() * srt.col(indexX[j]).mean() + * N); + } + } + + ridgeResult.lrModels[k].SST = S; + ridgeResult.lrModels[k].SSR = U; + ridgeResult.lrModels[k].SSE = S - U; + ridgeResult.lrModels[k].dfSST = N - 1; + ridgeResult.lrModels[k].dfSSR = nx; + ridgeResult.lrModels[k].dfSSE = N - nx - 1; + ridgeResult.lrModels[k].R2 = Math.max(0.0, + Math.min(1.0, ridgeResult.lrModels[k].SSR / ridgeResult.lrModels[k].SST)); + ridgeResult.lrModels[k].R = Math.sqrt(ridgeResult.lrModels[k].R2); + ridgeResult.lrModels[k].MST = ridgeResult.lrModels[k].SST / ridgeResult.lrModels[k].dfSST; + ridgeResult.lrModels[k].MSR = ridgeResult.lrModels[k].SSR / ridgeResult.lrModels[k].dfSSR; + ridgeResult.lrModels[k].MSE = ridgeResult.lrModels[k].SSE / ridgeResult.lrModels[k].dfSSE; + ridgeResult.lrModels[k].Ra2 = 1 - ridgeResult.lrModels[k].MSE / ridgeResult.lrModels[k].MST; + ridgeResult.lrModels[k].s = Math.sqrt(ridgeResult.lrModels[k].MSE); + ridgeResult.lrModels[k].F = ridgeResult.lrModels[k].MSR / ridgeResult.lrModels[k].MSE; + if (ridgeResult.lrModels[k].F < 0) { + ridgeResult.lrModels[k].F = 0; + } + ridgeResult.lrModels[k].AIC = N * Math.log(ridgeResult.lrModels[k].SSE) + 2 * nx; + + A.scaleEqual(N - 1); + // DenseMatrix invA = A.Inverse(); + DenseMatrix invA = A.solveLS(JMatrixFunc.identity(A.numRows(), A.numRows())); + + for (int i = 0; i < nx; i++) { + ridgeResult.lrModels[k].FX[i] = + ridgeResult.lrModels[k].beta[i + 1] * ridgeResult.lrModels[k].beta[i + 1] / ( + ridgeResult.lrModels[k].MSE * invA.get(i, i)); + ridgeResult.lrModels[k].TX[i] = ridgeResult.lrModels[k].beta[i + 1] / (ridgeResult.lrModels[k].s * Math + .sqrt(invA.get(i, i))); + } + + int p = nameX.length; + ridgeResult.lrModels[k].pEquation = 1 - CDF.F(ridgeResult.lrModels[k].F, p, N - p - 1); + ridgeResult.lrModels[k].pX = new double[nx]; + for (int i = 0; i < nx; i++) { + ridgeResult.lrModels[k].pX[i] = (1 - CDF.studentT(Math.abs(ridgeResult.lrModels[k].TX[i]), N - p - 1)) + * 2; + } + + } + return ridgeResult; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/regression/RidgeRegressionProcessResult.java b/core/src/main/java/com/alibaba/alink/operator/common/regression/RidgeRegressionProcessResult.java new file mode 100644 index 000000000..879fd507d --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/regression/RidgeRegressionProcessResult.java @@ -0,0 +1,65 @@ +package com.alibaba.alink.operator.common.regression; + +/** + * @author yangxu + */ +public class RidgeRegressionProcessResult { + + public int n; + public double[] kVals = null; + public LinearRegressionModel[] lrModels = null; + + public RidgeRegressionProcessResult(int n) { + if (n <= 0) { + throw new RuntimeException(); + } + this.n = n; + this.kVals = new double[n]; + this.lrModels = new LinearRegressionModel[n]; + } + + @Override + public String toString() { + java.io.CharArrayWriter cw = new java.io.CharArrayWriter(); + java.io.PrintWriter pw = new java.io.PrintWriter(cw, true); + if (lrModels[0].nameX.length > 0) { + pw.print(" K \t\t Intercept"); + for (int i = 0; i < lrModels[0].nameX.length; i++) { + pw.print("\t\t " + lrModels[0].nameX[i]); + } + pw.println(); + } + for (int k = 0; k < kVals.length; k++) { + pw.print(kVals[k]); + pw.print("\t"); + for (double betaVal : lrModels[k].beta) { + pw.print(betaVal); + pw.print("\t"); + } + pw.println(); + } + pw.println(); + return cw.toString(); + } + + public boolean saveLinearRegressionModel(double kVal, String outModelTableName) throws Exception { + boolean bSaved = false; + for (int k = 0; k < kVals.length; k++) { + if (kVals[k] == kVal) { + LinearReg.write(lrModels[k], outModelTableName); + bSaved = true; + break; + } + } + return bSaved; + } + + public boolean saveLinearRegressionModel(int kValIndex, String outModelTableName) throws Exception { + if (kValIndex < 0 || kValIndex > lrModels.length) { + throw new Exception("kVal not exists"); + } + LinearReg.write(lrModels[kValIndex], outModelTableName); + return true; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/regression/tensorflow/TFTableModelRegressionFlatModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/regression/tensorflow/TFTableModelRegressionFlatModelMapper.java index 1a3095d35..37fddccfb 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/regression/tensorflow/TFTableModelRegressionFlatModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/regression/tensorflow/TFTableModelRegressionFlatModelMapper.java @@ -7,7 +7,7 @@ import org.apache.flink.types.Row; import org.apache.flink.util.Collector; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.dl.plugin.TFPredictorClassLoaderFactory; import com.alibaba.alink.common.exceptions.AkPreconditions; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/regression/tensorflow/TFTableModelRegressionModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/regression/tensorflow/TFTableModelRegressionModelMapper.java index dcf6fe994..8070d1001 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/regression/tensorflow/TFTableModelRegressionModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/regression/tensorflow/TFTableModelRegressionModelMapper.java @@ -7,7 +7,7 @@ import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.dl.plugin.TFPredictorClassLoaderFactory; import com.alibaba.alink.common.exceptions.AkPreconditions; import com.alibaba.alink.common.linalg.tensor.FloatTensor; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/similarity/LocalitySensitiveHashApproxFunctions.java b/core/src/main/java/com/alibaba/alink/operator/common/similarity/LocalitySensitiveHashApproxFunctions.java index c6df3d345..d891012f4 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/similarity/LocalitySensitiveHashApproxFunctions.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/similarity/LocalitySensitiveHashApproxFunctions.java @@ -13,7 +13,7 @@ import com.alibaba.alink.operator.common.similarity.lsh.BaseLSH; import com.alibaba.alink.operator.common.similarity.lsh.BucketRandomProjectionLSH; import com.alibaba.alink.operator.common.similarity.lsh.MinHashLSH; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; import com.alibaba.alink.params.shared.HasMLEnvironmentId; import com.alibaba.alink.params.similarity.VectorApproxNearestNeighborTrainParams; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/similarity/SerializableComparator.java b/core/src/main/java/com/alibaba/alink/operator/common/similarity/SerializableComparator.java new file mode 100644 index 000000000..65fc7ea70 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/similarity/SerializableComparator.java @@ -0,0 +1,7 @@ +package com.alibaba.alink.operator.common.similarity; + +import java.io.Serializable; +import java.util.Comparator; + +public interface SerializableComparator extends Comparator , Serializable { +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/similarity/dataConverter/KDTreeModelDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/similarity/dataConverter/KDTreeModelDataConverter.java index b8af4182e..d38a63d20 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/similarity/dataConverter/KDTreeModelDataConverter.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/similarity/dataConverter/KDTreeModelDataConverter.java @@ -21,7 +21,7 @@ import com.alibaba.alink.operator.common.distance.FastDistanceVectorData; import com.alibaba.alink.operator.common.similarity.KDTree; import com.alibaba.alink.operator.common.similarity.modeldata.KDTreeModelData; -import com.alibaba.alink.operator.common.statistics.StatisticsHelper; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary; import com.alibaba.alink.params.similarity.VectorApproxNearestNeighborTrainParams; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/sql/LocalOpCalciteSqlExecutor.java b/core/src/main/java/com/alibaba/alink/operator/common/sql/LocalOpCalciteSqlExecutor.java index 53bf45e94..23a79088f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/sql/LocalOpCalciteSqlExecutor.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/sql/LocalOpCalciteSqlExecutor.java @@ -5,6 +5,7 @@ import org.apache.flink.types.Row; import com.alibaba.alink.common.LocalMLEnvironment; +import com.alibaba.alink.operator.batch.sql.BatchSqlOperators; import com.alibaba.alink.operator.local.LocalOperator; import com.alibaba.alink.operator.local.source.TableSourceLocalOp; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/sql/MTableCalciteSqlExecutor.java b/core/src/main/java/com/alibaba/alink/operator/common/sql/MTableCalciteSqlExecutor.java index 63983bdc6..58f0640da 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/sql/MTableCalciteSqlExecutor.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/sql/MTableCalciteSqlExecutor.java @@ -1,6 +1,7 @@ package com.alibaba.alink.operator.common.sql; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.table.api.TableSchema; import org.apache.flink.table.functions.ScalarFunction; import org.apache.flink.table.functions.TableFunction; import org.apache.flink.types.Row; @@ -10,23 +11,33 @@ import com.alibaba.alink.common.exceptions.AkIllegalStateException; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; import com.alibaba.alink.common.io.plugin.TemporaryClassLoaderContext; +import com.alibaba.alink.operator.batch.sql.BatchSqlOperators; import com.alibaba.alink.operator.common.io.types.JdbcTypeConverter; import com.alibaba.alink.operator.local.sql.CalciteFunctionCompiler; +import org.apache.calcite.avatica.AvaticaResultSetMetaData; import org.apache.calcite.avatica.util.Casing; import org.apache.calcite.avatica.util.Quoting; import org.apache.calcite.config.CalciteConnectionProperty; import org.apache.calcite.config.NullCollation; import org.apache.calcite.jdbc.CalciteConnection; +import org.apache.calcite.jdbc.CalcitePrepare.CalciteSignature; import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactoryImpl.JavaType; +import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.schema.impl.ScalarFunctionImpl; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.lang.reflect.Field; import java.sql.Connection; import java.sql.DriverManager; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; +import java.sql.Types; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; @@ -40,6 +51,9 @@ * Execute SQL on MTables with local Calcite engine, which is similar to {@link BatchSqlOperators}. */ public class MTableCalciteSqlExecutor implements SqlExecutor { + + private static final Logger LOG = LoggerFactory.getLogger(MTableCalciteSqlExecutor.class); + private final Connection connection; private final SchemaPlus rootSchema; private final CalciteSchema calciteSchema; @@ -103,21 +117,72 @@ public void addFunction(String name, TableFunction function) { } } + private TableSchema extractSchema(ResultSetMetaData metaData) throws SQLException { + int numCols = metaData.getColumnCount(); + String[] colNames = new String[numCols]; + TypeInformation [] colTypes = new TypeInformation[numCols]; + for (int i = 0; i < numCols; i += 1) { + colNames[i] = metaData.getColumnLabel(i + 1); + colTypes[i] = JdbcTypeConverter.getFlinkType(metaData.getColumnType(i + 1)); + } + //noinspection deprecation + return new TableSchema(colNames, colTypes); + } + + /** + * When user-defined types are in the results, {@link ResultSetMetaData#getColumnType} returns {@link Types#OTHER}, + * which makes Alink unable to get the right {@link TypeInformation} for these fields. + *

+ * To obtain right {@link TypeInformation} for these fields, we have to use reflections to access private field of + * {@link ResultSetMetaData}. As a fallback, legacy method is still used when reflections cannot work. + */ + private TableSchema extractSchemaByReflection(ResultSetMetaData metaData) throws SQLException { + try { + int numCols = metaData.getColumnCount(); + String[] colNames = new String[numCols]; + TypeInformation [] colTypes = new TypeInformation[numCols]; + + AvaticaResultSetMetaData avaticaResultSetMetaData = metaData.unwrap(AvaticaResultSetMetaData.class); + Field signatureField = AvaticaResultSetMetaData.class.getDeclaredField("signature"); + signatureField.setAccessible(true); + CalciteSignature signature = (CalciteSignature ) signatureField.get(avaticaResultSetMetaData); + RelDataType rowType = signature.rowType; + List fields = rowType.getFieldList(); + + for (int i = 0; i < fields.size(); i++) { + colNames[i] = fields.get(i).getName(); + RelDataType relDataType = fields.get(i).getType(); + boolean isUdt = false; + if (relDataType instanceof JavaType) { + JavaType javaType = (JavaType) relDataType; + Class clazz = javaType.getJavaClass(); + if (clazz.getCanonicalName().startsWith("com.alibaba.alink.")) { + colTypes[i] = TypeInformation.of(clazz); + isUdt = true; + } + } + if (!isUdt) { + colTypes[i] = JdbcTypeConverter.getFlinkType(metaData.getColumnType(i + 1)); + } + } + //noinspection deprecation + return new TableSchema(colNames, colTypes); + } catch (Exception ignored) { + LOG.info("Failed to extract schema from meta data by reflection, so fallback to the legacy approach: " + + metaData.toString()); + return extractSchema(metaData); + } + } + @Override public MTable query(String sql) { - try (TemporaryClassLoaderContext context = + try (TemporaryClassLoaderContext ignored = TemporaryClassLoaderContext.of(calciteFunctionCompiler.getClassLoader())) { Statement statement = connection.createStatement(); ResultSet resultSet = statement.executeQuery(sql); ResultSetMetaData metaData = resultSet.getMetaData(); - + TableSchema schema = extractSchemaByReflection(metaData); int numCols = metaData.getColumnCount(); - String[] colNames = new String[numCols]; - TypeInformation [] colTypes = new TypeInformation[numCols]; - for (int i = 0; i < numCols; i += 1) { - colNames[i] = metaData.getColumnLabel(i + 1); - colTypes[i] = JdbcTypeConverter.getFlinkType(metaData.getColumnType(i + 1)); - } List data = new ArrayList <>(); while (resultSet.next()) { Row row = new Row(numCols); @@ -126,7 +191,7 @@ public MTable query(String sql) { } data.add(row); } - return new MTable(data, colNames, colTypes); + return new MTable(data, schema); } catch (SQLException e) { throw new AkUnclassifiedErrorException("Failed to execute query: " + sql, e); } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/sql/MTableTable.java b/core/src/main/java/com/alibaba/alink/operator/common/sql/MTableTable.java index 0a07d20fe..bfc4b9e60 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/sql/MTableTable.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/sql/MTableTable.java @@ -12,7 +12,6 @@ import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.schema.ScannableTable; import org.apache.calcite.schema.impl.AbstractTable; -import org.apache.calcite.sql.type.SqlTypeUtil; import org.apache.calcite.util.Pair; import java.util.Arrays; @@ -29,7 +28,7 @@ public RelDataType getRowType(RelDataTypeFactory relDataTypeFactory) { String[] names = mTable.getColNames(); final JavaTypeFactory typeFactory = (JavaTypeFactory) relDataTypeFactory; RelDataType[] types = Arrays.stream(mTable.getColTypes()) - .map(d -> SqlTypeUtil.addCharsetAndCollation(typeFactory.createJavaType(d.getTypeClass()), typeFactory)) + .map(d -> typeFactory.createJavaType(d.getTypeClass())) .toArray(RelDataType[]::new); return typeFactory.createStructType(Pair.zip(names, types)); } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTest.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTest.java index c579d62c4..0ffffb39c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTest.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTest.java @@ -12,7 +12,7 @@ import org.apache.flink.types.Row; import org.apache.flink.util.Collector; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import org.apache.commons.math3.distribution.ChiSquaredDistribution; import java.util.HashMap; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestUtil.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestUtil.java index 5581e6f62..5bfbd2c19 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestUtil.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestUtil.java @@ -7,7 +7,8 @@ import org.apache.flink.table.api.Table; import org.apache.flink.types.Row; -import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticsHelper; +import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.params.feature.BasedChisqSelectorParams; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/CorrespondenceAnalysis.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/CorrespondenceAnalysis.java new file mode 100644 index 000000000..351a751b8 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/CorrespondenceAnalysis.java @@ -0,0 +1,235 @@ +package com.alibaba.alink.operator.common.statistics; + +import org.apache.flink.types.Row; + +import com.alibaba.alink.common.linalg.DenseMatrix; +import com.alibaba.alink.common.linalg.jama.JMatrixFunc; +import com.alibaba.alink.common.utils.TableUtil; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; + +public class CorrespondenceAnalysis { + + public static CorrespondenceAnalysisResult calc(Iterable row, String rowName, String colName, + String[] colNames) throws Exception { + int rowIdx = TableUtil.findColIndexWithAssertAndHint(colNames, rowName); + int colIdx = TableUtil.findColIndexWithAssertAndHint(colNames, colName); + + Map groupCount = getGroupCount(row, rowIdx, colIdx); + List > distinctValue = getDistinctValue(groupCount); + String[] rowTags = distinctValue.get(0).toArray(new String[0]); + String[] colTags = distinctValue.get(1).toArray(new String[1]); + double[][] tabs = getPivotTable(rowTags, colTags, groupCount); + + CorrespondenceAnalysisResult car = calc(tabs); + car.rowLegend = rowName; + car.colLegend = colName; + car.rowTags = rowTags; + car.colTags = colTags; + return car; + } + + private static Map getGroupCount(Iterable row, int rowIdx, int colIdx) { + Map groupCount = new TreeMap ( + new Comparator () { + public int compare(GroupFactor left, GroupFactor right) { + int r = left.rowExpr.compareTo(right.rowExpr); + if (r == 0) { + return left.colExpr.compareTo(right.colExpr); + } + return r; + } + }); + + for (Row objs : row) { + GroupFactor factor = new GroupFactor(objs.getField(rowIdx).toString(), objs.getField(colIdx).toString()); + if (groupCount.containsKey(factor)) { + groupCount.put(factor, groupCount.get(factor) + 1); + } else { + groupCount.put(factor, 1L); + } + } + + return groupCount; + } + + private static List > getDistinctValue(Map groupCount) { + Set rowExprs = new HashSet (); + Set colExprs = new HashSet (); + Set groupFactors = groupCount.keySet(); + for (GroupFactor i : groupFactors) { + rowExprs.add(i.rowExpr); + colExprs.add(i.colExpr); + } + List > Exprs = new ArrayList >(); + Exprs.add(rowExprs); + Exprs.add(colExprs); + return Exprs; + } + + static double[][] getPivotTable(String[] rowTags, + String[] colTags, + Map groupCount) { + double[][] tabs = new double[rowTags.length][colTags.length]; + for (int i = 0; i < rowTags.length; i++) { + for (int j = 0; j < colTags.length; j++) { + GroupFactor factor = new GroupFactor(rowTags[i], colTags[j]); + if (groupCount.containsKey(factor)) { + tabs[i][j] = groupCount.get(factor); + } else { + tabs[i][j] = 0; + } + } + } + return tabs; + } + + static CorrespondenceAnalysisResult calc(double[][] X) throws Exception { + int nrow = X.length; + int ncol = X[0].length; + if (nrow * ncol == 1) { + throw new Exception("(the number of column expr) * ( number of row expr) must Greater than 2.!"); + } + + double T = 0.0; + for (double[] aX : X) { + for (int j = 0; j < ncol; j++) { + T += aX[j]; + } + } + + double[] wrow = new double[nrow]; + double[] wcol = new double[ncol]; + + for (int i = 0; i < nrow; i++) { + double s = 0; + for (int j = 0; j < ncol; j++) { + s += X[i][j]; + } + wrow[i] = s; + } + for (int j = 0; j < ncol; j++) { + double s = 0; + for (double[] aX : X) { + s += aX[j]; + } + wcol[j] = s; + } + + double[][] P = new double[nrow][ncol]; + for (int i = 0; i < nrow; i++) { + for (int j = 0; j < ncol; j++) { + P[i][j] = X[i][j] / T; + } + } + + double[] Pr = new double[nrow]; + double[] Pc = new double[ncol]; + for (int i = 0; i < nrow; i++) { + double s = 0; + for (int j = 0; j < ncol; j++) { + s += P[i][j]; + } + Pr[i] = s; + } + for (int j = 0; j < ncol; j++) { + double s = 0; + for (int i = 0; i < nrow; i++) { + s += P[i][j]; + } + Pc[j] = s; + } + + double[][] Z = new double[nrow][ncol]; + for (int i = 0; i < nrow; i++) { + for (int j = 0; j < ncol; j++) { + double t = Pr[i] * Pc[j]; + Z[i][j] = (P[i][j] - t) / Math.sqrt(t); + } + } + + double chi2 = 0; + for (int i = 0; i < nrow; i++) { + for (int j = 0; j < ncol; j++) { + chi2 += Z[i][j] * Z[i][j]; + } + } + chi2 *= T; + + DenseMatrix[] ed = JMatrixFunc.svd(new DenseMatrix(Z)); + + int p = Math.min(nrow, ncol); + + DenseMatrix Mr = ed[0].multiplies(ed[1]); + for (int i = 0; i < nrow; i++) { + double t = Math.sqrt(Pr[i]); + for (int j = 0; j < p; j++) { + Mr.set(i, j, Mr.get(i, j) / t); + } + } + + DenseMatrix Mc = ed[2].multiplies(ed[1]); + for (int i = 0; i < ncol; i++) { + double t = Math.sqrt(Pc[i]); + for (int j = 0; j < p; j++) { + Mc.set(i, j, Mc.get(i, j) / t); + } + } + + CorrespondenceAnalysisResult ca = new CorrespondenceAnalysisResult(); + ca.nrow = nrow; + ca.ncol = ncol; + ca.rowPos = new double[nrow][2]; + if (Mr.numCols() > 1) { + for (int i = 0; i < nrow; i++) { + ca.rowPos[i][0] = Mr.get(i, 0); + ca.rowPos[i][1] = Mr.get(i, 1); + } + } else { + for (int i = 0; i < nrow; i++) { + ca.rowPos[i][0] = Mr.get(i, 0); + ca.rowPos[i][1] = 0; + } + } + ca.colPos = new double[ncol][2]; + if (Mr.numCols() > 1) { + for (int i = 0; i < ncol; i++) { + ca.colPos[i][0] = Mc.get(i, 0); + ca.colPos[i][1] = Mc.get(i, 1); + } + } else { + for (int i = 0; i < ncol; i++) { + ca.colPos[i][0] = Mc.get(i, 0); + ca.colPos[i][1] = 0; + } + } + ca.sv = new double[2]; + ca.sv[0] = ed[1].get(0, 0); + if (Mr.numCols() > 1) { + ca.sv[1] = ed[1].get(1, 1); + } else { + ca.sv[1] = 0; + } + ca.pct = new double[2]; + ca.pct[0] = T * ca.sv[0] * ca.sv[0] / chi2; + ca.pct[1] = T * ca.sv[1] * ca.sv[1] / chi2; + + return ca; + } + + private static class GroupFactor { + String rowExpr; + String colExpr; + + public GroupFactor(String rowExpr, String colExpr) { + this.rowExpr = rowExpr; + this.colExpr = colExpr; + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/CorrespondenceAnalysisResult.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/CorrespondenceAnalysisResult.java new file mode 100644 index 000000000..6b194b70a --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/CorrespondenceAnalysisResult.java @@ -0,0 +1,16 @@ +package com.alibaba.alink.operator.common.statistics; + +import com.alibaba.alink.common.utils.AlinkSerializable; + +public class CorrespondenceAnalysisResult implements AlinkSerializable { + public int nrow; + public int ncol; + public double[] sv; + public double[] pct; + public String[] rowTags; + public String[] colTags; + public double[][] rowPos; + public double[][] colPos; + public String rowLegend; + public String colLegend; +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/SomJni.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/SomJni.java new file mode 100644 index 000000000..2dd6edb9b --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/SomJni.java @@ -0,0 +1,82 @@ +package com.alibaba.alink.operator.common.statistics; + +public class SomJni { + + public static int getNeuronPos(int x, int y, int xdim, int ydim, int vdim) { + return (y * xdim + x) * vdim; + } + + public void updateBatchJava(float[] w, float[] batch, int cnt, double lr, double sig, + int xdim, int ydim, int vdim) { + float[] d2 = new float[xdim * ydim]; + int[] bmu = new int[2]; + float[] v = new float[vdim]; + + double[] xGaussian = new double[xdim]; + double[] yGaussian = new double[ydim]; + + double d = 2.0 * Math.PI * sig * sig; + for (int i = 0; i < xdim; i++) { + xGaussian[i] = Math.exp(-1.0 * i * i / d); + } + for (int i = 0; i < ydim; i++) { + yGaussian[i] = Math.exp(-1.0 * i * i / d); + } + + for (int c = 0; c < cnt; c++) { + int p = c * vdim; + for (int j = 0; j < vdim; j++) { + v[j] = batch[p + j]; + } + findBmuJava(w, d2, v, bmu, xdim, ydim, vdim); + + // update neurons one by one + for (int i = 0; i < xdim; i++) { + for (int j = 0; j < ydim; j++) { + int ii = Math.abs(i - bmu[0]); + int jj = Math.abs(j - bmu[1]); + double g = lr * xGaussian[ii] * yGaussian[jj]; + int pos = getNeuronPos(i, j, xdim, ydim, vdim); + for (int k = 0; k < vdim; k++) { + float delta = batch[p + k] - w[pos + k]; + w[pos + k] = w[pos + k] + delta * (float) g; + } + } + } + } + } + + public float findBmuJava(float[] w, float[] d2, float[] v, int[] bmu, int xdim, int ydim, int vdim) { + for (int i = 0; i < xdim; i++) { + for (int j = 0; j < ydim; j++) { + int pos = (j * xdim + i) * vdim; + int pos2 = j * xdim + i; + d2[pos2] = 0.F; + for (int k = 0; k < vdim; k++) { + float delta = v[k] - w[pos + k]; + d2[pos2] += delta * delta; + } + } + } + + float minValue = Float.MAX_VALUE; + int x = -1; + int y = -1; + + for (int i = 0; i < xdim; i++) { + for (int j = 0; j < ydim; j++) { + int pos2 = j * xdim + i; + float d = d2[pos2]; + if (d < minValue) { + minValue = d; + x = i; + y = j; + } + } + } + + bmu[0] = x; + bmu[1] = y; + return minValue; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/Percentile.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/Percentile.java deleted file mode 100644 index bbe5baf0d..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/Percentile.java +++ /dev/null @@ -1,14 +0,0 @@ -package com.alibaba.alink.operator.common.statistics.basicstat; - -import com.alibaba.alink.common.utils.AlinkSerializable; - -public class Percentile implements AlinkSerializable { - public String colName; - public String colType; - public Object[] items = new Object[101]; - public Object median; - public Object Q1; - public Object Q3; - public Object min; - public Object max; -} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/Quantile.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/Quantile.java deleted file mode 100644 index 1b0d59983..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/Quantile.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.alibaba.alink.operator.common.statistics.basicstat; - -import java.io.Serializable; - -public class Quantile implements Serializable { - private static final long serialVersionUID = 5690612202964994601L; - public String colName; - public String colType; - public int q; - public Object[] items = null; -} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/QuantileWindowFunction.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/QuantileWindowFunction.java deleted file mode 100644 index 9bc0e2398..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/QuantileWindowFunction.java +++ /dev/null @@ -1,117 +0,0 @@ -package com.alibaba.alink.operator.common.statistics.basicstat; - -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.common.typeinfo.Types; -import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; -import org.apache.flink.streaming.api.windowing.windows.TimeWindow; -import org.apache.flink.types.Row; -import org.apache.flink.util.Collector; - -import com.alibaba.alink.operator.batch.feature.QuantileDiscretizerTrainBatchOp.QIndex; -import com.alibaba.alink.operator.common.dataproc.SortUtils; -import com.alibaba.alink.params.statistics.HasRoundMode; - -import java.sql.Timestamp; -import java.text.DateFormat; -import java.text.SimpleDateFormat; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; - -import static com.alibaba.alink.common.utils.JsonConverter.gson; - -public class QuantileWindowFunction implements AllWindowFunction { - private static final long serialVersionUID = -3504832156252458350L; - private String[] selectedColNames; - private int[] selectedColIdx; - private int quantileNum; - private double timeInterval; - private int timeColIdx; - private TypeInformation timeColType; - - public QuantileWindowFunction(String[] selectedColNames, int[] selectedColIdx, - int quantileNum, - int timeColIdx, double timeInterval, TypeInformation timeColType) { - this.selectedColNames = selectedColNames; - this.selectedColIdx = selectedColIdx; - this.quantileNum = quantileNum; - this.timeInterval = timeInterval; - this.timeColIdx = timeColIdx; - this.timeColType = timeColType; - } - - @Override - public void apply(TimeWindow timeWindow, Iterable iterable, Collector collector) throws Exception { - long startTime = timeWindow.getStart(); - long endTime = startTime + Math.round(timeInterval * 1000); - - //save data - List > data = new ArrayList <>(); - int len = this.selectedColNames.length; - for (int i = 0; i < len; i++) { - data.add(new ArrayList <>()); - } - Iterator iterator = iterable.iterator(); - if (timeColIdx < 0) { - while (iterator.hasNext()) { - Row row = iterator.next(); - for (int i = 0; i < len; i++) { - data.get(i).add(row.getField(selectedColIdx[i])); - } - } - } else { - while (iterator.hasNext()) { - Row row = iterator.next(); - long timestamp = getTime(row.getField(timeColIdx), timeColType); - if (timestamp >= startTime && timestamp < endTime) { - for (int i = 0; i < len; i++) { - data.get(i).add(row.getField(selectedColIdx[i])); - } - } - } - } - - //sort - int t = 0; - for (List colData : data) { - if (!colData.isEmpty()) { - Collections.sort(colData, new SortUtils.ComparableComparator()); - QIndex QIndex = new QIndex( - colData.size(), - quantileNum, - HasRoundMode.RoundMode.ROUND - ); - - Object[] quantileItems = new Object[quantileNum + 1]; - for (int i = 0; i <= quantileNum; i++) { - quantileItems[i] = colData.get((int) QIndex.genIndex(i)); - } - - Row row = new Row(4); - row.setField(0, timeStamp2Str(startTime)); - row.setField(1, timeStamp2Str(endTime)); - row.setField(2, selectedColNames[t]); - row.setField(3, gson.toJson(quantileItems)); - - collector.collect(row); - } - - t++; - } - } - - private String timeStamp2Str(long timestamp) { - DateFormat sdf = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss"); - return sdf.format(timestamp); - } - - private long getTime(Object val, TypeInformation type) { - if (Types.LONG.getTypeClass().getName() == type.getTypeClass().getName()) { - return (long) val; - } else { - return ((Timestamp) val).getTime(); - } - } -} - diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/SetPartitionBasicStat.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/SetPartitionBasicStat.java deleted file mode 100644 index 93c0eb9f4..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/SetPartitionBasicStat.java +++ /dev/null @@ -1,73 +0,0 @@ -package com.alibaba.alink.operator.common.statistics.basicstat; - -import org.apache.flink.api.common.functions.MapPartitionFunction; -import org.apache.flink.table.api.TableSchema; -import org.apache.flink.types.Row; -import org.apache.flink.util.Collector; - -import com.alibaba.alink.operator.common.statistics.statistics.Summary2; -import com.alibaba.alink.operator.common.statistics.statistics.SummaryResultTable; -import com.alibaba.alink.params.statistics.HasStatLevel_L1; - -/** - * @author yangxu - */ -public class SetPartitionBasicStat implements MapPartitionFunction { - - private static final long serialVersionUID = -5607403479996476267L; - private String[] colNames; - private Class[] colTypes; - private HasStatLevel_L1.StatLevel statLevel; - private String[] selectedColNames = null; - - public SetPartitionBasicStat(TableSchema schema) { - this(schema, HasStatLevel_L1.StatLevel.L1); - } - - /** - * @param schema - * @param statLevel: L1,L2,L3: 默认是L1 - * L1 has basic statistic; - * L2 has simple statistic and cov/corr; - * L3 has simple statistic, cov/corr, histogram, freq, topk, bottomk; - */ - public SetPartitionBasicStat(TableSchema schema, HasStatLevel_L1.StatLevel statLevel) { - this.colNames = schema.getFieldNames(); - int n = this.colNames.length; - this.colTypes = new Class[n]; - for (int i = 0; i < n; i++) { - colTypes[i] = schema.getFieldTypes()[i].getTypeClass(); - } - this.statLevel = statLevel; - this.selectedColNames = this.colNames; - } - - /** - * @param schema - * @param statLevel: L1,L2,L3: 默认是L1 - * L1 has basic statistic; - * L2 has simple statistic and cov/corr; - * L3 has simple statistic, cov/corr, histogram, freq, topk, bottomk; - * @param selectedColNames - */ - public SetPartitionBasicStat(TableSchema schema, String[] selectedColNames, HasStatLevel_L1.StatLevel statLevel) { - this.colNames = schema.getFieldNames(); - int n = this.colNames.length; - this.colTypes = new Class[n]; - for (int i = 0; i < n; i++) { - colTypes[i] = schema.getFieldTypes()[i].getTypeClass(); - } - this.statLevel = statLevel; - this.selectedColNames = selectedColNames; - } - - @Override - public void mapPartition(Iterable itrbl, Collector clctr) throws Exception { - WindowTable wt = new WindowTable(this.colNames, this.colTypes, itrbl); - SummaryResultTable srt = Summary2.batchSummary(wt, this.selectedColNames, 10, 10, 1000, 100, this.statLevel); - if (srt != null) { - clctr.collect(srt); - } - } - -} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/StatOnTimeWindowAll.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/StatOnTimeWindowAll.java deleted file mode 100644 index 6814cdc19..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/StatOnTimeWindowAll.java +++ /dev/null @@ -1,85 +0,0 @@ -package com.alibaba.alink.operator.common.statistics.basicstat; - -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; -import org.apache.flink.streaming.api.windowing.windows.TimeWindow; -import org.apache.flink.table.api.TableSchema; -import org.apache.flink.types.Row; -import org.apache.flink.util.Collector; - -import com.alibaba.alink.common.utils.TableUtil; -import com.alibaba.alink.operator.common.statistics.statistics.Summary2; -import com.alibaba.alink.operator.common.statistics.statistics.SummaryResultTable; -import com.alibaba.alink.params.statistics.HasStatLevel_L1.StatLevel; - -/** - * @author yangxu - */ -public class StatOnTimeWindowAll implements AllWindowFunction { - - private static final long serialVersionUID = 3329375326702908192L; - private String[] colNames; - private Class[] colTypes; - private String[] statColNames; - private Class[] statColTypes; - private StatLevel statLevel; - /** - * ouput or not when window is empty. - */ - private Boolean allowEmptyOutput = false; - - public StatOnTimeWindowAll(TableSchema schema) { - this(schema, null, StatLevel.L1); - } - - public StatOnTimeWindowAll(TableSchema schema, String[] statColNames) { - this(schema, statColNames, StatLevel.L3); - } - - public StatOnTimeWindowAll(TableSchema schema, String[] statColNames, StatLevel statLevel) { - this.colNames = schema.getFieldNames(); - int n = this.colNames.length; - this.colTypes = new Class[n]; - for (int i = 0; i < n; i++) { - colTypes[i] = schema.getFieldTypes()[i].getTypeClass(); - } - if ((statColNames != null) && (statColNames.length > 0)) { - this.statColNames = statColNames; - } else { - this.statColNames = this.colNames; - } - TypeInformation[] selectedTypes = TableUtil.findColTypesWithAssertAndHint(schema, this.statColNames); - n = this.statColNames.length; - this.statColTypes = new Class[n]; - for (int i = 0; i < n; ++i) { - this.statColTypes[i] = selectedTypes[i].getTypeClass(); - } - this.statLevel = statLevel; - } - - public StatOnTimeWindowAll(TableSchema schema, String[] statColNames, StatLevel statLevel, - Boolean allowEmptyOutput) { - this(schema, statColNames, statLevel); - this.allowEmptyOutput = allowEmptyOutput; - } - - @Override - public void apply(TimeWindow window, Iterable values, Collector out) throws Exception { - if ((values != null) && (values.iterator().hasNext())) { - WindowTable wt = new WindowTable(this.colNames, this.colTypes, values); - SummaryResultTable srt = Summary2.streamSummary(wt, statColNames, - 10, 10, 100, 10, statLevel); - out.collect(srt); - } else { - if (this.allowEmptyOutput) { - out.collect( - new SummaryResultTable( - this.statColNames, - this.statColTypes, - this.statLevel - ) - ); - } - } - } -} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/StatOnTimeWindowAllByKey.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/StatOnTimeWindowAllByKey.java deleted file mode 100644 index 0b0c13fea..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/StatOnTimeWindowAllByKey.java +++ /dev/null @@ -1,59 +0,0 @@ -package com.alibaba.alink.operator.common.statistics.basicstat; - -import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; -import org.apache.flink.streaming.api.windowing.windows.TimeWindow; -import org.apache.flink.table.api.TableSchema; -import org.apache.flink.types.Row; -import org.apache.flink.util.Collector; - -import com.alibaba.alink.common.AlinkGlobalConfiguration; -import com.alibaba.alink.operator.common.statistics.statistics.Summary2; -import com.alibaba.alink.operator.common.statistics.statistics.SummaryResultTable; -import com.alibaba.alink.params.statistics.HasStatLevel_L1; - -import java.util.HashMap; -import java.util.Map; - -/** - * @author yangxu - */ -public class StatOnTimeWindowAllByKey implements AllWindowFunction , TimeWindow> { - - private static final long serialVersionUID = -5715960707128472124L; - public String[] values; - private String[] colNames; - private Class[] colTypes; - private String[] statColNames; - private String groupColName; - - public StatOnTimeWindowAllByKey(TableSchema schema, String groupColName, String[] values, String[] statColNames) { - this.colNames = schema.getFieldNames(); - int n = this.colNames.length; - this.colTypes = new Class[n]; - for (int i = 0; i < n; i++) { - colTypes[i] = schema.getFieldTypes()[i].getTypeClass(); - } - this.statColNames = statColNames; - this.groupColName = groupColName; - this.values = values; - } - - @Override - public void apply(TimeWindow window, Iterable values, Collector > out) - throws Exception { - Map srtMap = new HashMap <>(); - for (int i = 0; i < this.values.length; i++) { - String groupValue = this.values[i]; - WindowTable wt = new WindowTable(this.colNames, this.colTypes, values, this.groupColName, groupValue); - SummaryResultTable srt = Summary2.streamSummary(wt, statColNames, 10, - 10, 100, 1000, HasStatLevel_L1.StatLevel.L3); - - if (AlinkGlobalConfiguration.isPrintProcessInfo()) { - System.out.println(window.toString() + " \t " + String.valueOf(srt.col(0).count)); - } - srtMap.put(this.values[i], srt); - } - out.collect(srtMap); - } - -} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/VectorUtil.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/VectorUtil.java deleted file mode 100644 index f3235be75..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstat/VectorUtil.java +++ /dev/null @@ -1,192 +0,0 @@ -package com.alibaba.alink.operator.common.statistics.basicstat; - -import com.alibaba.alink.common.linalg.DenseMatrix; -import com.alibaba.alink.common.linalg.DenseVector; - -/** - * Util of vector operations. - * - *

- DenseVector extension. - * - *

- DenseMatrix extension. - */ -public class VectorUtil { - - /** - * return left + normL1(right) - * it will change left, right will not be change - */ - public static DenseVector plusNormL1(DenseVector left, DenseVector right) { - double[] leftData = left.getData(); - double[] rightData = right.getData(); - if (left.size() >= right.size()) { - for (int i = 0; i < right.size(); i++) { - leftData[i] = sum(leftData[i], Math.abs(rightData[i])); - } - return left; - } else { - double[] data = rightData.clone(); - for (int i = 0; i < left.size(); i++) { - data[i] = sum(Math.abs(data[i]), leftData[i]); - } - - return new DenseVector(data); - } - } - - /** - * return left + right * right. - * it will change left, right will not be change. - */ - public static DenseVector plusSum2(DenseVector left, DenseVector right) { - double[] leftData = left.getData(); - double[] rightData = right.getData(); - if (left.size() >= right.size()) { - for (int i = 0; i < right.size(); i++) { - leftData[i] = sum(leftData[i], rightData[i] * rightData[i]); - } - return left; - } else { - double[] data = rightData.clone(); - for (int i = 0; i < left.size(); i++) { - data[i] = sum(data[i] * data[i], leftData[i]); - } - for (int i = left.size(); i < right.size(); i++) { - data[i] = data[i] * data[i]; - } - - return new DenseVector(data); - } - } - - /** - * return left + right - * it will change left, right will not be change. - */ - public static DenseVector plusEqual(DenseVector left, DenseVector right) { - double[] leftData = left.getData(); - double[] rightData = right.getData(); - if (left.size() >= right.size()) { - for (int i = 0; i < right.size(); i++) { - leftData[i] = sum(leftData[i], rightData[i]); - } - return left; - } else { - double[] data = rightData.clone(); - for (int i = 0; i < left.size(); i++) { - data[i] = sum(data[i], leftData[i]); - } - - return new DenseVector(data); - } - } - - /** - * return min(left, right) - * it will change left, right will not be change. - */ - public static DenseVector minEqual(DenseVector left, DenseVector right) { - double[] leftData = left.getData(); - double[] rightData = right.getData(); - if (left.size() >= right.size()) { - for (int i = 0; i < right.size(); i++) { - leftData[i] = min(leftData[i], rightData[i]); - } - return left; - } else { - double[] data = rightData.clone(); - for (int i = 0; i < left.size(); i++) { - data[i] = min(leftData[i], rightData[i]); - } - - return new DenseVector(data); - } - } - - /** - * return max(left,right) - * it will change left, right will not be change. - */ - public static DenseVector maxEqual(DenseVector left, DenseVector right) { - double[] leftData = left.getData(); - double[] rightData = right.getData(); - if (left.size() >= right.size()) { - for (int i = 0; i < right.size(); i++) { - leftData[i] = max(leftData[i], rightData[i]); - } - return left; - } else { - double[] data = rightData.clone(); - for (int i = 0; i < left.size(); i++) { - data[i] = max(leftData[i], rightData[i]); - } - - return new DenseVector(data); - } - } - - /** - * return left + right - * row and col of right matrix is equal with or Less than left matrix. - * it will change left, right will not be change. - */ - public static DenseMatrix plusEqual(DenseMatrix left, DenseMatrix right) { - for (int i = 0; i < right.numRows(); i++) { - for (int j = 0; j < right.numCols(); j++) { - left.add(i, j, right.get(i, j)); - } - } - return left; - } - - /** - * deal with nan. - */ - private static double sum(double left, double right) { - Boolean leftNan = Double.isNaN(left); - Boolean rightNan = Double.isNaN(right); - if (leftNan && rightNan) { - return left; - } else if (leftNan && !rightNan) { - return right; - } else if (!leftNan && rightNan) { - return left; - } else { - return left + right; - } - } - - /** - * deal with nan. - */ - private static double min(double left, double right) { - Boolean leftNan = Double.isNaN(left); - Boolean rightNan = Double.isNaN(right); - if (leftNan && rightNan) { - return left; - } else if (leftNan && !rightNan) { - return right; - } else if (!leftNan && rightNan) { - return left; - } else { - return Math.min(left, right); - } - } - - /** - * deal with nan. - */ - private static double max(double left, double right) { - Boolean leftNan = Double.isNaN(left); - Boolean rightNan = Double.isNaN(right); - if (leftNan && rightNan) { - return left; - } else if (leftNan && !rightNan) { - return right; - } else if (!leftNan && rightNan) { - return left; - } else { - return Math.max(left, right); - } - } -} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/SummaryDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/SummaryDataConverter.java index 1d0ce2da9..3b9459bf6 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/SummaryDataConverter.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/SummaryDataConverter.java @@ -3,9 +3,16 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.ml.api.misc.param.Params; +import com.alibaba.alink.common.exceptions.AkIllegalStateException; import com.alibaba.alink.common.linalg.VectorUtil; import com.alibaba.alink.common.model.SimpleModelDataConverter; import com.alibaba.alink.common.utils.JsonConverter; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonPrimitive; +import org.apache.commons.lang3.StringUtils; import java.util.ArrayList; import java.util.Iterator; @@ -29,13 +36,17 @@ public Tuple2 > serializeModel(TableSummary summary) { data.add(JsonConverter.toJson(summary.colNames)); data.add(String.valueOf(summary.count)); data.add(JsonConverter.toJson(summary.numericalColIndices)); - if(summary.count() != 0) { - data.add(VectorUtil.toString(summary.numMissingValue)); + if (summary.count() != 0) { + data.add(longVectorToString(summary.numMissingValue)); data.add(VectorUtil.toString(summary.sum)); - data.add(VectorUtil.toString(summary.squareSum)); - data.add(VectorUtil.toString(summary.min)); - data.add(VectorUtil.toString(summary.max)); + data.add(VectorUtil.toString(summary.sum2)); + data.add(VectorUtil.toString(summary.sum3)); + data.add(VectorUtil.toString(summary.sum4)); + data.add(VectorUtil.toString(summary.minDouble)); + data.add(VectorUtil.toString(summary.maxDouble)); data.add(VectorUtil.toString(summary.normL1)); + data.add(objectVectorToString(summary.min)); + data.add(objectVectorToString(summary.max)); } } @@ -60,17 +71,123 @@ public TableSummary deserializeModel(Params meta, Iterable data) { summary.colNames = JsonConverter.fromJson(dataIterator.next(), String[].class); summary.count = Long.parseLong(dataIterator.next()); summary.numericalColIndices = JsonConverter.fromJson(dataIterator.next(), int[].class); - if(summary.count != 0) { - summary.numMissingValue = VectorUtil.parseDense(dataIterator.next()); - if(dataIterator.hasNext()) { + if (summary.count != 0) { + summary.numMissingValue = stringToLongVector(dataIterator.next()); + if (dataIterator.hasNext()) { summary.sum = VectorUtil.parseDense(dataIterator.next()); - summary.squareSum = VectorUtil.parseDense(dataIterator.next()); - summary.min = VectorUtil.parseDense(dataIterator.next()); - summary.max = VectorUtil.parseDense(dataIterator.next()); + summary.sum2 = VectorUtil.parseDense(dataIterator.next()); + summary.sum3 = VectorUtil.parseDense(dataIterator.next()); + summary.sum4 = VectorUtil.parseDense(dataIterator.next()); + summary.minDouble = VectorUtil.parseDense(dataIterator.next()); + summary.maxDouble = VectorUtil.parseDense(dataIterator.next()); summary.normL1 = VectorUtil.parseDense(dataIterator.next()); + summary.min = stringToObjectVector(dataIterator.next()); + summary.max = stringToObjectVector(dataIterator.next()); } } return summary; } + + /** + * for serialize long vector. + */ + private static final char ELEMENT_DELIMITER = ' '; + + /** + * for serialize object vector. + */ + private static final String CLASS_NAME = "CLASS_NAME"; + + /** + * for serialize object vector. + */ + private static final String INSTANCE = "INSTANCE"; + + /** + * for serialize object vector. + */ + private final static Gson gson = new GsonBuilder() + .serializeNulls() + .disableHtmlEscaping() + .serializeSpecialFloatingPointValues() + .create(); + + /** + * serialize long vector . + */ + static String longVectorToString(long[] longVector) { + StringBuilder sbd = new StringBuilder(); + + for (int i = 0; i < longVector.length; i++) { + sbd.append(longVector[i]); + if (i < longVector.length - 1) { + sbd.append(ELEMENT_DELIMITER); + } + } + return sbd.toString(); + } + + /** + * deserialize long vector . + */ + static long[] stringToLongVector(String longVecStr) { + String[] longSVector = StringUtils.split(longVecStr, ELEMENT_DELIMITER); + long[] longVec = new long[longSVector.length]; + for (int i = 0; i < longSVector.length; i++) { + longVec[i] = Long.parseLong(longSVector[i]); + } + return longVec; + } + + /** + * serialize object vector. + * If JsonConvertor.toJson then JsonConvertor.fromJson(xxx, Object[].class), + * timestamp will convert to long, decimal will convert to double. so save class type when tojson. + */ + static String objectVectorToString(Object[] vec) { + JsonObject[] ojs = new JsonObject[vec.length]; + for (int i = 0; i < vec.length; i++) { + Object src = vec[i]; + if (src != null) { + JsonObject retValue = new JsonObject(); + String className = src.getClass().getName(); + + retValue.addProperty(CLASS_NAME, className); + JsonElement elem = gson.toJsonTree(src); + retValue.add(INSTANCE, elem); + + ojs[i] = retValue; + } + } + + return gson.toJson(ojs); + + } + + /** + * deserialize object vector. + */ + static Object[] stringToObjectVector(String vecJson) { + JsonElement[] jsonElements = gson.fromJson(vecJson, JsonElement[].class); + Object[] values = new Object[jsonElements.length]; + for (int i = 0; i < jsonElements.length; i++) { + JsonElement jsonElement = jsonElements[i]; + if (jsonElement instanceof JsonObject) { + JsonObject jsonObject = jsonElement.getAsJsonObject(); + JsonPrimitive prim = (JsonPrimitive) jsonObject.get(CLASS_NAME); + String className = prim.getAsString(); + + Class klass = null; + try { + klass = Class.forName(className); + } catch (ClassNotFoundException e) { + e.printStackTrace(); + throw new AkIllegalStateException(e.getMessage()); + } + values[i] = gson.fromJson(gson.toJson(jsonObject.get(INSTANCE)), klass); + } + } + return values; + } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/TableSummarizer.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/TableSummarizer.java index d559c1085..a43d57f1b 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/TableSummarizer.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/TableSummarizer.java @@ -1,14 +1,21 @@ package com.alibaba.alink.operator.common.statistics.basicstatistic; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; +import com.alibaba.alink.common.exceptions.AkIllegalStateException; import com.alibaba.alink.common.linalg.DenseMatrix; import com.alibaba.alink.common.linalg.DenseVector; -import com.alibaba.alink.common.linalg.MatVecOp; -import com.alibaba.alink.common.linalg.VectorUtil; import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.operator.common.statistics.statistics.BaseMeasureIterator; +import com.alibaba.alink.operator.common.statistics.statistics.BooleanMeasureIterator; +import com.alibaba.alink.operator.common.statistics.statistics.DateMeasureIterator; +import com.alibaba.alink.operator.common.statistics.statistics.StatisticsIteratorFactory; +import com.alibaba.alink.operator.common.statistics.statistics.NumberMeasureIterator; -import java.util.Arrays; +import java.util.ArrayList; +import java.util.List; /** * It is summary for table, it will calculate statistics and return TableSummary. @@ -18,75 +25,60 @@ public class TableSummarizer extends BaseSummarizer { private static final long serialVersionUID = 4588962274305185787L; /** - * col names which will calculate. + * col names. */ public String[] colNames; /** - * the value of ith row and jth col is sum of the ith variance - * when the ith col and the jth col of row are both not null. - * xSum_i_j = sum(x_i) when x_i != null && x_j!=null. + * col types. */ - DenseMatrix xSum; - - /** - * the value of ith row and jth col is sum of the ith variance - * when the ith col and the jth col of row are both not null. - * xSum_i_j = sum(x_i) when x_i != null && x_j!=null. - */ - DenseMatrix xSquareSum; - - /** - * the value of ith row and jth col is the count of the ith variance is not null - * and the jth variance is not null at the same row. - */ - DenseMatrix xyCount; - - /** - * numerical col indices: - * if col is numerical, it will calculate all statistics, otherwise only count, numMissingValue. - */ - private int[] numericalColIndices; + TypeInformation [] colTypes; /** - * the number of missing value of all columns. + * simple statistics for a col. */ - private DenseVector numMissingValue; + BaseMeasureIterator[] statIterators; /** - * sum_i = sum(x_i) when x_i is not null. + * col number. */ - protected DenseVector sum; + private int n; /** - * squareSum_i = sum(x_i * x_i) when x_i is not null. + * num of numerical cols, boolean cols and date cols. */ - protected DenseVector squareSum; + private int numberN; /** - * sum3_i = sum(x_i * x_i * x_i) when x_i is not null. + * numerical col and boolean col indices: + * if col is numerical or boolean , it will calculate cov and corr. */ - protected DenseVector sum3; + private int[] numericalColIndices; /** - * min_i = min(x_i) when x_i is not null. + * Intermediate variable which will used in Visit function. */ - protected DenseVector min; + private Double[] vals; /** - * max_i = max(x_i) when x_i is not null. + * the value of ith row and jth col is sum of the ith variance + * when the ith col and the jth col of row are both not null. + * xSum_i_j = sum(x_i) when x_i != null && x_j!=null. */ - protected DenseVector max; + DenseMatrix xSum; /** - * normL1_i = normL1(x_i) = sum(|x_i|) when x_i is not null. + * the value of ith row and jth col is sum of the ith variance + * when the ith col and the jth col of row are both not null. + * xSum_i_j = sum(x_i) when x_i != null && x_j!=null. */ - protected DenseVector normL1; + DenseMatrix xSquareSum; /** - * Intermediate variable which will used in Visit function. + * the value of ith row and jth col is the count of the ith variance is not null + * and the jth variance is not null at the same row. */ - private Double[] vals; + DenseMatrix xyCount; /** * default constructor. @@ -99,18 +91,22 @@ private TableSummarizer() { * if calculateOuterProduct is false, outerProduct,xSum, xSquareSum, xyCount are not be used, * these are for correlation and covariance. */ - public TableSummarizer(String[] selectedColNames, int[] numericalColIndices, boolean calculateOuterProduct) { - this.colNames = selectedColNames; + public TableSummarizer(TableSchema tableSchema, boolean calculateOuterProduct) { + this.colNames = tableSchema.getFieldNames(); + this.colTypes = tableSchema.getFieldTypes(); this.calculateOuterProduct = calculateOuterProduct; - this.numericalColIndices = numericalColIndices; + this.n = this.colNames.length; + this.numericalColIndices = calcCovColIndices(new TableSchema(this.colNames, this.colTypes)); + this.numberN = this.numericalColIndices.length; } /** * given row, incremental calculate statistics. */ public BaseSummarizer visit(Row row) { - int n = row.getArity(); - int numberN = numericalColIndices.length; + if (this.n != row.getArity()) { + throw new AkIllegalStateException("row size is not equal with table col num."); + } if (count == 0) { init(); @@ -118,39 +114,27 @@ public BaseSummarizer visit(Row row) { count++; - for (int i = 0; i < n; i++) { - Object obj = row.getField(i); - if (obj == null) { - numMissingValue.add(i, 1); - } + for (int i = 0; i < this.n; i++) { + this.statIterators[i].visit(row.getField(i)); } - for (int i = 0; i < numberN; i++) { - Object obj = row.getField(numericalColIndices[i]); - if (obj != null) { - if (obj instanceof Boolean) { - vals[i] = (boolean) obj ? 1.0 : 0.0; + if (calculateOuterProduct) { + for (int i = 0; i < numberN; i++) { + Object obj = row.getField(numericalColIndices[i]); + if (obj != null) { + if (obj instanceof Boolean) { + vals[i] = (boolean) obj ? 1.0 : 0.0; + } else { + vals[i] = ((Number) obj).doubleValue(); + } } else { - vals[i] = ((Number) obj).doubleValue(); + vals[i] = null; } - } else { - vals[i] = null; } - } - for (int i = 0; i < numberN; i++) { - if (vals[i] != null) { - double val = vals[i]; - - max.set(i, Math.max(val, max.get(i))); - min.set(i, Math.min(val, min.get(i))); + for (int i = 0; i < numberN; i++) { + if (vals[i] != null) { + double val = vals[i]; - sum.add(i, val); - squareSum.add(i, val * val); - sum3.add(i, val * val * val); - - normL1.add(i, Math.abs(val)); - - if (calculateOuterProduct) { for (int j = i; j < numberN; j++) { if (vals[j] != null) { outerProduct.add(i, j, val * vals[j]); @@ -170,73 +154,6 @@ public BaseSummarizer visit(Row row) { return this; } - /** - * n is the number of columns participating in the calculation. - */ - private void init() { - int n = colNames.length; - int numberN = numericalColIndices.length; - - numMissingValue = new DenseVector(n); - sum = new DenseVector(numberN); - squareSum = new DenseVector(numberN); - sum3 = new DenseVector(numberN); - normL1 = new DenseVector(numberN); - - double[] minVals = new double[numberN]; - Arrays.fill(minVals, Double.MAX_VALUE); - min = new DenseVector(minVals); - - double[] maxVals = new double[numberN]; - Arrays.fill(maxVals, -Double.MAX_VALUE); - max = new DenseVector(maxVals); - - if (calculateOuterProduct) { - outerProduct = new DenseMatrix(numberN, numberN); - xSum = new DenseMatrix(numberN, numberN); - xSquareSum = new DenseMatrix(numberN, numberN); - xyCount = new DenseMatrix(numberN, numberN); - } - - vals = new Double[numberN]; - } - - /** - * merge left and right, return a new summary. left will be changed. - */ - public static TableSummarizer merge(TableSummarizer left, TableSummarizer right) { - if (right.count == 0) { - return left; - } - - if (left.count == 0) { - return right.copy(); - } - - left.count += right.count; - left.numMissingValue.plusEqual(right.numMissingValue); - left.sum.plusEqual(right.sum); - left.squareSum.plusEqual(right.squareSum); - left.sum3.plusEqual(right.sum3); - left.normL1.plusEqual(right.normL1); - MatVecOp.apply(left.min, right.min, left.min, Math::min); - MatVecOp.apply(left.max, right.max, left.max, Math::max); - - if (left.outerProduct != null && right.outerProduct != null) { - left.outerProduct.plusEquals(right.outerProduct); - left.xSum.plusEquals(right.xSum); - left.xSquareSum.plusEquals(right.xSquareSum); - left.xyCount.plusEquals(right.xyCount); - } else if (left.outerProduct == null && right.outerProduct != null) { - left.outerProduct = right.outerProduct.clone(); - left.xSum = right.xSum.clone(); - left.xSquareSum = right.xSquareSum.clone(); - left.xyCount = right.xyCount.clone(); - } - - return left; - } - @Override public String toString() { StringBuilder sbd = new StringBuilder() @@ -244,17 +161,9 @@ public String toString() { .append(count) .append("\n"); if (count != 0) { - sbd.append("sum: ") - .append(VectorUtil.toString(sum)) - .append("\n") - .append("squareSum: ") - .append(VectorUtil.toString(squareSum)) - .append("\n") - .append("min: ") - .append(VectorUtil.toString(min)) - .append("\n") - .append("max: ") - .append(VectorUtil.toString(max)); + for (int i = 0; i < this.n; i++) { + sbd.append(this.colNames[i]).append(": ").append(this.statIterators[i]); + } } return sbd.toString(); @@ -266,71 +175,63 @@ public String toString() { public TableSummary toSummary() { TableSummary summary = new TableSummary(); - summary.count = count; - summary.sum = sum; - summary.squareSum = squareSum; - summary.sum3 = sum3; - summary.normL1 = normL1; - summary.min = min; - summary.max = max; - - summary.numMissingValue = numMissingValue; summary.numericalColIndices = numericalColIndices; - summary.colNames = colNames; - - return summary; - } - - /** - * get summary result of selected columns. - */ - public TableSummary toSummary(String[] selectedColNames) { - if (selectedColNames.length == 0) { - return toSummary(); - } - TableSummary summary = new TableSummary(); - int[] selectedColIndex = TableUtil.findColIndices(colNames, selectedColNames); - int n = selectedColNames.length; - summary.count = count; - summary.sum = new DenseVector(n); - summary.squareSum = new DenseVector(n); - summary.sum3 = new DenseVector(n); - summary.normL1 = new DenseVector(n); - summary.min = new DenseVector(n); - summary.max = new DenseVector(n); - summary.numMissingValue = new DenseVector(n); - summary.numericalColIndices = new int[n]; - - for (int i = 0; i < selectedColIndex.length; i++) { - int targetIndex = selectedColIndex[i]; - int targetIndexInDenseVector = -1; + summary.sum = new DenseVector(this.numberN); + summary.sum2 = new DenseVector(this.numberN); + summary.sum3 = new DenseVector(this.numberN); + summary.sum4 = new DenseVector(this.numberN); + summary.normL1 = new DenseVector(this.numberN); + summary.minDouble = new DenseVector(this.numberN); + summary.maxDouble = new DenseVector(this.numberN); + summary.numMissingValue = new long[this.n]; + summary.min = new Object[this.numberN]; + summary.max = new Object[this.numberN]; + + if (count > 0) { + for (int i = 0; i < this.n; i++) { + summary.numMissingValue[i] = this.statIterators[i].missingCount(); + } - for (int j = 0; j < numericalColIndices.length; j++) { - if (targetIndex == numericalColIndices[j]) { - targetIndexInDenseVector = j; + for (int i = 0; i < this.numberN; i++) { + BaseMeasureIterator iterator = this.statIterators[this.numericalColIndices[i]]; + if (iterator instanceof NumberMeasureIterator) { + NumberMeasureIterator numberIterator = (NumberMeasureIterator ) iterator; + summary.sum.set(i, numberIterator.sum); + summary.sum2.set(i, numberIterator.sum2); + summary.sum3.set(i, numberIterator.sum3); + summary.sum4.set(i, numberIterator.sum4); + summary.minDouble.set(i, numberIterator.min.doubleValue()); + summary.maxDouble.set(i, numberIterator.max.doubleValue()); + summary.normL1.set(i, numberIterator.normL1); + summary.min[i] = numberIterator.min; + summary.max[i] = numberIterator.max; + } else if (iterator instanceof BooleanMeasureIterator) { + BooleanMeasureIterator boolIterator = (BooleanMeasureIterator) iterator; + summary.sum.set(i, boolIterator.countTrue); + summary.sum2.set(i, boolIterator.countTrue); + summary.sum3.set(i, boolIterator.countTrue); + summary.sum4.set(i, boolIterator.countTrue); + summary.normL1.set(i, boolIterator.countTrue); + summary.minDouble.set(i, boolIterator.countFalse > 0 ? 0.0 : 1.0); + summary.maxDouble.set(i, boolIterator.countTrue > 0 ? 1.0 : 0.0); + summary.min[i] = boolIterator.countFalse <= 0; + summary.max[i] = boolIterator.countTrue > 0; + } else if (iterator instanceof DateMeasureIterator) { + DateMeasureIterator dateIterator = (DateMeasureIterator ) iterator; + summary.sum.set(i, Double.NaN); + summary.sum2.set(i, Double.NaN); + summary.sum3.set(i, Double.NaN); + summary.sum4.set(i, Double.NaN); + summary.minDouble.set(i, dateIterator.min.getTime()); + summary.maxDouble.set(i, dateIterator.max.getTime()); + summary.min[i] = dateIterator.min; + summary.max[i] = dateIterator.max; } } - if (targetIndexInDenseVector == -1) { - summary.numericalColIndices[i] = -1; - continue; - } - - summary.sum.set(i, sum.get(targetIndexInDenseVector)); - summary.squareSum.set(i, squareSum.get(targetIndexInDenseVector)); - summary.sum3.set(i, sum3.get(targetIndexInDenseVector)); - summary.normL1.set(i, normL1.get(targetIndexInDenseVector)); - summary.min.set(i, min.get(targetIndexInDenseVector)); - summary.max.set(i, max.get(targetIndexInDenseVector)); - - summary.numMissingValue.set(i, numMissingValue.get(targetIndexInDenseVector)); - summary.numericalColIndices[i] = TableUtil.findColIndex(selectedColNames, - colNames[numericalColIndices[targetIndexInDenseVector]]); } - summary.colNames = selectedColNames; - return summary; } @@ -397,11 +298,10 @@ public DenseMatrix covariance() { if (outerProduct == null) { return null; } - int nStat = numMissingValue.size(); - double[][] cov = new double[nStat][nStat]; - for (int i = 0; i < nStat; i++) { - for (int j = 0; j < nStat; j++) { + double[][] cov = new double[this.n][this.n]; + for (int i = 0; i < this.n; i++) { + for (int j = 0; j < this.n; j++) { cov[i][j] = Double.NaN; } } @@ -420,30 +320,98 @@ public DenseMatrix covariance() { } /** - * + * clone. */ TableSummarizer copy() { TableSummarizer srt = new TableSummarizer(); srt.colNames = colNames.clone(); srt.count = count; - srt.numericalColIndices = numericalColIndices.clone(); if (count != 0) { - srt.numMissingValue = numMissingValue.clone(); - srt.sum = sum.clone(); - srt.squareSum = squareSum.clone(); - srt.sum3 = sum3.clone(); - srt.normL1 = normL1.clone(); - srt.min = min.clone(); - srt.max = max.clone(); + srt.statIterators = new BaseMeasureIterator[this.n]; + for (int i = 0; i < this.n; i++) { + srt.statIterators[i] = this.statIterators[i].clone(); + } } - if (outerProduct != null) { - srt.outerProduct = outerProduct.clone(); - srt.xSum = xSum.clone(); - srt.xSquareSum = xSquareSum.clone(); - srt.xyCount = xyCount.clone(); + if (this.outerProduct != null) { + srt.numericalColIndices = this.numericalColIndices.clone(); + srt.outerProduct = this.outerProduct.clone(); + srt.xSum = this.xSum.clone(); + srt.xSquareSum = this.xSquareSum.clone(); + srt.xyCount = this.xyCount.clone(); } return srt; } + + /** + * n is the number of columns participating in the calculation. + */ + private void init() { + this.statIterators = new BaseMeasureIterator[this.n]; + for (int i = 0; i < this.n; i++) { + this.statIterators[i] = StatisticsIteratorFactory.getMeasureIterator(colTypes[i]); + } + + if (calculateOuterProduct) { + vals = new Double[numberN]; + outerProduct = new DenseMatrix(numberN, numberN); + xSum = new DenseMatrix(numberN, numberN); + xSquareSum = new DenseMatrix(numberN, numberN); + xyCount = new DenseMatrix(numberN, numberN); + } + } + + /** + * col indices which col type is number or boolean or date. + */ + private int[] calcCovColIndices(TableSchema tableSchema) { + List indicesList = new ArrayList <>(); + for (int i = 0; i < tableSchema.getFieldNames().length; i++) { + TypeInformation type = tableSchema.getFieldType(i).get(); + if (TableUtil.isSupportedNumericType(type) + || TableUtil.isSupportedBoolType(type) + || TableUtil.isSupportedDateType(type)) { + indicesList.add(i); + } + } + return indicesList.stream().mapToInt(Integer::valueOf).toArray(); + } + + /** + * merge left and right, return a new summary. left will be changed. + */ + public static TableSummarizer merge(TableSummarizer left, TableSummarizer right) { + if (right.count == 0) { + return left; + } + + if (left.count == 0) { + return right.copy(); + } + + left.count += right.count; + + if (left.n != right.n) { + throw new AkIllegalStateException("left stat cols is not equal with right stat cols"); + } + + for (int i = 0; i < left.n; i++) { + left.statIterators[i].merge(right.statIterators[i]); + } + + if (left.outerProduct != null && right.outerProduct != null) { + left.outerProduct.plusEquals(right.outerProduct); + left.xSum.plusEquals(right.xSum); + left.xSquareSum.plusEquals(right.xSquareSum); + left.xyCount.plusEquals(right.xyCount); + } else if (left.outerProduct == null && right.outerProduct != null) { + left.outerProduct = right.outerProduct.clone(); + left.xSum = right.xSum.clone(); + left.xSquareSum = right.xSquareSum.clone(); + left.xyCount = right.xyCount.clone(); + } + + return left; + } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/TableSummary.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/TableSummary.java index f0b434c4b..606999525 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/TableSummary.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/TableSummary.java @@ -1,6 +1,7 @@ package com.alibaba.alink.operator.common.statistics.basicstatistic; import com.alibaba.alink.common.linalg.DenseVector; +import com.alibaba.alink.common.linalg.tensor.LongTensor; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils; @@ -20,7 +21,7 @@ public class TableSummary extends BaseSummary { /** * the number of missing value. */ - DenseVector numMissingValue; + long[] numMissingValue; /** * sum_i = sum(x_i) @@ -28,24 +29,39 @@ public class TableSummary extends BaseSummary { DenseVector sum; /** - * squareSum_i = sum(x_i * x_i) + * sum2_i = sum(x_i * x_i) */ - DenseVector squareSum; + DenseVector sum2; /** * sum3_i = sum(x_i * x_i * x_i) */ DenseVector sum3; + /** + * sum4_i = sum(x_i * x_i * x_i * x_i) + */ + DenseVector sum4; + /** * min_i = min(x_i) */ - DenseVector min; + DenseVector minDouble; /** * max_i = max(x_i) */ - DenseVector max; + DenseVector maxDouble; + + /** + * min. + */ + Object[] min; + + /** + * max. + */ + Object[] max; /** * normL1_i = sum(|x_i|) @@ -58,7 +74,7 @@ public class TableSummary extends BaseSummary { int[] numericalColIndices; /** - * It will generated by summary. + * It will be generated by summary. */ TableSummary() { @@ -116,7 +132,7 @@ public double sum(String colName) { public double mean(String colName) { int idx = findIdx(colName); if (idx >= 0) { - if(isEmpty(colName)) { + if (isEmpty(colName)) { return Double.NaN; } return sum.get(idx) / numValidValue(colName); @@ -135,7 +151,7 @@ public double variance(String colName) { if (0 == numVaildValue || 1 == numVaildValue) { return 0; } - return Math.max(0.0, (squareSum.get(idx) - sum.get(idx) * sum.get(idx) / numVaildValue) / (numVaildValue + return Math.max(0.0, (sum2.get(idx) - sum.get(idx) * sum.get(idx) / numVaildValue) / (numVaildValue - 1)); } else { return Double.NaN; @@ -149,16 +165,23 @@ public double standardDeviation(String colName) { return Math.sqrt(variance(colName)); } + /** + * given colName, return standardError of the column. + */ + public double standardError(String colName) { + return standardDeviation(colName) / Math.sqrt(count); + } + /** * given colName, return min of the column. */ - public double min(String colName) { + public double minDouble(String colName) { int idx = findIdx(colName); if (idx >= 0) { if (isEmpty(colName)) { return Double.NaN; } - return min.get(idx); + return minDouble.get(idx); } else { return Double.NaN; } @@ -167,18 +190,34 @@ public double min(String colName) { /** * given colName, return max of the column. */ - public double max(String colName) { + public double maxDouble(String colName) { int idx = findIdx(colName); if (idx >= 0) { if (isEmpty(colName)) { return Double.NaN; } - return max.get(idx); + return maxDouble.get(idx); } else { return Double.NaN; } } + public Object min(String colName) { + int idx = findIdx(colName); + if (idx < 0) { + return Double.NaN; + } + return isEmpty(colName) ? Double.NaN : min[idx]; + } + + public Object max(String colName) { + int idx = findIdx(colName); + if (idx < 0) { + return Double.NaN; + } + return isEmpty(colName) ? Double.NaN : max[idx]; + } + /** * given colName, return l1 norm of the column. */ @@ -203,7 +242,7 @@ public double normL2(String colName) { if (isEmpty(colName)) { return Double.NaN; } - return Math.sqrt(squareSum.get(idx)); + return Math.sqrt(sum2.get(idx)); } else { return Double.NaN; } @@ -219,7 +258,7 @@ public double centralMoment2(String colName) { int idx = findIdx(colName); double mean = mean(colName); if (idx >= 0) { - return squareSum.get(idx) / count - mean * mean; + return sum2.get(idx) / count - mean * mean; } else { return Double.NaN; } @@ -235,12 +274,39 @@ public double centralMoment3(String colName) { int idx = findIdx(colName); double mean = mean(colName); if (idx >= 0) { - return (sum3.get(idx) - 3 * squareSum.get(idx) * mean + 2 * sum.get(idx) * mean * mean) / count; + return (sum3.get(idx) - 3 * sum2.get(idx) * mean + 2 * sum.get(idx) * mean * mean) / count; + } else { + return Double.NaN; + } + } + + public double centralMoment4(String colName) { + if (isEmpty(colName)) { + return Double.NaN; + } + int idx = findIdx(colName); + double mean = mean(colName); + if (idx >= 0) { + return (sum4.get(idx) - 4 * sum3.get(idx) * mean + 6 * sum2.get(idx) * mean * mean - 3 * sum.get(idx) * mean * mean * mean) / count; } else { return Double.NaN; } } + /** + * Skewness + */ + public double skewness(String colName) { + return centralMoment3(colName) / (centralMoment2(colName) * Math.sqrt(centralMoment2(colName))); + } + + /** + * Kurtosis + */ + public double kurtosis(String colName) { + return centralMoment4(colName) / (centralMoment2(colName) * centralMoment2(colName)) - 3; + } + /** * given colName, return the number of valid value. */ @@ -256,7 +322,11 @@ public long numMissingValue(String colName) { if (this.count == 0) { return 0; } - return Math.round(numMissingValue.get(idx)); + return numMissingValue[idx]; + } + + public double cv(String colName) { + return standardDeviation(colName) / mean(colName); } /** diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/BaseIntervalCalculator.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/BaseIntervalCalculator.java new file mode 100644 index 000000000..23ed73333 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/BaseIntervalCalculator.java @@ -0,0 +1,654 @@ +package com.alibaba.alink.operator.common.statistics.interval; + +import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; +import com.alibaba.alink.common.exceptions.AkIllegalDataException; +import com.alibaba.alink.operator.common.statistics.statistics.IntervalMeasureCalculator; + +import java.io.Serializable; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.math.RoundingMode; + +/** + * @author yangxu + */ +abstract class BaseIntervalCalculator implements Cloneable, Serializable { + + /** + * * + * 日期类型数据的基本区间长度的可选值数组 + */ + public static final long[] constSteps4DateType = new long[] { + 1,//1 millisecond + 10L, + 10L * 10, + 10L * 10 * 10, // 1 sec + 10L * 10 * 10 * 10, + 10L * 10 * 10 * 10 * 6,//1 min + 10L * 10 * 10 * 10 * 6 * 10, + 10L * 10 * 10 * 10 * 6 * 10 * 3,//half an hour + 10L * 10 * 10 * 10 * 6 * 10 * 3 * 2,//1 hour + 10L * 10 * 10 * 10 * 6 * 10 * 3 * 2 * 6, + 10L * 10 * 10 * 10 * 6 * 10 * 3 * 2 * 6 * 2,//half a day + 10L * 10 * 10 * 10 * 6 * 10 * 3 * 2 * 6 * 2 * 2,//1 day + 10L * 10 * 10 * 10 * 6 * 10 * 3 * 2 * 6 * 2 * 2 * 10,//10 day + 10L * 10 * 10 * 10 * 6 * 10 * 3 * 2 * 6 * 2 * 2 * 100,//100 day + 10L * 10 * 10 * 10 * 6 * 10 * 3 * 2 * 6 * 2 * 2 * 1000,//1000 day + 10L * 10 * 10 * 10 * 6 * 10 * 3 * 2 * 6 * 2 * 2 * 10000,//10000 day + 10L * 10 * 10 * 10 * 6 * 10 * 3 * 2 * 6 * 2 * 2 * 100000,//100000 day + }; + + protected final static int DefaultMagnitude = 1000; + + public int n; + public long[] count = null; + public int nCol = -1; + + /** + * * + * 每个基本区间内数据的基本统计计算量 + */ + public IntervalMeasureCalculator[][] mcs = null; + public int magnitude; // magnitude < n <= 10 * magnitude + public long startIndex; + + /** + * * + * step positive: for Long and Date type step negative: for Double type + * 10^(step+1000) is the real step value + */ + public long step; + public BigDecimal stepBD = null; + + /////////////////////////////////////////////////////////////////////////////////// + // IntervalCalculator 构造函数 + /////////////////////////////////////////////////////////////////////////////////// + + BaseIntervalCalculator(long start, long step, long[] count, IntervalMeasureCalculator[][] mcs, int magnitude) { + if ((10 * (long) magnitude > Integer.MAX_VALUE) || (magnitude < 1)) { + throw new AkIllegalArgumentException(""); + } else { + this.magnitude = magnitude; + } + this.step = step; + this.n = count.length; + this.count = count; + if (null != mcs) { + this.nCol = mcs[0].length; + this.mcs = mcs; + } + } + + protected static long divideInt(long k, long m) { + if (k >= 0) { + return k / m; + } else { + return (k - m + 1) / m; + } + } + + public static long calcIntervalVal(long val, long curStep) { + return divideInt(val, curStep); + } + + /////////////////////////////////////////////////////////////////////////////////// + // 获取区间数据 + /////////////////////////////////////////////////////////////////////////////////// + + public static long calcIntervalVal(double val, BigDecimal curStep) { + return calcIntervalVal(new BigDecimal(val), curStep); + } + + public static long calcIntervalVal(BigDecimal val, BigDecimal curStep) { + BigInteger k = calcIntervalValBD(val, curStep); + if (BigInteger.valueOf(k.longValue()).subtract(k).signum() == 0) { + return k.longValue(); + } else { + throw new AkIllegalArgumentException(""); + } + } + + private static BigInteger calcIntervalValBD(BigDecimal valBD, BigDecimal curStep) { + // return valBD.divide(curStep, 2, RoundingMode.FLOOR).toBigInteger(); + BigInteger bd = valBD.divide(curStep, 2, RoundingMode.FLOOR).toBigInteger(); + if (valBD.subtract(curStep.multiply(new BigDecimal(bd))).signum() < 0) { + return bd.subtract(BigInteger.ONE); + } else { + return bd; + } + } + + private static BigInteger calcIntervalValBD(double val, BigDecimal curStep) { + return calcIntervalValBD(new BigDecimal(val), curStep); + } + + /////////////////////////////////////////////////////////////////////////////////// + // 由直方图目标数据和其他需要参加统计的数据,创建IntervalCalculator + /////////////////////////////////////////////////////////////////////////////////// + //public static BaseIntervalCalculator create(long[] vals, int magnitude) { + // return create(vals, null, magnitude); + //} + // + //public static BaseIntervalCalculator create(long[] vals, double[][] colvals) { + // return create(vals, colvals, BaseIntervalCalculator.DefaultMagnitude); + //} + // + //public static BaseIntervalCalculator create(long[] vals, double[][] colvals, int magnitude) { + // if (null == vals || vals.length == 0) { + // throw new AkIllegalDataException(""); + // } + // + // long minVal = vals[0]; + // long maxVal = vals[0]; + // for (int i = 0; i < vals.length; i++) { + // if (minVal > vals[i]) { + // minVal = vals[i]; + // } + // if (maxVal < vals[i]) { + // maxVal = vals[i]; + // } + // } + // int nCol = -1; + // if (null != colvals) { + // nCol = colvals[0].length; + // } + // BaseIntervalCalculator xi = getEmptyInterval(minVal, maxVal, nCol, magnitude); + // xi.calculate(vals, colvals); + // + // return xi; + //} + // + //public static BaseIntervalCalculator create(double[] vals, int magnitude) { + // return create(vals, null, magnitude); + //} + // + //public static BaseIntervalCalculator create(double[] vals, double[][] colvals) { + // return create(vals, colvals, BaseIntervalCalculator.DefaultMagnitude); + //} + // + //public static BaseIntervalCalculator create(double[] vals, double[][] colvals, int magnitude) { + // if (null == vals || vals.length == 0) { + // throw new AkIllegalDataException(""); + // } + // + // double minVal = vals[0]; + // double maxVal = vals[0]; + // for (int i = 0; i < vals.length; i++) { + // if (minVal > vals[i]) { + // minVal = vals[i]; + // } + // if (maxVal < vals[i]) { + // maxVal = vals[i]; + // } + // } + // int nCol = -1; + // if (null != colvals) { + // nCol = colvals[0].length; + // } + // BaseIntervalCalculator xi = getEmptyInterval(minVal, maxVal, nCol, magnitude); + // xi.calculate(vals, colvals); + // + // return xi; + //} + + /////////////////////////////////////////////////////////////////////////////////// + // 由直方图目标数据的最小值和最大值,创建空的IntervalCalculator + /////////////////////////////////////////////////////////////////////////////////// + + //public static BaseIntervalCalculator getEmptyInterval(long min, long max, int nCol) { + // return getEmptyInterval(min, max, nCol, BaseIntervalCalculator.DefaultMagnitude); + //} + // + //public static BaseIntervalCalculator getEmptyInterval(long min, long max, int nCol, int magnitude) { + // MeasureCalculator[][] tmpmcs = null; + // if (nCol > 0) { + // tmpmcs = new MeasureCalculator[1][nCol]; + // for (int i = 0; i < nCol; i++) { + // tmpmcs[0][i] = new MeasureCalculator(); + // } + // } + // + // return new BaseIntervalCalculator(Long.class, min, 1, new long[] {0}, tmpmcs, magnitude); + // + //} + // + //public static BaseIntervalCalculator getEmptyInterval(double min, double max, int nCol, int magnitude) { + // if (Double.NEGATIVE_INFINITY < min && min <= max && max < Double.POSITIVE_INFINITY) { + // int k = -300; //double类型的最小精度 + // if (0 != min || 0 != max) { + // int k1 = (int) Math.log10(Math.abs(min) + Math.abs(max)); + // k = Math.max(k, k1 - 19);//long型数据大约19个有效数字 + // + // if (min != max) { + // int k2 = (int) (Math.log10(max - min) - Math.log10(magnitude)); + // k = Math.max(k, k2); + // } + // } + // BigDecimal stepBD = new BigDecimal(1); + // if (k > 1) { + // stepBD = BigDecimal.TEN.pow(k - 1); + // } else if (k <= 0) { + // stepBD = new BigDecimal(1).divide(BigDecimal.TEN.pow(1 - k)); + // } + // + // long minBD = calcIntervalValBD(min, stepBD).longValue(); + // MeasureCalculator[][] tmpmcs = null; + // if (nCol > 0) { + // tmpmcs = new MeasureCalculator[1][nCol]; + // for (int i = 0; i < nCol; i++) { + // tmpmcs[0][i] = new MeasureCalculator(); + // } + // } + // + // return new BaseIntervalCalculator(Double.class, minBD, k - 1 - 1000, new long[] {0}, tmpmcs, magnitude); + // + // } else { + // throw new AkIllegalDataException(""); + // } + //} + + /** + * 区间合并 + * + * @param ia :参加合并的区间 + * @param ib :参加合并的区间 + * @return 合并后的区间 + * @throws CloneNotSupportedException + */ + public static BaseIntervalCalculator combine(BaseIntervalCalculator ia, BaseIntervalCalculator ib) { + if (null == ia || null == ib) { + return null; + } + if (ia.magnitude != ib.magnitude) { + throw new AkIllegalDataException("Two merge XInterval must have same magnitude!"); + } + BaseIntervalCalculator x = null; + BaseIntervalCalculator y = null; + try { + if (ia.step > ib.step) { + x = (BaseIntervalCalculator) ia.clone(); + y = (BaseIntervalCalculator) ib.clone(); + } else { + x = (BaseIntervalCalculator) ib.clone(); + y = (BaseIntervalCalculator) ia.clone(); + } + } catch (Exception ex) { + throw new AkIllegalDataException(ex.getMessage()); + } + + while (x.step > y.step) { + y.upgrade(); + } + + long min = Math.min(x.startIndex, y.startIndex); + long max = Math.max(x.startIndex + x.n - 1, y.startIndex + y.n - 1); + + x.upgrade(min, max); + y.upgrade(min, max); + + for (int i = 0; i < x.n; i++) { + x.count[i] += y.count[i]; + } + + if (null != x.mcs && null != y.mcs) { + for (int i = 0; i < x.n; i++) { + for (int j = 0; j < x.nCol; j++) { + if (y.mcs[i][j] == null) { + x.mcs[i][j] = null; + } else { + x.mcs[i][j].calculate(y.mcs[i][j]); + } + } + } + } else { + x.mcs = null; + x.nCol = 0; + } + + return x; + } + + /** + * 将一组区间的步长统一为其中最大者 + * + * @param ics 区间组 + * @return + */ + public static boolean update2MaxStep(BaseIntervalCalculator[] ics) { + if (null == ics || ics.length == 0) { + throw new AkIllegalDataException(""); + } + + long maxstep = ics[0].step; + for (int i = 1; i < ics.length; i++) { + if (maxstep < ics[i].step) { + maxstep = ics[i].step; + } + } + + for (int i = 0; i < ics.length; i++) { + while (maxstep > ics[i].step) { + ics[i].upgrade(); + } + } + + return true; + } + + /** + * 将一组区间的步长统一为其中最大者,并将表示的区间范围统一 + * + * @param ics 区间组 + * @return + */ + public static boolean update2MaxStepSameRange(BaseIntervalCalculator[] ics) { + + update2MaxStep(ics); + + long min = ics[0].startIndex; + long max = ics[0].startIndex + ics[0].n - 1; + for (int i = 1; i < ics.length; i++) { + min = Math.min(min, ics[i].startIndex); + max = Math.max(max, ics[i].startIndex + ics[i].n - 1); + } + + for (int i = 0; i < ics.length; i++) { + ics[i].upgrade(min, max); + } + + return true; + } + + @Override + public Object clone() throws CloneNotSupportedException { + BaseIntervalCalculator sd = (BaseIntervalCalculator) super.clone(); + sd.count = this.count.clone(); + if (null != this.mcs) { + sd.mcs = this.mcs.clone(); + } + return sd; + } + + @Override + public String toString() { + StringBuilder sbd = new StringBuilder(); + sbd.append("startIndex=" + startIndex + ", step=" + step + ", n=" + n + ", magnitude=" + magnitude + '\n'); + for (int i = 0; i < n; i++) { + sbd.append("count[" + i + "] = " + count[i] + "\n"); + } + return sbd.toString(); + } + + /** + * * + * 获取数据类型 + * + * @return 数据类型 + */ + abstract String getDataType(); + + /** + * * + * 获取左边界值 + * + * @return 左边界值 + */ + abstract BigDecimal getLeftBound(); + + /** + * * + * 获取基本步长 + * + * @return 基本步长 + */ + abstract BigDecimal getStep(); + + /** + * * + * 获取指定分界点的值 + * + * @param index 指定分界点的索引 + * @return 指定分界点的值 + */ + abstract BigDecimal getTag(long index); + + /** + * * + * 获取每个基本区间内数据的个数 + * + * @return 每个基本区间内数据的个数 + */ + public long[] getCount() { + return this.count.clone(); + } + + public IntervalMeasureCalculator[] updateMeasureCalculatorsByCol(int idx) throws Exception { + IntervalMeasureCalculator[] measureCalculators = new IntervalMeasureCalculator[n]; + for (int i = 0; i < this.mcs.length; i++) { + measureCalculators[i] = mcs[i][idx]; + } + return measureCalculators; + + } + + /////////////////////////////////////////////////////////////////////////////////// + // 增量计算新增数据 + /////////////////////////////////////////////////////////////////////////////////// + abstract void updateStepBD(); + + public void calculate(long val) { + calculate(new long[]{val}, null); + } + + // public void calculate(long[] vals) { + // calculate(vals, null); + // } + // + public void calculate(long val, double[] colvals) { + calculate(new long[]{val}, new double[][]{colvals}); + } + + public void calculate(long[] vals, double[][] colvals) { + if (null == vals || vals.length == 0) { + return; + } + + long minVal = vals[0]; + long maxVal = vals[0]; + for (int i = 0; i < vals.length; i++) { + if (minVal > vals[i]) { + minVal = vals[i]; + } + if (maxVal < vals[i]) { + maxVal = vals[i]; + } + } + + long min = toIntervalVal(minVal); + long max = toIntervalVal(maxVal); + if ((min < startIndex) || (max >= startIndex + n)) { + // min 或者 max 不在区间内,需要重新设计区间分划 + min = Math.min(min, startIndex); + max = Math.max(max, startIndex + n - 1); + upgrade(min, max); + } + + for (int i = 0; i < vals.length; i++) { + int t = (int) (toIntervalVal(vals[i]) - startIndex); + count[t]++; + if (null != this.mcs && null != colvals) { + for (int j = 0; j < this.nCol; j++) { + mcs[t][j].calculate(colvals[i][j]); + } + } + } + return; + } + + // + // public void calculate(Date date) { + // calculate(new Date[]{date}, null); + // } + // + // public void calculate(Date[] dates) { + // calculate(dates, null); + // } + // + // public void calculate(Date date, double[] colvals) { + // calculate(date.getTime(), colvals); + // } + // + // public void calculate(Date[] dates, double[][] colvals) { + // long[] ds = new long[dates.length]; + // for (int i = 0; i < dates.length; i++) { + // ds[i] = dates[i].getTime(); + // } + // calculate(ds, colvals); + // } + // + // public void calculate(double val) { + // calculate(new double[]{val}, null); + // } + // + // public void calculate(double[] vals) { + // calculate(vals, null); + // } + // + public void calculate(double val, double[] colvals) { + calculate(new double[]{val}, new double[][]{colvals}); + } + + public void calculate(double[] vals, double[][] colvals) { + if (null == vals || vals.length == 0) { + return; + } + + double minVal = vals[0]; + double maxVal = vals[0]; + for (int i = 0; i < vals.length; i++) { + if (minVal > vals[i]) { + minVal = vals[i]; + } + if (maxVal < vals[i]) { + maxVal = vals[i]; + } + } + + while (!hasValidIntervalVal(minVal)) { + upgrade(); + } + while (!hasValidIntervalVal(maxVal)) { + upgrade(); + } + + long min = toIntervalVal(minVal); + long max = toIntervalVal(maxVal); + if ((min < startIndex) || (max >= startIndex + n)) { + // val 不在区间内,需要重新设计区间分划 + min = Math.min(min, startIndex); + max = Math.max(max, startIndex + n - 1); + + upgrade(min, max); + } + + for (int i = 0; i < vals.length; i++) { + int t = (int) (toIntervalVal(vals[i]) - startIndex); + count[t]++; + if (null != this.mcs && null != colvals) { + for (int j = 0; j < this.nCol; j++) { + mcs[t][j].calculate(colvals[i][j]); + } + } + } + return; + + } + + abstract long getNextScale(); + + protected void upgrade() { + long scale = getNextScale(); + long startNew = divideInt(startIndex, scale); + long endNew = divideInt(startIndex + n - 1, scale) + 1; + subUpgrade(scale, startNew, endNew); + } + + protected void upgrade(long min, long max) { + long scale = getScale4Upgrade(min, max); + long startNew = divideInt(min, scale); + long endNew = divideInt(max, scale) + 1; + subUpgrade(scale, startNew, endNew); + } + + abstract void adjustStepByScale(long scale); + + private void subUpgrade(long scale, long startNew, long endNew) { + adjustStepByScale(scale); + + int nNew = (int) (endNew - startNew); + long[] countNew = new long[nNew]; + for (int i = 0; i < n; i++) { + int t = (int) (divideInt(i + startIndex, scale) - startNew); + countNew[t] += count[i]; + } + + if (null != this.mcs) { + IntervalMeasureCalculator[][] mscNew = new IntervalMeasureCalculator[nNew][this.nCol]; + for (int i = 0; i < nNew; i++) { + for (int j = 0; j < this.nCol; j++) { + mscNew[i][j] = new IntervalMeasureCalculator(); + } + } + for (int i = 0; i < n; i++) { + int t = (int) (divideInt(i + startIndex, scale) - startNew); + for (int j = 0; j < nCol; j++) { + if (this.mcs[i][j] == null) { + mscNew[t][j] = null; + } else { + mscNew[t][j].calculate(this.mcs[i][j]); + } + } + } + this.mcs = mscNew; + } + + this.startIndex = startNew; + this.n = nNew; + this.count = countNew; + } + + private long getScale4Upgrade(long min, long max) { + if (min > max) { + throw new AkIllegalDataException(""); + } + + long s = 1; + for (int i = 0; i < 20; i++) { + long k = divideInt(max + s - 1, s) - divideInt(min, s); + if (k <= this.magnitude * 10) { + break; + } else { + s *= getNextScale(); + } + } + return s; + } + + private boolean hasValidIntervalVal(double val) { + BigInteger k = calcIntervalValBD(val, this.stepBD); + return BigInteger.valueOf(k.longValue()).subtract(k).signum() == 0; + } + + private long toIntervalVal(long val) { + return calcIntervalVal(val, this.step); + } + + private long toIntervalVal(double val) { + BigInteger k = calcIntervalValBD(val, this.stepBD); + if (BigInteger.valueOf(k.longValue()).subtract(k).signum() != 0) { + //有精度损失 + throw new AkIllegalDataException(""); + } + return k.longValue(); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/DateIntervalCalculator.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/DateIntervalCalculator.java new file mode 100644 index 000000000..f81218b43 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/DateIntervalCalculator.java @@ -0,0 +1,255 @@ +package com.alibaba.alink.operator.common.statistics.interval; + +import com.alibaba.alink.common.exceptions.AkIllegalDataException; +import com.alibaba.alink.operator.common.statistics.statistics.IntervalMeasureCalculator; + +import java.math.BigDecimal; +import java.util.Date; + +/** + * @author yangxu + */ +public class DateIntervalCalculator extends BaseIntervalCalculator { + + public DateIntervalCalculator(long start, long step, long[] count) { + this(start, step, count, DefaultMagnitude); + } + + public DateIntervalCalculator(long start, long step, long[] count, int magnitude) { + this(start, step, count, null, magnitude); + } + + public DateIntervalCalculator(long start, long step, long[] count, IntervalMeasureCalculator[][] mcs, int magnitude) { + super(start, step, count, mcs, magnitude); + this.startIndex = divideInt(start, step); + } + + public void calculate(T date, double[] colvals) { + calculate(date.getTime(), colvals); + } + + public void calculate(T[] dates, double[][] colvals) { + long[] ds = new long[dates.length]; + for (int i = 0; i < dates.length; i++) { + ds[i] = dates[i].getTime(); + } + calculate(ds, colvals); + } + + /////////////////////////////////////////////////////////////////////////////////// + // 由直方图目标数据的最小值和最大值,创建空的IntervalCalculator + /////////////////////////////////////////////////////////////////////////////////// + //public static DateIntervalCalculator getEmptyInterval(long min, long max, int nCol) { + // return getEmptyInterval(min, max, nCol, DateIntervalCalculator.DefaultMagnitude); + //} + // + //public static DateIntervalCalculator getEmptyInterval(long min, long max, int nCol, int magnitude) { + // MeasureCalculator[][] tmpmcs = null; + // if (nCol > 0) { + // tmpmcs = new MeasureCalculator[1][nCol]; + // for (int i = 0; i < nCol; i++) { + // tmpmcs[0][i] = new MeasureCalculator(); + // } + // } + // + // return new DateIntervalCalculator(Long.class, min, 1, new long[] {0}, tmpmcs, magnitude); + // + //} + // + //public static DateIntervalCalculator getEmptyInterval(double min, double max, int nCol, int magnitude) { + // if (Double.NEGATIVE_INFINITY < min && min <= max && max < Double.POSITIVE_INFINITY) { + // int k = -300; //double类型的最小精度 + // if (0 != min || 0 != max) { + // int k1 = (int) Math.log10(Math.abs(min) + Math.abs(max)); + // k = Math.max(k, k1 - 19);//long型数据大约19个有效数字 + // + // if (min != max) { + // int k2 = (int) (Math.log10(max - min) - Math.log10(magnitude)); + // k = Math.max(k, k2); + // } + // } + // BigDecimal stepBD = new BigDecimal(1); + // if (k > 1) { + // stepBD = BigDecimal.TEN.pow(k - 1); + // } else if (k <= 0) { + // stepBD = new BigDecimal(1).divide(BigDecimal.TEN.pow(1 - k)); + // } + // + // long minBD = calcIntervalValBD(min, stepBD).longValue(); + // MeasureCalculator[][] tmpmcs = null; + // if (nCol > 0) { + // tmpmcs = new MeasureCalculator[1][nCol]; + // for (int i = 0; i < nCol; i++) { + // tmpmcs[0][i] = new MeasureCalculator(); + // } + // } + // + // return new DateIntervalCalculator(Double.class, minBD, k - 1 - 1000, new long[] {0}, tmpmcs, magnitude); + // + // } else { + // throw new AkIllegalDataException(""); + // } + //} + + /** + * 区间合并 + * + * @param ia :参加合并的区间 + * @param ib :参加合并的区间 + * @return 合并后的区间 + * @throws CloneNotSupportedException + */ + public static DateIntervalCalculator combine(DateIntervalCalculator ia, DateIntervalCalculator ib) { + if (null == ia || null == ib) { + return null; + } + if (ia.magnitude != ib.magnitude) { + throw new AkIllegalDataException("Two merge XInterval must have same magnitude!"); + } + DateIntervalCalculator x = null; + DateIntervalCalculator y = null; + try { + if (ia.step > ib.step) { + x = (DateIntervalCalculator) ia.clone(); + y = (DateIntervalCalculator) ib.clone(); + } else { + x = (DateIntervalCalculator) ib.clone(); + y = (DateIntervalCalculator) ia.clone(); + } + } catch (Exception ex) { + throw new AkIllegalDataException(ex.getMessage()); + } + + while (x.step > y.step) { + y.upgrade(); + } + + long min = Math.min(x.startIndex, y.startIndex); + long max = Math.max(x.startIndex + x.n - 1, y.startIndex + y.n - 1); + + x.upgrade(min, max); + y.upgrade(min, max); + + for (int i = 0; i < x.n; i++) { + x.count[i] += y.count[i]; + } + + if (null != x.mcs && null != y.mcs) { + for (int i = 0; i < x.n; i++) { + for (int j = 0; j < x.nCol; j++) { + if (y.mcs[i][j] == null) { + x.mcs[i][j] = null; + } else { + x.mcs[i][j].calculate(y.mcs[i][j]); + } + } + } + } else { + x.mcs = null; + x.nCol = 0; + } + + return x; + } + + @Override + public Object clone() throws CloneNotSupportedException { + DateIntervalCalculator sd = (DateIntervalCalculator) super.clone(); + sd.count = this.count.clone(); + if (null != this.mcs) { + sd.mcs = this.mcs.clone(); + } + return sd; + } + + @Override + public String toString() { + StringBuilder sbd = new StringBuilder(); + sbd.append("startIndex=" + startIndex + ", step=" + step + ", n=" + n + ", magnitude=" + magnitude + '\n'); + for (int i = 0; i < n; i++) { + sbd.append("count[" + i + "] = " + count[i] + "\n"); + } + return sbd.toString(); + } + + /** + * * + * 获取数据类型 + * + * @return 数据类型 + */ + @Override + public String getDataType() { + return "Date"; + } + + /** + * * + * 获取左边界值 + * + * @return 左边界值 + */ + @Override + public BigDecimal getLeftBound() { + return BigDecimal.valueOf(startIndex * step); + } + + /** + * * + * 获取基本步长 + * + * @return 基本步长 + */ + @Override + public BigDecimal getStep() { + return BigDecimal.valueOf(this.step); + } + + /** + * * + * 获取指定分界点的值 + * + * @param index 指定分界点的索引 + * @return 指定分界点的值 + */ + @Override + public BigDecimal getTag(long index) { + return BigDecimal.valueOf((startIndex + index) * step); + } + + /////////////////////////////////////////////////////////////////////////////////// + // 增量计算新增数据 + /////////////////////////////////////////////////////////////////////////////////// + @Override + public void updateStepBD() { + this.stepBD = null; + } + + @Override + long getNextScale() { + if (this.step > 0) { + for (int i = 0; i < constSteps4DateType.length - 1; i++) { + if (constSteps4DateType[i] == this.step) { + return constSteps4DateType[i + 1] / constSteps4DateType[i]; + } + } + if (this.step * 10 / 10 == this.step) { + return 10; + } + } + + throw new AkIllegalDataException("Not support this data type or wrong step!"); + } + + @Override + void adjustStepByScale(long scale) { + if (1 < scale) { + if (this.step > 0) { + this.step *= scale; + } else { + throw new AkIllegalDataException("Not support this data type or wrong step!"); + } + } + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/FloatIntervalCalculator.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/FloatIntervalCalculator.java new file mode 100644 index 000000000..6e249353b --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/FloatIntervalCalculator.java @@ -0,0 +1,284 @@ +package com.alibaba.alink.operator.common.statistics.interval; + +import com.alibaba.alink.common.exceptions.AkIllegalDataException; +import com.alibaba.alink.operator.common.statistics.statistics.IntervalMeasureCalculator; + +import java.math.BigDecimal; + +/** + * @author yangxu + */ +public class FloatIntervalCalculator extends BaseIntervalCalculator { + + public FloatIntervalCalculator(double start, double step, long[] count) { + this(start, step, count, DefaultMagnitude); + } + + public FloatIntervalCalculator(double start, double step, long[] count, int magnitude) { + this(start, step, count, null, magnitude); + } + + public FloatIntervalCalculator(double start, double step, long[] count, IntervalMeasureCalculator[][] mcs, int magnitude) { + this(Math.round(start / step), Math.round(Math.log10(step)) - 1000, count, mcs, magnitude); + } + + public FloatIntervalCalculator(long start, long step, long[] count, int magnitude) { + this(start, step, count, null, magnitude); + } + + public FloatIntervalCalculator(long start, long step, long[] count, IntervalMeasureCalculator[][] mcs, int magnitude) { + super(start, step, count, mcs, magnitude); + updateStepBD(); + this.startIndex = start; + } + + public void calculate(T val, double[] colvals) { + calculate(val.doubleValue(), colvals); + } + + public void calculate(T[] vals, double[][] colvals) { + double[] ds = new double[vals.length]; + for (int i = 0; i < vals.length; i++) { + ds[i] = vals[i].doubleValue(); + } + calculate(ds, colvals); + } + + /////////////////////////////////////////////////////////////////////////////////// + // 由直方图目标数据和其他需要参加统计的数据,创建IntervalCalculator + /////////////////////////////////////////////////////////////////////////////////// + //public static BaseIntervalCalculator1 create(long[] vals, int magnitude) { + // return create(vals, null, magnitude); + //} + // + //public static BaseIntervalCalculator1 create(long[] vals, double[][] colvals) { + // return create(vals, colvals, BaseIntervalCalculator1.DefaultMagnitude); + //} + // + //public static BaseIntervalCalculator1 create(long[] vals, double[][] colvals, int magnitude) { + // if (null == vals || vals.length == 0) { + // throw new AkIllegalDataException(""); + // } + // + // long minVal = vals[0]; + // long maxVal = vals[0]; + // for (int i = 0; i < vals.length; i++) { + // if (minVal > vals[i]) { + // minVal = vals[i]; + // } + // if (maxVal < vals[i]) { + // maxVal = vals[i]; + // } + // } + // int nCol = -1; + // if (null != colvals) { + // nCol = colvals[0].length; + // } + // BaseIntervalCalculator1 xi = getEmptyInterval(minVal, maxVal, nCol, magnitude); + // xi.calculate(vals, colvals); + // + // return xi; + //} + // + //public static BaseIntervalCalculator1 create(double[] vals, int magnitude) { + // return create(vals, null, magnitude); + //} + // + //public static BaseIntervalCalculator1 create(double[] vals, double[][] colvals) { + // return create(vals, colvals, BaseIntervalCalculator1.DefaultMagnitude); + //} + // + ////public static BaseIntervalCalculator1 create(double[] vals, double[][] colvals, int magnitude) { + //// if (null == vals || vals.length == 0) { + //// throw new AkIllegalDataException(""); + //// } + //// + //// double minVal = vals[0]; + //// double maxVal = vals[0]; + //// for (int i = 0; i < vals.length; i++) { + //// if (minVal > vals[i]) { + //// minVal = vals[i]; + //// } + //// if (maxVal < vals[i]) { + //// maxVal = vals[i]; + //// } + //// } + //// int nCol = -1; + //// if (null != colvals) { + //// nCol = colvals[0].length; + //// } + //// BaseIntervalCalculator1 xi = getEmptyInterval(minVal, maxVal, nCol, magnitude); + //// xi.calculate(vals, colvals); + //// + //// return xi; + ////} + + /** + * 区间合并 + * + * @param ia :参加合并的区间 + * @param ib :参加合并的区间 + * @return 合并后的区间 + * @throws CloneNotSupportedException + */ + public static FloatIntervalCalculator combine(FloatIntervalCalculator ia, FloatIntervalCalculator ib) { + if (null == ia || null == ib) { + return null; + } + if (ia.magnitude != ib.magnitude) { + throw new AkIllegalDataException("Two merge XInterval must have same magnitude!"); + } + FloatIntervalCalculator x = null; + FloatIntervalCalculator y = null; + try { + if (ia.step > ib.step) { + x = (FloatIntervalCalculator) ia.clone(); + y = (FloatIntervalCalculator) ib.clone(); + } else { + x = (FloatIntervalCalculator) ib.clone(); + y = (FloatIntervalCalculator) ia.clone(); + } + } catch (Exception ex) { + throw new AkIllegalDataException(ex.getMessage()); + } + + while (x.step > y.step) { + y.upgrade(); + } + + long min = Math.min(x.startIndex, y.startIndex); + long max = Math.max(x.startIndex + x.n - 1, y.startIndex + y.n - 1); + + x.upgrade(min, max); + y.upgrade(min, max); + + for (int i = 0; i < x.n; i++) { + x.count[i] += y.count[i]; + } + + if (null != x.mcs && null != y.mcs) { + for (int i = 0; i < x.n; i++) { + for (int j = 0; j < x.nCol; j++) { + if (y.mcs[i][j] == null) { + x.mcs[i][j] = null; + } else { + x.mcs[i][j].calculate(y.mcs[i][j]); + } + } + } + } else { + x.mcs = null; + x.nCol = 0; + } + + return x; + } + + @Override + public Object clone() throws CloneNotSupportedException { + FloatIntervalCalculator sd = (FloatIntervalCalculator) super.clone(); + sd.count = this.count.clone(); + if (null != this.mcs) { + sd.mcs = this.mcs.clone(); + } + return sd; + } + + @Override + public String toString() { + StringBuilder sbd = new StringBuilder(); + sbd.append("startIndex=" + startIndex + ", step=" + step + ", n=" + n + ", magnitude=" + magnitude + '\n'); + for (int i = 0; i < n; i++) { + sbd.append("count[" + i + "] = " + count[i] + "\n"); + } + return sbd.toString(); + } + + /** + * * + * 获取数据类型 + * + * @return 数据类型 + */ + @Override + public String getDataType() { + return "Decimal"; + } + + /** + * * + * 获取左边界值 + * + * @return 左边界值 + */ + @Override + BigDecimal getLeftBound() { + return this.stepBD.multiply(BigDecimal.valueOf(startIndex)); + } + + /** + * * + * 获取基本步长 + * + * @return 基本步长 + */ + @Override + public BigDecimal getStep() { + return this.stepBD; + } + + /** + * * + * 获取指定分界点的值 + * + * @param index 指定分界点的索引 + * @return 指定分界点的值 + */ + @Override + public BigDecimal getTag(long index) { + return this.stepBD.multiply(BigDecimal.valueOf(startIndex + index)); + } + + /////////////////////////////////////////////////////////////////////////////////// + // 增量计算新增数据 + /////////////////////////////////////////////////////////////////////////////////// + @Override + void updateStepBD() { + if (this.step == -1000) { + stepBD = BigDecimal.ONE; + } else if (this.step > -1000) { + stepBD = BigDecimal.TEN.pow((int) (this.step + 1000)); + } else { + stepBD = new BigDecimal(1).divide(BigDecimal.TEN.pow((int) (0 - 1000 - this.step))); + } + } + + @Override + long getNextScale() { + if (this.step < 0) { + if (this.step < -1) { + return 10; + } + } + + throw new AkIllegalDataException("Not support this data type or wrong step!"); + } + + @Override + void adjustStepByScale(long scale) { + if (1 < scale) { + if (this.step < 0) { + long s = scale; + + while (s > 1) { + s /= 10; + this.step++; + this.updateStepBD(); + } + } else { + throw new AkIllegalDataException("Not support this data type or wrong step!"); + } + } + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/IntegerIntervalCalculator.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/IntegerIntervalCalculator.java new file mode 100644 index 000000000..7346bdb1c --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/IntegerIntervalCalculator.java @@ -0,0 +1,283 @@ +package com.alibaba.alink.operator.common.statistics.interval; + +import com.alibaba.alink.common.exceptions.AkIllegalDataException; +import com.alibaba.alink.operator.common.statistics.statistics.IntervalMeasureCalculator; + +import java.math.BigDecimal; + +/** + * @author yangxu + */ +public class IntegerIntervalCalculator extends BaseIntervalCalculator { + + public IntegerIntervalCalculator(long start, long step, long[] count) { + this(start, step, count, DefaultMagnitude); + } + + public IntegerIntervalCalculator(long start, long step, long[] count, int magnitude) { + this(start, step, count, null, magnitude); + } + + public IntegerIntervalCalculator(long start, long step, long[] count, IntervalMeasureCalculator[][] mcs, int magnitude) { + super(start, step, count, mcs, magnitude); + this.startIndex = divideInt(start, step); + } + + public void calculate(T val, double[] colvals) { + calculate(val.longValue(), colvals); + } + + public void calculate(T[] vals, double[][] colvals) { + long[] ds = new long[vals.length]; + for (int i = 0; i < vals.length; i++) { + ds[i] = vals[i].longValue(); + } + calculate(ds, colvals); + } + + /////////////////////////////////////////////////////////////////////////////////// + // 由直方图目标数据和其他需要参加统计的数据,创建IntervalCalculator + /////////////////////////////////////////////////////////////////////////////////// + //public static BaseIntervalCalculator1 create(long[] vals, int magnitude) { + // return create(vals, null, magnitude); + //} + // + //public static BaseIntervalCalculator1 create(long[] vals, double[][] colvals) { + // return create(vals, colvals, BaseIntervalCalculator1.DefaultMagnitude); + //} + // + //public static BaseIntervalCalculator1 create(long[] vals, double[][] colvals, int magnitude) { + // if (null == vals || vals.length == 0) { + // throw new AkIllegalDataException(""); + // } + // + // long minVal = vals[0]; + // long maxVal = vals[0]; + // for (int i = 0; i < vals.length; i++) { + // if (minVal > vals[i]) { + // minVal = vals[i]; + // } + // if (maxVal < vals[i]) { + // maxVal = vals[i]; + // } + // } + // int nCol = -1; + // if (null != colvals) { + // nCol = colvals[0].length; + // } + // BaseIntervalCalculator1 xi = getEmptyInterval(minVal, maxVal, nCol, magnitude); + // xi.calculate(vals, colvals); + // + // return xi; + //} + // + //public static BaseIntervalCalculator1 create(double[] vals, int magnitude) { + // return create(vals, null, magnitude); + //} + // + //public static BaseIntervalCalculator1 create(double[] vals, double[][] colvals) { + // return create(vals, colvals, BaseIntervalCalculator1.DefaultMagnitude); + //} + // + ////public static BaseIntervalCalculator1 create(double[] vals, double[][] colvals, int magnitude) { + //// if (null == vals || vals.length == 0) { + //// throw new AkIllegalDataException(""); + //// } + //// + //// double minVal = vals[0]; + //// double maxVal = vals[0]; + //// for (int i = 0; i < vals.length; i++) { + //// if (minVal > vals[i]) { + //// minVal = vals[i]; + //// } + //// if (maxVal < vals[i]) { + //// maxVal = vals[i]; + //// } + //// } + //// int nCol = -1; + //// if (null != colvals) { + //// nCol = colvals[0].length; + //// } + //// BaseIntervalCalculator1 xi = getEmptyInterval(minVal, maxVal, nCol, magnitude); + //// xi.calculate(vals, colvals); + //// + //// return xi; + ////} + + /////////////////////////////////////////////////////////////////////////////////// + // 由直方图目标数据的最小值和最大值,创建空的IntervalCalculator + /////////////////////////////////////////////////////////////////////////////////// + public static IntegerIntervalCalculator getEmptyInterval(long min, long max, int nCol) { + return getEmptyInterval(min, max, nCol, DefaultMagnitude); + } + + public static IntegerIntervalCalculator getEmptyInterval(long min, long max, int nCol, int magnitude) { + IntervalMeasureCalculator[][] tmpmcs = null; + if (nCol > 0) { + tmpmcs = new IntervalMeasureCalculator[1][nCol]; + for (int i = 0; i < nCol; i++) { + tmpmcs[0][i] = new IntervalMeasureCalculator(); + } + } + + return new IntegerIntervalCalculator(min, 1, new long[] {0}, tmpmcs, magnitude); + + } + + /** + * 区间合并 + * + * @param ia :参加合并的区间 + * @param ib :参加合并的区间 + * @return 合并后的区间 + * @throws CloneNotSupportedException + */ + public static IntegerIntervalCalculator combine(IntegerIntervalCalculator ia, IntegerIntervalCalculator ib) { + if (null == ia || null == ib) { + return null; + } + if (ia.magnitude != ib.magnitude) { + throw new AkIllegalDataException("Two merge XInterval must have same magnitude!"); + } + IntegerIntervalCalculator x = null; + IntegerIntervalCalculator y = null; + try { + if (ia.step > ib.step) { + x = (IntegerIntervalCalculator) ia.clone(); + y = (IntegerIntervalCalculator) ib.clone(); + } else { + x = (IntegerIntervalCalculator) ib.clone(); + y = (IntegerIntervalCalculator) ia.clone(); + } + } catch (Exception ex) { + throw new AkIllegalDataException(ex.getMessage()); + } + + while (x.step > y.step) { + y.upgrade(); + } + + long min = Math.min(x.startIndex, y.startIndex); + long max = Math.max(x.startIndex + x.n - 1, y.startIndex + y.n - 1); + + x.upgrade(min, max); + y.upgrade(min, max); + + for (int i = 0; i < x.n; i++) { + x.count[i] += y.count[i]; + } + + if (null != x.mcs && null != y.mcs) { + for (int i = 0; i < x.n; i++) { + for (int j = 0; j < x.nCol; j++) { + if (y.mcs[i][j] == null) { + x.mcs[i][j] = null; + } else { + x.mcs[i][j].calculate(y.mcs[i][j]); + } + } + } + } else { + x.mcs = null; + x.nCol = 0; + } + + return x; + } + + @Override + public Object clone() throws CloneNotSupportedException { + IntegerIntervalCalculator sd = (IntegerIntervalCalculator) super.clone(); + sd.count = this.count.clone(); + if (null != this.mcs) { + sd.mcs = this.mcs.clone(); + } + return sd; + } + + @Override + public String toString() { + StringBuilder sbd = new StringBuilder(); + sbd.append("startIndex=" + startIndex + ", step=" + step + ", n=" + n + ", magnitude=" + magnitude + '\n'); + for (int i = 0; i < n; i++) { + sbd.append("count[" + i + "] = " + count[i] + "\n"); + } + return sbd.toString(); + } + + /** + * * + * 获取数据类型 + * + * @return 数据类型 + */ + @Override + public String getDataType() { + return "Integer"; + } + + /** + * * + * 获取左边界值 + * + * @return 左边界值 + */ + @Override + public BigDecimal getLeftBound() { + return BigDecimal.valueOf(startIndex * step); + } + + /** + * * + * 获取基本步长 + * + * @return 基本步长 + */ + @Override + public BigDecimal getStep() { + return BigDecimal.valueOf(this.step); + } + + /** + * * + * 获取指定分界点的值 + * + * @param index 指定分界点的索引 + * @return 指定分界点的值 + */ + @Override + public BigDecimal getTag(long index) { + return BigDecimal.valueOf((startIndex + index) * step); + } + + /////////////////////////////////////////////////////////////////////////////////// + // 增量计算新增数据 + /////////////////////////////////////////////////////////////////////////////////// + @Override + void updateStepBD() { + this.stepBD = null; + } + + @Override + long getNextScale() { + if (this.step > 0) { + if (this.step * 10 / 10 == this.step) { + return 10; + } + } + + throw new AkIllegalDataException("Not support this data type or wrong step!"); + } + + @Override + void adjustStepByScale(long scale) { + if (1 < scale) { + if (this.step > 0) { + this.step *= scale; + } else { + throw new AkIllegalDataException("Not support this data type or wrong step!"); + } + } + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/Interval.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/Interval.java new file mode 100644 index 000000000..e8784fbed --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/Interval.java @@ -0,0 +1,235 @@ +/* + * To change this template, choose Tools | Templates + * and open the template in the editor. + */ +package com.alibaba.alink.operator.common.statistics.interval; + +import com.alibaba.alink.common.exceptions.AkIllegalStateException; + +import java.math.BigDecimal; + +/** + * @author zhihan.gao + */ +class Interval implements Cloneable { + + public double start; + public double end; + public double step; + public int numIntervals; + public int numStartToEnd; + public boolean bEqualAtRight; + public boolean leftResult; + public boolean rightResult; + + public Interval(double start, int size, double step, boolean bEqualAtRight, boolean leftResult, boolean + rightResult) { + if (step <= 0) { + throw new AkIllegalStateException("step must be positive!"); + } + // System.out.println("Ok"); + this.start = start; + BigDecimal startBD = new BigDecimal(String.valueOf(start)); + BigDecimal sizeBD = new BigDecimal(size); + BigDecimal stepBD = new BigDecimal(String.valueOf(step)); + this.end = (startBD.add(stepBD.multiply(sizeBD))).doubleValue(); + this.step = step; + this.bEqualAtRight = bEqualAtRight; + this.leftResult = leftResult; + this.rightResult = rightResult; + numStartToEnd = size; + if (leftResult) { + size++; + } + if (rightResult) { + size++; + } + numIntervals = size; + } + + public static Interval findInterval(double min, double max, int N) { + if (min > max) { + return null; + } + if (max - min < Double.MIN_VALUE) { + BigDecimal stepBD = new BigDecimal(1); + int start = (int) (Math.floor(min / stepBD.doubleValue())); + BigDecimal startBD = new BigDecimal(start); + Object[] r = new Object[3]; + r[0] = startBD; + r[1] = stepBD; + r[2] = 1; + return new Interval(((BigDecimal) r[0]).doubleValue(), (Integer) (r[2]), ((BigDecimal) r[1]).doubleValue(), + false, false, false); + } + BigDecimal stepBD = new BigDecimal(100000000); + BigDecimal scaleBD = new BigDecimal(10); + int i = 0; + long k = 0; + long start = 0; + long end = 0; + for (; i < 17; i++) { + start = (int) (Math.floor(min / stepBD.doubleValue())); + end = (int) (Math.ceil(max / stepBD.doubleValue())); + if ((max / stepBD.doubleValue() - end) < Double.MIN_VALUE) { + end++; + } + k = end - start; + if (k > N && k <= N * 10) { + break; + } + stepBD = stepBD.divide(scaleBD); + } + if (i == 17) { + double step = Math.pow(10.0, Math.ceil(Math.log10((max - min) / N))); + return new Interval(Math.floor(min / step) * step, (int) (Math.ceil(max / step) - Math.floor(min / step)), + step, false, false, false); + } + BigDecimal startBD = new BigDecimal(start).multiply(stepBD); + Object[] r = new Object[3]; + r[0] = startBD; + r[1] = stepBD; + r[2] = k; + return new Interval(((BigDecimal) r[0]).doubleValue(), Integer.parseInt(String.valueOf(r[2])), + ((BigDecimal) r[1]).doubleValue(), false, false, false); + } + + public double getStart() { + return start; + } + + public double getEnd() { + return end; + } + + public double getStep() { + return step; + } + + public boolean getBEqualAtRight() { + return bEqualAtRight; + } + + public boolean getLeftResult() { + return leftResult; + } + + public boolean getRightResult() { + return rightResult; + } + + @Override + public Object clone() { + Interval o = null; + try { + o = (Interval) super.clone(); + } catch (CloneNotSupportedException e) { + e.printStackTrace(); + } + return o; + } + + public int Position(double valueID) { + int num = -1; + if (bEqualAtRight) { + if (valueID <= start) { + if (leftResult) { + num = 0; + } else { + num = -1; + } + } else if (valueID > end) { + if (rightResult) { + num = numIntervals - 1; + } else { + num = -1; + } + } else { + double temp = (valueID - start) / step; + int tempInt = (int) (temp); + if ((temp - tempInt) < Double.MIN_VALUE) { + tempInt--; + } + num = tempInt; + if (leftResult) { + num++; + } + } + } else { + if (valueID < start) { + if (leftResult) { + num = 0; + } else { + num = -1; + } + } else if (valueID >= end) { + if (rightResult) { + num = numIntervals - 1; + } else { + num = -1; + } + } else { + double temp = (valueID - start) / step; + int tempInt = (int) (temp); + num = tempInt; + if (leftResult) { + num++; + } + } + } + if (num < 0) { + num = 0; + } + + return num; + } + + public int NumIntervals() { + return numIntervals; + } + + public String toString(int pos) { + if (pos < 0 || pos >= numIntervals) { + throw new AkIllegalStateException("Out of bound!"); + } + if (pos == numIntervals - 1 && rightResult) { + if (bEqualAtRight) { + return "(" + end + ", " + "infinity)"; + } else { + return "[" + end + ", " + "infinity)"; + } + } + BigDecimal posBD = new BigDecimal(pos); + BigDecimal stepBD = new BigDecimal(String.valueOf(step)); + BigDecimal startBD = new BigDecimal(String.valueOf(start)); + + double left; + double right; + if (leftResult) { + if (pos == 0) { + if (bEqualAtRight) { + return "(-infinity, " + start + "]"; + } else { + return "(-infinity, " + start + ")"; + } + } + pos--; + if (pos == numStartToEnd - 1) { + left = (startBD.add(posBD.multiply(stepBD))).doubleValue(); + right = end; + } else { + left = (startBD.add(posBD.multiply(stepBD))).doubleValue(); + right = (startBD.add(posBD.multiply(stepBD)).add(stepBD)).doubleValue(); + } + } else { + + left = (startBD.add(posBD.multiply(stepBD))).doubleValue(); + right = (startBD.add(posBD.multiply(stepBD)).add(stepBD)).doubleValue(); + } + if (bEqualAtRight) { + return "(" + left + ", " + right + "]"; + } else { + return "[" + left + ", " + right + ")"; + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/Interval4Calc.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/Interval4Calc.java new file mode 100644 index 000000000..829ccaad4 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/Interval4Calc.java @@ -0,0 +1,577 @@ +package com.alibaba.alink.operator.common.statistics.interval; + +import com.alibaba.alink.common.exceptions.AkIllegalStateException; +import com.alibaba.alink.operator.common.statistics.statistics.IntervalCalculator; +import com.alibaba.alink.operator.common.statistics.statistics.IntervalMeasureCalculator; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.text.SimpleDateFormat; +import java.util.Date; + +/** + * 直方图显示区间的计算 + *

+ * 支持类型:Double.class, Long.class, Date.class + *

+ * 直方图的数据来源: 1. 经过基本统计计算,能得到频率信息的数据,直接用频率计算 2. + * 否则,使用基本统计计算出的IntervalCalculator类型结果,其中包含基本细分区间 + *

+ * 直方图绘图参数: 1. 绘图区间为半闭半开区间 2. + * 如果不指定左边界或者右边界,则从频率或IntervalCalculator类型结果中的最小值或最大值求出 3. + * 可以建议绘图步长或直方图区间个数,若两个参数同时给出,则优先考虑绘图步长 + * + * @author yangxu + */ +public class Interval4Calc { + + /** + * myformat: 定义日期类型格式,用来输出日期类型数据的间隔标签 + */ + public static final SimpleDateFormat myformat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS"); + /** + * colType: 数据的类型,取值为三种Double.class, Long.class, Date.class + */ + String dataType = null; + /** + * 分块区间的个数 + */ + int nBlock; + /** + * Long.class, Date.class类型数据的起始位置和步长 + */ + long startLong; + long stepLong; + /** + * Double.class类型数据的起始位置和步长,用BigDecimal类型表示 + */ + BigDecimal startBD = null; + BigDecimal stepBD = null; + /** + * 如果是对直方图数据进行计算,下面2个变量需要赋值 ic指向原始直方图数据 sizeBlock:显示的单个分块区间对应原始直方图基本区间的个数 + */ + IntervalCalculator ic = null; + int sizeBlock = -1; + + /** + * 获取直方图显示区间 + * + * @param left :显示区域左边界 + * @param right :显示区域右边界 + * @param preferStep:建议步长 + * @param preferN :建议直方图区间个数 + * @param ic :IntervalCalculator类型数据,其中包含基本细分区间 + * @return 直方图显示区间 + */ + public static Interval4Display display(long left, long right, long preferStep, int preferN, IntervalCalculator + ic) { + if (null == ic || ic.getDataType().equals("Decimal")) { + throw new AkIllegalStateException(""); + } + Interval4Calc itvc = Interval4Calc.calculate(left, right, preferStep, preferN, ic.getDataType(), ic); + return itvc.toInterval4Display(); + } + + /////////////////////////////////////////// + // 获取直方图显示区间 + /////////////////////////////////////////// + + /** + * 获取直方图显示区间 + * + * @param left :显示区域左边界 + * @param right :显示区域右边界 + * @param preferStep:建议步长 + * @param preferN :建议直方图区间个数 + * @param ic :IntervalCalculator类型数据,其中包含基本细分区间 + * @return 直方图显示区间 + */ + public static Interval4Display display(String left, String right, String preferStep, int preferN, + IntervalCalculator ic) { + if (null == ic) { + throw new AkIllegalStateException(""); + } + if (ic.getDataType().equals("Decimal")) { + Interval4Calc itvc = Interval4Calc.calculate(left, right, preferStep, preferN, ic); + return itvc.toInterval4Display(); + } else { + long iLeft, iRight, iPreferStep; + if (null == left) { + iLeft = ic.getLeftBound().longValue(); + } else { + iLeft = Long.parseLong(left); + } + if (null == right) { + iRight = ic.getTag(ic.n).longValue(); + } else { + iRight = Long.parseLong(right); + } + if (null == preferStep) { + iPreferStep = -1; + } else { + iPreferStep = Long.parseLong(preferStep); + } + Interval4Calc itvc = Interval4Calc.calculate(iLeft, iRight, iPreferStep, preferN, ic.getDataType(), ic); + return itvc.toInterval4Display(); + } + } + + /** + * 获取直方图显示区间 + * + * @param left :显示区域左边界 + * @param right :显示区域右边界 + * @param preferStep:建议步长 + * @param preferN :建议直方图区间个数 + * @param dataType :数据类型 + * @param items :数据频率中的元素项 + * @param vals :数据频率中的频数项 + * @return 直方图显示区间 + */ + public static Interval4Display display(long left, long right, long preferStep, int preferN, String dataType, + long[] items, long[] vals) { + Interval4Calc itvc = Interval4Calc.calculate(left, right, preferStep, preferN, dataType, null); + return itvc.toInterval4Display(items, vals, dataType); + } + + /** + * 获取直方图显示区间 + * + * @param left :显示区域左边界 + * @param right :显示区域右边界 + * @param preferStep:建议步长 + * @param preferN :建议直方图区间个数 + * @param dataType :数据类型 + * @param items :数据频率中的元素项 + * @param vals :数据频率中的频数项 + * @return 直方图显示区间 + */ + public static Interval4Display display(String left, String right, String preferStep, int preferN, String dataType, + long[] items, long[] vals) { + long iLeft, iRight, iPreferStep; + if (null == left) { + iLeft = Min(items);//items[0]; + } else { + iLeft = Long.parseLong(left); + } + + if (null == right) { + iRight = Max(items) + 1;//items[items.size - 1]; + } else { + iRight = Long.parseLong(right); + } + + if (null == preferStep) { + iPreferStep = -1; + } else { + iPreferStep = Long.parseLong(preferStep); + } + + return display(iLeft, iRight, iPreferStep, preferN, dataType, items, vals); + } + + /** + * 获取直方图显示区间 + * + * @param left :显示区域左边界 + * @param right :显示区域右边界 + * @param preferStep:建议步长 + * @param preferN :建议直方图区间个数 + * @param items :数据频率中的元素项 + * @param vals :数据频率中的频数项 + * @return 直方图显示区间 + */ + public static Interval4Display display(String left, String right, String preferStep, int preferN, double[] items, + long[] vals) { + String iLeft, iRight; + if (null == left) { + iLeft = Double.toString(Min(items)); + } else { + iLeft = left; + } + if (null == right) { + double tmin = Min(items); + double tmax = Max(items); + double te = 1; + if ((tmax - tmin) / 10000 < te) { + te = (tmax - tmin) / 10000; + } + iRight = Double.toString(Max(items) + te); + } else { + iRight = right; + } + + Interval4Calc itvc = Interval4Calc.calculate(iLeft, iRight, preferStep, preferN, null); + return itvc.toInterval4Display(items, vals); + } + + /////////////////////////////////////////// + // 计算直方图显示区间的划分方案 + /////////////////////////////////////////// + private static long getNextStepLong(long step, boolean isDateType) { + if (isDateType) { + for (int i = 0; i < IntervalCalculator.constSteps4DateType.length - 1; i++) { + if (IntervalCalculator.constSteps4DateType[i] > step) { + return IntervalCalculator.constSteps4DateType[i]; + } + } + if (step * 10 / 10 == step) { + return step * 10; + } else { + throw new AkIllegalStateException(""); + } + } else { + return step * 10; + } + } + + /** + * 计算直方图显示区间的划分方案 + * + * @param left :显示区域左边界 + * @param right :显示区域右边界 + * @param preferStep:建议步长 + * @param preferN :建议直方图区间个数 + * @param dataType :数据类型 + * @param ic :IntervalCalculator类型数据,其中包含基本细分区间 + * @return 直方图显示区间的划分方案 + */ + static Interval4Calc calculate(long left, long right, long preferStep, int preferN, String dataType, + IntervalCalculator ic) { + //if (colType.equals("Decimal") || colType.equals("Integer")) { + if (dataType.equals("Decimal")) { + throw new AkIllegalStateException(""); + } + if (null != ic && ic.getDataType() != dataType) { + throw new AkIllegalStateException(""); + } + + long icstep = 1; + if (null != ic) { + icstep = ic.getStep().longValue(); + } + long step = -1; + int sizeBlock = -1; + + //根据用户是否给出建议步长,分情况分析 + if (preferStep > 0) { + step = icstep; + boolean isDateType = (dataType.equals("Date")); + if (step < preferStep) { + long nextstep = getNextStepLong(step, isDateType); + while (nextstep < preferStep) { + step = nextstep; + nextstep = getNextStepLong(step, isDateType); + } + } + sizeBlock = Math.max(1, (int) (preferStep / step)); + } else { + if (preferN < 1) { + throw new AkIllegalStateException(""); + } + step = icstep; + boolean isDateType = (dataType.equals("Date")); + if ((right - left) / step > preferN) { + long nextstep = getNextStepLong(step, isDateType); + while ((right - left) / nextstep > preferN) { + step = nextstep; + nextstep = getNextStepLong(step, isDateType); + } + } + sizeBlock = Math.max(1, (int) ((right - left) / step / preferN)); + } + + long scale = step / icstep; + long offset = IntervalCalculator.calcIntervalVal(left, step); + long end = IntervalCalculator.calcIntervalVal(right - 1, step); + int n = (int) ((end - offset + 1 + sizeBlock - 1) / sizeBlock); + + Interval4Calc r = new Interval4Calc(); + r.nBlock = n; + r.startLong = offset * step; + r.stepLong = step * sizeBlock; + r.dataType = dataType; + if (null != ic) { + r.ic = ic; + r.sizeBlock = sizeBlock * (int) scale; + } + return r; + } + + /** + * 计算直方图显示区间的划分方案 + * + * @param left :显示区域左边界 + * @param right :显示区域右边界 + * @param preferStep:建议步长 + * @param preferN :建议直方图区间个数 + * @param ic :IntervalCalculator类型数据,其中包含基本细分区间 + * @return 直方图显示区间的划分方案 + */ + static Interval4Calc calculate(String left, String right, String preferStep, int preferN, IntervalCalculator ic) { + if (null != ic) { + if (null == left) { + left = ic.getLeftBound().toString(); + } + if (null == right) { + right = ic.getTag(ic.n).toString(); + } + } + if (null == left || null == right || Double.parseDouble(left) >= Double.parseDouble(right)) { + throw new AkIllegalStateException(""); + } + + BigDecimal stepBD = new BigDecimal(1); + double min = Double.parseDouble(left); + double max = Double.parseDouble(right); + if (null != ic) { + stepBD = ic.getStep(); + } else { + int k = -300; //double类型的最小精度 + if (0 != min || 0 != max) { + int k1 = (int) Math.log10(Math.abs(min) + Math.abs(max)); + k = Math.max(k, k1 - 19);//long型数据大约19个有效数字 + + if (min != max) { + int k2 = (int) (Math.log10(max - min) - Math.log10(preferN)); + k = Math.max(k, k2); + } + } + if (k > 1) { + stepBD = BigDecimal.TEN.pow(k - 1); + } else if (k <= 0) { + stepBD = new BigDecimal(1).divide(BigDecimal.TEN.pow(1 - k)); + } + } + long valLeft, valRight; + valLeft = IntervalCalculator.calcIntervalVal(min, stepBD); + valRight = IntervalCalculator.calcIntervalVal(max, stepBD); + int sizeBlock; + long scale; + if (null == preferStep) { + if (preferN < 1) { + throw new AkIllegalStateException(""); + } + scale = 1; + if ((valRight - valLeft) / scale > preferN) { + long nextscale = scale * 10; + while ((valRight - valLeft) / nextscale > preferN) { + scale = nextscale; + nextscale *= 10; + } + } + sizeBlock = Math.max(1, (int) Math.floor((valRight - valLeft) / scale / preferN)); + } else { + scale = 1; + BigDecimal preferStepBD = new BigDecimal(preferStep); + if (stepBD.multiply(BigDecimal.valueOf(scale)).subtract(preferStepBD).signum() < 0) { + long nextscale = scale * 10; + while (stepBD.multiply(BigDecimal.valueOf(nextscale)).subtract(preferStepBD).signum() < 0) { + scale = nextscale; + nextscale *= 10; + } + } + sizeBlock = Math.max(1, + preferStepBD.divide(stepBD.multiply(BigDecimal.valueOf(scale)), 2, RoundingMode.FLOOR).intValue()); + } + stepBD = stepBD.multiply(BigDecimal.valueOf(scale)); + + long offset = IntervalCalculator.calcIntervalVal(new BigDecimal(left), stepBD); + long end = IntervalCalculator.calcIntervalVal(new BigDecimal(right), stepBD); + if (stepBD.multiply(BigDecimal.valueOf(end)).subtract(new BigDecimal(right)).signum() == 0) { + end -= 1; + } + int n = (int) ((end - offset + 1 + sizeBlock - 1) / sizeBlock); + Interval4Calc r = new Interval4Calc(); + r.nBlock = n; + r.startBD = BigDecimal.valueOf(offset).multiply(stepBD); + r.stepBD = stepBD.multiply(BigDecimal.valueOf(sizeBlock)); + r.dataType = "Decimal"; + if (null != ic) { + r.sizeBlock = sizeBlock * (int) scale; + r.ic = ic; + } + return r; + } + + @Override + public String toString() { + return "Interval2{" + "nBlock=" + nBlock + ", startLong=" + startLong + ", stepLong=" + stepLong + ", startBD=" + + startBD + ", stepBD=" + stepBD + ", colType=" + dataType + ", ic=" + (ic != null) + ", sizeBlock=" + + sizeBlock + '}'; + } + + ///////////////////////////////////////////////////////////////// + // 由直方图显示区间的划分方案 + 频率数据或IntervalCalculator类型 + // 数据(其中包含基本细分区间),计算直方图显示区间的数据 + ///////////////////////////////////////////////////////////////// + + /** + * @param items :数据频率中的元素项 + * @param vals :数据频率中的频数项 + * @return + */ + Interval4Display toInterval4Display(double[] items, long[] vals) { + Interval4Display + dis = new Interval4Display(); + int n = this.nBlock; + dis.n = n; + dis.step = this.stepBD.stripTrailingZeros().toPlainString(); + dis.count = new long[n]; + dis.tags = new String[n + 1]; + for (int i = 0; i <= n; i++) { + dis.tags[i] = this.stepBD.multiply(BigDecimal.valueOf(i)).add(this.startBD).stripTrailingZeros() + .toPlainString(); + } + double left = this.startBD.doubleValue(); + double right = this.stepBD.multiply(BigDecimal.valueOf(n)).add(this.startBD).doubleValue(); + for (int i = 0; i < items.length; i++) { + if (items[i] >= left && items[i] < right) { + long idx = IntervalCalculator.calcIntervalVal(BigDecimal.valueOf(items[i]).subtract(this.startBD), + this.stepBD); + if (idx >= 0 && idx < n) { + dis.count[(int) idx] += vals[i]; + } + } + } + return dis; + } + + /** + * @param items :数据频率中的元素项 + * @param vals :数据频率中的频数项 + * @param dataType :数据类型 + * @return + */ + Interval4Display toInterval4Display(long[] items, long[] vals, String dataType) { + if (!this.dataType.equals(dataType)) { + throw new AkIllegalStateException(""); + } + Interval4Display + dis = new Interval4Display(); + int n = this.nBlock; + dis.n = n; + dis.step = Long.toString(stepLong); + dis.count = new long[n]; + dis.tags = new String[n + 1]; + for (int i = 0; i <= n; i++) { + if (this.dataType.equals("Date")) { + dis.tags[i] = myformat.format(new Date(this.startLong + i * this.stepLong)); + } else if (this.dataType.equals("Integer")) { + dis.tags[i] = String.valueOf(this.startLong + i * this.stepLong); + } else { + throw new AkIllegalStateException(""); + } + } + for (int i = 0; i < items.length; i++) { + long idx = IntervalCalculator.calcIntervalVal(items[i] - startLong, stepLong); + if (idx >= 0 && idx < n) { + dis.count[(int) idx] += vals[i]; + } + } + return dis; + } + + Interval4Display toInterval4Display() { + if (null == this.ic) { + throw new AkIllegalStateException(""); + } + Interval4Display dis = new Interval4Display(); + int n = this.nBlock; + dis.n = n; + if (this.dataType.equals("Decimal")) { + dis.step = this.stepBD.stripTrailingZeros().toPlainString(); + } else { + dis.step = Long.toString(this.stepLong); + } + dis.count = new long[n]; + dis.tags = new String[n + 1]; + if (ic.nCol > 0) { + dis.nCol = ic.nCol; + dis.mcs = new IntervalMeasureCalculator[n][dis.nCol]; + for (int i = 0; i < n; i++) { + for (int j = 0; j < dis.nCol; j++) { + dis.mcs[i][j] = new IntervalMeasureCalculator(); + } + } + } + long offset = -1; + if (ic.getDataType().equals("Decimal")) { + offset = IntervalCalculator.calcIntervalVal(this.startBD.subtract(ic.getLeftBound()), ic.getStep()); + } else { + offset = (this.startLong - ic.getLeftBound().longValue()) / ic.getStep().longValue(); + } + if (ic.getDataType().equals("Date")) { + for (int i = 0; i <= n; i++) { + BigDecimal bd = ic.getTag(offset + i * sizeBlock); + dis.tags[i] = myformat.format(new Date(bd.longValue())); + } + } else { + for (int i = 0; i <= n; i++) { + BigDecimal bd = ic.getTag(offset + i * sizeBlock); + if (bd.compareTo(BigDecimal.ZERO) == 0) { + dis.tags[i] = "0"; + } else { + dis.tags[i] = bd.stripTrailingZeros().toPlainString(); + } + } + } + for (int i = 0; i < n; i++) { + for (long idx = offset + i * sizeBlock; idx < offset + (i + 1) * sizeBlock; idx++) { + if (idx >= 0 && idx < ic.n) { + dis.count[i] += ic.count[(int) idx]; + if (null != dis.mcs) { + for (int j = 0; j < dis.nCol; j++) { + if (ic.mcs != null && ic.mcs[(int) idx] != null && ic.mcs[(int) idx][j] != null) { + dis.mcs[i][j].calculate(ic.mcs[(int) idx][j]); + } + } + } + } + } + } + + return dis; + } + + static long Max(long[] counts) { + long max = Long.MIN_VALUE; + for (int i = 0; i < counts.length; i++) { + if (max < counts[i]) { + max = counts[i]; + } + } + return max; + } + + static long Min(long[] counts) { + long min = Long.MAX_VALUE; + for (long count : counts) { + if (min > count) { + min = count; + } + } + return min; + } + + public static double Max(double[] counts) { + double max = Double.MIN_VALUE; + for (double count : counts) { + if (max < count) { + max = count; + } + } + return max; + } + + static double Min(double[] counts) { + double min = Double.MAX_VALUE; + for (double count : counts) { + if (min > count) { + min = count; + } + } + return min; + } + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/Interval4Display.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/Interval4Display.java new file mode 100644 index 000000000..97fdd5189 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/interval/Interval4Display.java @@ -0,0 +1,28 @@ +package com.alibaba.alink.operator.common.statistics.interval; + +import com.alibaba.alink.operator.common.statistics.statistics.IntervalMeasureCalculator; + +/** + * @author yangxu + */ +public class Interval4Display { + + public int n; + public long[] count = null; + public int nCol = -1; + public IntervalMeasureCalculator[][] mcs = null; + public String[] tags = null; + public String step = null; + + @Override + public String toString() { + StringBuilder sbd = new StringBuilder(); + sbd.append( + "Interval4Display{" + "n=" + n + ", count=" + count + ", nCol=" + nCol + ", mcs=" + mcs + ", tags=" + tags + + ", step=" + step + '}'); + for (int i = 0; i < n; i++) { + sbd.append("\n[" + tags[i] + " , " + tags[i + 1] + ") : " + count[i]); + } + return sbd.toString(); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/BaseMeasureIterator.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/BaseMeasureIterator.java new file mode 100644 index 000000000..5843438ff --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/BaseMeasureIterator.java @@ -0,0 +1,31 @@ +package com.alibaba.alink.operator.common.statistics.statistics; + +public interface BaseMeasureIterator> { + /** + * add val. + */ + void visit(T val); + + /** + * merge iterator. + */ + void merge(I iterator); + + /** + * clone iterator. + */ + I clone(); + + /** + * missing value number. + */ + long missingCount(); + + /** + * valid value number. + */ + long count(); + + void finalResult(SummaryResultCol src); + +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/BooleanMeasureIterator.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/BooleanMeasureIterator.java new file mode 100644 index 000000000..32290f561 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/BooleanMeasureIterator.java @@ -0,0 +1,88 @@ +package com.alibaba.alink.operator.common.statistics.statistics; + +public class BooleanMeasureIterator implements BaseMeasureIterator { + /** + * missing value number. + */ + public long countMissing = 0; + + /** + * true value number. + */ + public long countTrue = 0; + + /** + * false value number. + */ + public long countFalse = 0; + + public BooleanMeasureIterator() { + } + + @Override + public void visit(Boolean val) { + if (null == val) { + countMissing++; + } else if (val) { + countTrue++; + } else { + countFalse++; + } + } + + @Override + public void merge(BooleanMeasureIterator iterator) { + this.countTrue += iterator.countTrue; + this.countFalse += iterator.countFalse; + this.countMissing += iterator.countMissing; + } + + @Override + public BooleanMeasureIterator clone() { + BooleanMeasureIterator iterator = new BooleanMeasureIterator(); + iterator.countMissing = this.countMissing; + iterator.countTrue = this.countTrue; + iterator.countFalse = this.countFalse; + return iterator; + } + + @Override + public long missingCount() { + return countMissing; + } + + @Override + public long count() { + return countFalse + countTrue; + } + + @Override + public void finalResult(SummaryResultCol src) { + long count = countTrue + countFalse; + Boolean min = null; + Boolean max = null; + if (count > 0) { + min = Boolean.FALSE; + max = Boolean.TRUE; + if (0 == countTrue) { + max = Boolean.FALSE; + } + if (0 == countFalse) { + min = Boolean.TRUE; + } + } + src.init(null, count + countMissing, count, countMissing, 0, 0, 0, countFalse, + countTrue, countTrue, countTrue, countTrue, countTrue, min, max); + } + + @Override + public String toString() { + String result = ""; + result += String.valueOf(countTrue); + result += " "; + result += String.valueOf(countFalse); + result += " "; + result += String.valueOf(countMissing); + return result; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/DateMeasureIterator.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/DateMeasureIterator.java new file mode 100644 index 000000000..027bf259d --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/DateMeasureIterator.java @@ -0,0 +1,151 @@ +package com.alibaba.alink.operator.common.statistics.statistics; + +import java.util.Date; + +public class DateMeasureIterator implements BaseMeasureIterator > { + /** + * valid value count. + */ + public long count = 0; + + /** + * missing value count. + */ + public long countMissing = 0; + /** + * * + * 数据和 + */ + public double sum = 0.0; + /** + * * + * 数据平方和 + */ + public double sum2 = 0.0; + /** + * * + * 数据立方和 + */ + public double sum3 = 0.0; + /** + * * + * 数据四次方和 + */ + public double sum4 = 0.0; + /** + * * + * 最小值 + */ + public T min = null; + /** + * * + * 最大值 + */ + public T max = null; + + private boolean needInit = true; + + public DateMeasureIterator() { + } + + @Override + public long missingCount() { + return countMissing; + } + + @Override + public long count() { + return count; + } + + @Override + public void visit(T val) { + if (null == val) { + countMissing++; + } else { + count++; + + if (needInit) { + min = val; + max = val; + needInit = false; + } else { + if (val.compareTo(min) < 0) { + min = val; + } + if (val.compareTo(max) > 0) { + max = val; + } + } + + double d = (double) val.getTime(); + sum += d; + sum2 += d * d; + sum3 += d * d * d; + sum4 += d * d * d * d; + } + } + + @Override + public void merge(DateMeasureIterator iterator) { + if (null == this.min) { + this.min = iterator.min; + } else if (null != iterator.min && this.min.compareTo(iterator.min) < 0) { + this.min = iterator.min; + } + if (null == this.max) { + this.max = iterator.max; + } else if (null != iterator.max && this.max.compareTo(iterator.max) > 0) { + this.max = iterator.max; + } + + this.sum += iterator.sum; + this.sum2 += iterator.sum2; + this.sum3 += iterator.sum3; + this.sum4 += iterator.sum4; + this.count += iterator.count; + this.countMissing += iterator.countMissing; + } + + @Override + public DateMeasureIterator clone() { + DateMeasureIterator iterator = new DateMeasureIterator (); + iterator.count = this.count; + iterator.countMissing = this.countMissing; + iterator.sum = this.sum; + iterator.sum2 = this.sum2; + iterator.sum3 = this.sum3; + iterator.sum4 = this.sum4; + iterator.min = this.min; + iterator.max = this.max; + iterator.needInit = this.needInit; + return iterator; + } + + @Override + public void finalResult(SummaryResultCol src) { + src.init(null, + count + countMissing, count, countMissing, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + min, max); + } + + @Override + public String toString() { + String result = ""; + result += String.valueOf(count); + result += " "; + result += String.valueOf(sum); + result += " "; + result += String.valueOf(sum2); + result += " "; + result += String.valueOf(sum3); + result += " "; + result += String.valueOf(sum4); + result += " "; + result += String.valueOf(min); + result += " "; + result += String.valueOf(max); + return result; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/DistinctValueIterator.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/DistinctValueIterator.java new file mode 100644 index 000000000..b8afb3f63 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/DistinctValueIterator.java @@ -0,0 +1,15 @@ +package com.alibaba.alink.operator.common.statistics.statistics; + +import java.util.HashSet; + +public class DistinctValueIterator { + public HashSet mapFreq = null; + + public DistinctValueIterator() { + this.mapFreq = new HashSet <>(); + } + + public void visit(T val) { + mapFreq.add(val); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/FrequencyIterator.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/FrequencyIterator.java new file mode 100644 index 000000000..c6de3ca68 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/FrequencyIterator.java @@ -0,0 +1,46 @@ +package com.alibaba.alink.operator.common.statistics.statistics; + +import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; + +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Map.Entry; +import java.util.TreeMap; + +public class FrequencyIterator { + private final int capacity; + private HashMap mapFreq; + private boolean inRange = true; + + public FrequencyIterator(int capacity) { + this.capacity = capacity; + if (capacity <= 0 || capacity == Integer.MAX_VALUE) { + throw new AkIllegalArgumentException("Wrong capacity value for Statistic Frequency."); + } + mapFreq = new HashMap <>(capacity + 1); + } + + public void visit(T val) { + if (inRange && null != val) { + this.mapFreq.merge(val, 1L, Long::sum); + if (mapFreq.size() > this.capacity) { + this.inRange = false; + this.mapFreq.clear(); + } + } + } + + public void finalResult(SummaryResultCol src) { + if (inRange) { + src.freq = new TreeMap (); + Iterator > it = this.mapFreq.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry e = it.next(); + src.freq.put(e.getKey(), e.getValue()); + } + } else { + src.freq = null; + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/FullStats.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/FullStats.java index 9af1d17f7..35796b710 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/FullStats.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/FullStats.java @@ -3,7 +3,8 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple2; -import com.alibaba.alink.common.AlinkTypes; +import com.alibaba.alink.common.type.AlinkTypes; +import com.alibaba.alink.metadata.def.v0.BytesStatistics; import com.alibaba.alink.metadata.def.v0.CommonStatistics; import com.alibaba.alink.metadata.def.v0.DatasetFeatureStatistics; import com.alibaba.alink.metadata.def.v0.DatasetFeatureStatisticsList; @@ -18,7 +19,11 @@ import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter; import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; import java.util.Map; +import java.util.Map.Entry; import java.util.TreeMap; /** @@ -56,8 +61,9 @@ public static FullStats fromSummaryResultTable(String[] tableNames, String[] col String[] colNames = srt.colNames; for (int i = 0; i < colNames.length; i++) { String colName = colNames[i]; - if (AlinkTypes.DOUBLE == colTypes[i] || AlinkTypes.FLOAT == colTypes[i] - || AlinkTypes.LONG == colTypes[i] || AlinkTypes.INT == colTypes[i]) { + if (Number.class.isAssignableFrom(colTypes[i].getTypeClass())) { + //if (AlinkTypes.DOUBLE == colTypes[i] || AlinkTypes.FLOAT == colTypes[i] + // || AlinkTypes.LONG == colTypes[i] || AlinkTypes.INT == colTypes[i]) { SummaryResultCol src = srt.col(colName); @@ -89,7 +95,8 @@ public static FullStats fromSummaryResultTable(String[] tableNames, String[] col } Type feaType = - (AlinkTypes.DOUBLE == colTypes[i] || AlinkTypes.FLOAT == colTypes[i]) ? Type.FLOAT : Type.INT; + (AlinkTypes.DOUBLE == colTypes[i] || AlinkTypes.FLOAT == colTypes[i] + || AlinkTypes.BIG_DEC == colTypes[i]) ? Type.FLOAT : Type.INT; builder.addFeatures( FeatureNameStatistics.newBuilder() @@ -129,18 +136,25 @@ public static FullStats fromSummaryResultTable(String[] tableNames, String[] col .setNumMissing(src.countMissValue) .setTotNumValues(src.countTotal) .setNumNonMissing(src.count) - ); + ) + .setAvgLength((float) src.mean()); + if (src.hasFreq()) { TreeMap freq = src.getFrequencyMap(); stringBuilder.setUnique(freq.size()); - int k = 0; - for (Map.Entry entry : freq.entrySet()) { - stringBuilder.addTopValues(k++, - FreqAndValue.newBuilder() - .setValue(entry.getKey().toString()) - .setFrequency(entry.getValue()) - ); - } + + ArrayList > list = new ArrayList <>(freq.entrySet()); + Collections.sort(list, new Comparator >() { + @Override + public int compare(Entry o1, Entry o2) { + return o2.getValue().compareTo(o1.getValue()); + } + }); + stringBuilder.addTopValues(0, + FreqAndValue.newBuilder() + .setValue(list.get(0).getKey().toString()) + .setFrequency(list.get(0).getValue()) + ); for (Map.Entry entry : freq.entrySet()) { stringBuilder.getRankHistogramBuilder().addBuckets( @@ -155,6 +169,119 @@ public static FullStats fromSummaryResultTable(String[] tableNames, String[] col .setType(Type.STRING) .setStringStats(stringBuilder) ); + } else if (AlinkTypes.BOOLEAN == colTypes[i]) { + SummaryResultCol src = srt.col(colName); + + Long countTrue = (long) src.sum(); + Long countFalse = src.count - (long) src.sum(); + + Histogram.Builder histoBuilder = Histogram.newBuilder() + .setType(HistogramType.STANDARD); + + histoBuilder.addBuckets(Histogram.Bucket.newBuilder() + .setLowValue(0.0) + .setHighValue(1.0) + .setSampleCount(countFalse) + ); + histoBuilder.addBuckets(Histogram.Bucket.newBuilder() + .setLowValue(1.0) + .setHighValue(2.0) + .setSampleCount(countTrue) + ); + + Histogram.Builder percentileBuilder = Histogram.newBuilder() + .setType(HistogramType.QUANTILES); + + int k = (int) (countFalse * 10 / src.count); + for (int j = 0; j < 10; j++) { + percentileBuilder.addBuckets(Histogram.Bucket.newBuilder() + .setLowValue((j <= k) ? 0.0 : 1.0) + .setHighValue((j < k) ? 0.0 : 1.0) + .setSampleCount(src.count / 10.0) + ); + } + + builder.addFeatures( + FeatureNameStatistics.newBuilder() + .setName(colName) + .setType(Type.INT) + .setNumStats( + NumericStatistics.newBuilder() + .setCommonStats( + CommonStatistics.newBuilder() + .setNumMissing(src.countMissValue) + .setTotNumValues(src.countTotal) + .setNumNonMissing(src.count) + .setAvgNumValues(1) + .setMinNumValues(1) + .setMaxNumValues(1) + ) + .setNumZeros(src.countZero) + .setMax(src.maxDouble()) + .setMin(src.minDouble()) + .setMean(src.mean()) + .setStdDev(src.standardDeviation()) + .setMedian((countTrue >= countFalse) ? 1.0 : 0.0) + .addHistograms(histoBuilder) + .addHistograms(percentileBuilder) + ) + ); + + StringStatistics.Builder stringBuilder = StringStatistics.newBuilder() + .setCommonStats( + CommonStatistics.newBuilder() + .setNumMissing(src.countMissValue) + .setTotNumValues(src.countTotal) + .setNumNonMissing(src.count) + ) + .setUnique(((countTrue > 0) ? 1 : 0) + ((countFalse > 0) ? 1 : 0)); + + stringBuilder.addTopValues(0, + FreqAndValue.newBuilder() + .setValue((countTrue >= countFalse) ? "true" : "false") + .setFrequency((countTrue >= countFalse) ? countTrue : countFalse) + ); + + stringBuilder.getRankHistogramBuilder().addBuckets( + Bucket.newBuilder().setLabel("true").setSampleCount(countTrue) + ); + stringBuilder.getRankHistogramBuilder().addBuckets( + Bucket.newBuilder().setLabel("false").setSampleCount(countFalse) + ); + + builder.addFeatures( + FeatureNameStatistics.newBuilder() + .setName(colName + "_categorical") + .setType(Type.STRING) + .setStringStats(stringBuilder) + ); + + } else if (AlinkTypes.VARBINARY == colTypes[i]) { + SummaryResultCol src = srt.col(colName); + + BytesStatistics.Builder bytesBuilder = BytesStatistics.newBuilder() + .setCommonStats( + CommonStatistics.newBuilder() + .setNumMissing(src.countMissValue) + .setTotNumValues(src.countTotal) + .setNumNonMissing(src.count) + ) + .setAvgNumBytes((float) src.mean()) + .setMinNumBytes((float) src.minDouble()) + .setMaxNumBytes((float) src.maxDouble()) + .setMaxNumBytesInt((null == src.max) ? 0L : ((Number) src.max).longValue()); + + if (src.hasFreq()) { + TreeMap freq = src.getFrequencyMap(); + bytesBuilder.setUnique(freq.size()); + } + + builder.addFeatures( + FeatureNameStatistics.newBuilder() + .setName(colName) + .setType(Type.BYTES) + .setBytesStats(bytesBuilder) + ); } } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/Interval4Calc.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/Interval4Calc.java index 15ea94ae7..6a6ee505d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/Interval4Calc.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/Interval4Calc.java @@ -484,10 +484,10 @@ Interval4Display toInterval4Display() { dis.tags = new String[n + 1]; if (ic.nCol > 0) { dis.nCol = ic.nCol; - dis.mcs = new MeasureCalculator[n][dis.nCol]; + dis.mcs = new IntervalMeasureCalculator[n][dis.nCol]; for (int i = 0; i < n; i++) { for (int j = 0; j < dis.nCol; j++) { - dis.mcs[i][j] = new MeasureCalculator(); + dis.mcs[i][j] = new IntervalMeasureCalculator(); } } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/Interval4Display.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/Interval4Display.java index 4adc8f628..dc30c487d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/Interval4Display.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/Interval4Display.java @@ -8,7 +8,7 @@ public class Interval4Display { public int n; public long[] count = null; public int nCol = -1; - public MeasureCalculator[][] mcs = null; + public IntervalMeasureCalculator[][] mcs = null; public String[] tags = null; public String step = null; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/IntervalCalculator.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/IntervalCalculator.java index ce9594516..14564bfc5 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/IntervalCalculator.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/IntervalCalculator.java @@ -2,7 +2,6 @@ import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; import com.alibaba.alink.common.exceptions.AkIllegalDataException; -import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import java.io.Serializable; @@ -52,7 +51,7 @@ public class IntervalCalculator implements Cloneable, Serializable { * * * 每个基本区间内数据的基本统计计算量 */ - public MeasureCalculator[][] mcs = null; + public IntervalMeasureCalculator[][] mcs = null; public int magnitude; // magnitude < n <= 10 * magnitude public long startIndex; @@ -75,7 +74,7 @@ public IntervalCalculator(double start, double step, long[] count, int magnitude this(start, step, count, null, magnitude); } - public IntervalCalculator(double start, double step, long[] count, MeasureCalculator[][] mcs, int magnitude) { + public IntervalCalculator(double start, double step, long[] count, IntervalMeasureCalculator[][] mcs, int magnitude) { this(Double.class, Math.round(start / step), Math.round(Math.log10(step)) - 1000, count, mcs, magnitude); } @@ -87,7 +86,7 @@ public IntervalCalculator(long start, long step, long[] count, int magnitude) { this(start, step, count, null, magnitude); } - public IntervalCalculator(long start, long step, long[] count, MeasureCalculator[][] mcs, int magnitude) { + public IntervalCalculator(long start, long step, long[] count, IntervalMeasureCalculator[][] mcs, int magnitude) { this(Long.class, start, step, count, mcs, magnitude); } @@ -95,7 +94,7 @@ public IntervalCalculator(Class type, long start, long step, long[] count) { this(type, start, step, count, null, DefaultMagnitude); } - IntervalCalculator(Class type, long start, long step, long[] count, MeasureCalculator[][] mcs, int magnitude) { + IntervalCalculator(Class type, long start, long step, long[] count, IntervalMeasureCalculator[][] mcs, int magnitude) { if (Double.class == type || Float.class == type) { this.type = "Decimal"; } else if (Long.class == type || Integer.class == type) { @@ -245,11 +244,11 @@ public static IntervalCalculator getEmptyInterval(long min, long max, int nCol) } public static IntervalCalculator getEmptyInterval(long min, long max, int nCol, int magnitude) { - MeasureCalculator[][] tmpmcs = null; + IntervalMeasureCalculator[][] tmpmcs = null; if (nCol > 0) { - tmpmcs = new MeasureCalculator[1][nCol]; + tmpmcs = new IntervalMeasureCalculator[1][nCol]; for (int i = 0; i < nCol; i++) { - tmpmcs[0][i] = new MeasureCalculator(); + tmpmcs[0][i] = new IntervalMeasureCalculator(); } } @@ -277,11 +276,11 @@ public static IntervalCalculator getEmptyInterval(double min, double max, int nC } long minBD = calcIntervalValBD(min, stepBD).longValue(); - MeasureCalculator[][] tmpmcs = null; + IntervalMeasureCalculator[][] tmpmcs = null; if (nCol > 0) { - tmpmcs = new MeasureCalculator[1][nCol]; + tmpmcs = new IntervalMeasureCalculator[1][nCol]; for (int i = 0; i < nCol; i++) { - tmpmcs[0][i] = new MeasureCalculator(); + tmpmcs[0][i] = new IntervalMeasureCalculator(); } } @@ -487,8 +486,8 @@ public long[] getCount() { return this.count.clone(); } - public MeasureCalculator[] updateMeasureCalculatorsByCol(int idx) throws Exception { - MeasureCalculator[] measureCalculators = new MeasureCalculator[n]; + public IntervalMeasureCalculator[] updateMeasureCalculatorsByCol(int idx) throws Exception { + IntervalMeasureCalculator[] measureCalculators = new IntervalMeasureCalculator[n]; for (int i = 0; i < this.mcs.length; i++) { measureCalculators[i] = mcs[i][idx]; } @@ -706,10 +705,10 @@ private void subUpgrade(long scale, long startNew, long endNew) { } if (null != this.mcs) { - MeasureCalculator[][] mscNew = new MeasureCalculator[nNew][this.nCol]; + IntervalMeasureCalculator[][] mscNew = new IntervalMeasureCalculator[nNew][this.nCol]; for (int i = 0; i < nNew; i++) { for (int j = 0; j < this.nCol; j++) { - mscNew[i][j] = new MeasureCalculator(); + mscNew[i][j] = new IntervalMeasureCalculator(); } } for (int i = 0; i < n; i++) { diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/MeasureCalculator.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/IntervalMeasureCalculator.java similarity index 93% rename from core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/MeasureCalculator.java rename to core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/IntervalMeasureCalculator.java index 00b9de212..8ad71e29c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/MeasureCalculator.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/IntervalMeasureCalculator.java @@ -5,7 +5,7 @@ /** * @author yangxu */ -public class MeasureCalculator implements Serializable { +public class IntervalMeasureCalculator implements Serializable { private static final long serialVersionUID = -4625259870390900630L; /** @@ -44,7 +44,7 @@ public class MeasureCalculator implements Serializable { */ public double maxDouble; - public MeasureCalculator() { + public IntervalMeasureCalculator() { count = 0; sum = 0.0; sum2 = 0.0; @@ -82,7 +82,7 @@ public void calculate(double d) { * * @param mc 一个MeasureCalculator实例 */ - public void calculate(MeasureCalculator mc) { + public void calculate(IntervalMeasureCalculator mc) { if (mc.minDouble < this.minDouble) { this.minDouble = mc.minDouble; } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/NumberMeasureIterator.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/NumberMeasureIterator.java new file mode 100644 index 000000000..62438a3f4 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/NumberMeasureIterator.java @@ -0,0 +1,171 @@ +package com.alibaba.alink.operator.common.statistics.statistics; + +public class NumberMeasureIterator> implements + BaseMeasureIterator > { + /** + * valid number of data. + */ + public long count = 0; + + /** + * missing value number. + */ + public long countMissing = 0; + + /** + * nan value number. + */ + public long countNaN = 0; + + + private long countZero4SRC = 0; + /** + * * + * 数据和 + */ + public double sum = 0.0; + /** + * * + * 数据平方和 + */ + public double sum2 = 0.0; + /** + * * + * 数据立方和 + */ + public double sum3 = 0.0; + /** + * * + * 数据四次方和 + */ + public double sum4 = 0.0; + /** + * * + * 最小值 + */ + public N min = null; + /** + * * + * 最大值 + */ + public N max = null; + + /** + * l1 norm. + */ + public double normL1 = 0.0; + + private boolean needInit = true; + + public NumberMeasureIterator() { + } + + @Override + public long missingCount() { + return countMissing; + } + + @Override + public long count() { + return count; + } + + @Override + public void visit(N val) { + if (null == val) { + countMissing++; + } else if (val != val) { + countNaN++; + } else { + count++; + + if (needInit) { + min = val; + max = val; + needInit = false; + } else { + if (val.compareTo(min) < 0) { + min = val; + } + if (val.compareTo(max) > 0) { + max = val; + } + } + + double d = val.doubleValue(); + sum += d; + sum2 += d * d; + sum3 += d * d * d; + sum4 += d * d * d * d; + normL1 += Math.abs(d); + countZero4SRC += (0.0 == d) ? 1 : 0; + } + } + + @Override + public void merge(NumberMeasureIterator iterator) { + if (null == this.min) { + this.min = iterator.min; + } else if (null != iterator.min && this.min.compareTo(iterator.min) < 0) { + this.min = iterator.min; + } + if (null == this.max) { + this.max = iterator.max; + } else if (null != iterator.max && this.max.compareTo(iterator.max) > 0) { + this.max = iterator.max; + } + + this.sum += iterator.sum; + this.sum2 += iterator.sum2; + this.sum3 += iterator.sum3; + this.sum4 += iterator.sum4; + this.count += iterator.count; + this.countNaN += iterator.countNaN; + this.countMissing += iterator.countMissing; + this.normL1 += iterator.normL1; + this.countZero4SRC += iterator.countZero4SRC; + } + + @Override + public NumberMeasureIterator clone() { + NumberMeasureIterator iterator = new NumberMeasureIterator (); + iterator.count = this.count; + iterator.countMissing = this.countMissing; + iterator.countNaN = this.countNaN; + iterator.sum = this.sum; + iterator.sum2 = this.sum2; + iterator.sum3 = this.sum3; + iterator.sum4 = this.sum4; + iterator.min = this.min; + iterator.max = this.max; + iterator.needInit = this.needInit; + iterator.countZero4SRC = this.countZero4SRC; + return iterator; + } + + @Override + public void finalResult(SummaryResultCol src) { + src.init(null, count + countMissing + countNaN, count, + countMissing, countNaN, 0, 0, countZero4SRC, + sum, normL1, sum2, sum3, sum4, min, max); + } + + @Override + public String toString() { + String result = ""; + result += String.valueOf(count); + result += " "; + result += String.valueOf(sum); + result += " "; + result += String.valueOf(sum2); + result += " "; + result += String.valueOf(sum3); + result += " "; + result += String.valueOf(sum4); + result += " "; + result += String.valueOf(min); + result += " "; + result += String.valueOf(max); + return result; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/ObjectMeasureIterator.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/ObjectMeasureIterator.java new file mode 100644 index 000000000..1c007f347 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/ObjectMeasureIterator.java @@ -0,0 +1,66 @@ +package com.alibaba.alink.operator.common.statistics.statistics; + +public class ObjectMeasureIterator implements BaseMeasureIterator > { + /** + * valid value count. + */ + public long count = 0; + + /** + * missing value count. + */ + public long countMissing = 0; + + public ObjectMeasureIterator() { + } + + + @Override + public long missingCount() { + return countMissing; + } + + @Override + public long count() { + return count; + } + + + @Override + public void visit(T val) { + if (null == val) { + countMissing++; + } else { + count++; + } + } + + @Override + public void merge(ObjectMeasureIterator iterator) { + this.count += iterator.count; + this.countMissing += iterator.countMissing; + } + + @Override + public ObjectMeasureIterator clone() { + ObjectMeasureIterator iterator = new ObjectMeasureIterator (); + iterator.count = this.count; + iterator.countMissing = this.countMissing; + return iterator; + } + + @Override + public void finalResult(SummaryResultCol src) { + src.init(null, count + countMissing, count, countMissing, 0, 0, 0, 0, 0, + 0, 0, 0, 0, null, null); + } + + @Override + public String toString() { + String result = ""; + result += String.valueOf(count); + result += " "; + result += String.valueOf(countMissing); + return result; + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/Summary2.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/SrtUtil.java similarity index 51% rename from core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/Summary2.java rename to core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/SrtUtil.java index d37eda9fd..d7a8a64bb 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/Summary2.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/SrtUtil.java @@ -9,27 +9,21 @@ import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.exceptions.AkIllegalStateException; -import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; import com.alibaba.alink.common.utils.TableUtil; -import com.alibaba.alink.operator.common.statistics.StatisticUtil; -import com.alibaba.alink.operator.common.statistics.basicstat.WindowTable; +import com.alibaba.alink.operator.batch.statistics.utils.StatisticUtil; import com.alibaba.alink.params.statistics.HasStatLevel_L1; import java.util.ArrayList; -import java.util.Comparator; import java.util.Date; import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.PriorityQueue; -import java.util.TreeMap; /** * @author yangxu */ -public class Summary2 { +public class SrtUtil { //for stream: stream iter can iter two times. public static SummaryResultTable streamSummary(WindowTable wt, String[] statColNames, @@ -119,9 +113,9 @@ static SummaryResultTable basicSummary(WindowTable wt, String[] statColNames, int nStat = idxStat.length; - MeasureIteratorBase[] mis = newMeasures(wt.colTypes, idxStat); + BaseMeasureIterator[] mis = newMeasures(wt.colTypes, idxStat); FrequencyIterator[] fis = newFreqs(wt.colTypes, idxStat, bCalcFreq, freqSize); - TopKInterator[] tis = newTopK(wt.colTypes, idxStat, bCalcFreq, smallK, largeK); + TopKIterator[] tis = newTopK(wt.colTypes, idxStat, bCalcFreq, smallK, largeK); srt.dotProduction = newDotProduction(nStat, bCov); Row val; @@ -132,7 +126,7 @@ static SummaryResultTable basicSummary(WindowTable wt, String[] statColNames, Object obj = val.getField(idxStat[i]); mis[i].visit(obj); if (tis != null && tis[i] != null) { - tis[i].visit(obj); + tis[i].visit((Comparable) obj); } if (fis != null && fis[i] != null) { fis[i].visit(obj); @@ -142,6 +136,7 @@ static SummaryResultTable basicSummary(WindowTable wt, String[] statColNames, for (int i = 0; i < nStat; i++) { mis[i].finalResult(srt.src[i]); + srt.src[i].dataType = wt.colTypes[idxStat[i]]; if (tis != null && tis[i] != null) { tis[i].finalResult(srt.src[i]); } @@ -185,12 +180,12 @@ public static SrtForWp summaryForWp(String[] colNames, Class[] colTypes, Iterabl int[] idxStat = TableUtil.findColIndicesWithAssertAndHint(colNames, statColNames); int nStat = idxStat.length; - MeasureIteratorBase[] mis = newMeasures(colTypes, idxStat); + BaseMeasureIterator[] mis = newMeasures(colTypes, idxStat); DistinctValueIterator[] fis = new DistinctValueIterator[nStat]; for (int j = 0; j < nStat; j++) { if (needFreqs[j]) { - fis[j] = new DistinctValueIterator(colTypes[idxStat[j]]); + fis[j] = StatisticsIteratorFactory.getDistinctValueIterator(colTypes[idxStat[j]]); } } @@ -208,6 +203,7 @@ public static SrtForWp summaryForWp(String[] colNames, Class[] colTypes, Iterabl for (int i = 0; i < nStat; i++) { mis[i].finalResult(srt.src[i]); + srt.src[i].dataType = colTypes[idxStat[i]]; if (needFreqs[i]) { srt.distinctValues[i] = fis[i].mapFreq; } @@ -240,26 +236,27 @@ private static boolean isCalc(Row data, int timeColIdx, long startTime, long end } } - private static MeasureIteratorBase[] newMeasures(Class[] colTypes, int[] idxStat) { + private static BaseMeasureIterator[] newMeasures(Class[] colTypes, int[] idxStat) { int nStat = idxStat.length; - MeasureIteratorBase[] mis = new MeasureIteratorBase[nStat]; + BaseMeasureIterator[] mis = new BaseMeasureIterator[nStat]; for (int j = 0; j < nStat; j++) { int index = idxStat[j]; - if (Double.class == colTypes[index] || Float.class == colTypes[index]) { - mis[j] = new MeasureIteratorDouble(colTypes[index]); - } else if (Long.class == colTypes[index] || Integer.class == colTypes[index] - || Short.class == colTypes[index] || Byte.class == colTypes[index]) { - mis[j] = new MeasureIteratorLong(colTypes[index]); - } else if (String.class == colTypes[index]) { - mis[j] = new MeasureIteratorString(colTypes[index]); - } else if (Boolean.class == colTypes[index]) { - mis[j] = new MeasureIteratorBoolean(colTypes[index]); - } else if (java.sql.Timestamp.class == colTypes[index]) { - mis[j] = new MeasureIteratorDate(colTypes[index]); - } else { - throw new AkUnsupportedOperationException(String.format( - "col type [%s] not supported.", colTypes[index].getSimpleName())); - } + mis[j] = StatisticsIteratorFactory.getMeasureIterator(colTypes[index]); + //if (Double.class == colTypes[index] || Float.class == colTypes[index]) { + // mis[j] = new MeasureIteratorDouble(colTypes[index]); + //} else if (Long.class == colTypes[index] || Integer.class == colTypes[index] + // || Short.class == colTypes[index] || Byte.class == colTypes[index]) { + // mis[j] = new MeasureIteratorLong(colTypes[index]); + //} else if (String.class == colTypes[index]) { + // mis[j] = new MeasureIteratorString(colTypes[index]); + //} else if (Boolean.class == colTypes[index]) { + // mis[j] = new MeasureIteratorBoolean(colTypes[index]); + //} else if (java.sql.Timestamp.class == colTypes[index]) { + // mis[j] = new MeasureIteratorDate(colTypes[index]); + //} else { + // throw new AkUnsupportedOperationException(String.format( + // "col type [%s] not supported.", colTypes[index].getSimpleName())); + //} } return mis; } @@ -271,24 +268,22 @@ private static FrequencyIterator[] newFreqs(Class[] colTypes, int[] idxStat, boo fis = new FrequencyIterator[nStat]; for (int j = 0; j < nStat; j++) { int index = idxStat[j]; - fis[j] = new FrequencyIterator(colTypes[index], freqSize); + fis[j] = StatisticsIteratorFactory.getFrequencyIterator(colTypes[index], freqSize); } } return fis; } - private static TopKInterator[] newTopK(Class[] colTypes, int[] idxStat, boolean bCalcFreq, int smallK, int + private static TopKIterator[] newTopK(Class[] colTypes, int[] idxStat, boolean bCalcFreq, int smallK, int largeK) { - TopKInterator[] fis = null; + TopKIterator[] fis = null; if (bCalcFreq) { int nStat = idxStat.length; - fis = new TopKInterator[nStat]; + fis = new TopKIterator[nStat]; for (int j = 0; j < nStat; j++) { int index = idxStat[j]; - if (colTypes[index] != String.class) { - fis[j] = new TopKInterator(colTypes[index], smallK, largeK); - } + fis[j] = StatisticsIteratorFactory.getTopKIterator(colTypes[index], smallK, largeK); } } return fis; @@ -394,7 +389,7 @@ public static Map summaryForGroup(String[] colNames, Class[] List keys = new ArrayList <>(); - List measures = new ArrayList <>(); + List measures = new ArrayList <>(); if (colNames == null || colNames.length == 0) { throw new AkIllegalOperatorParameterException("colNames must not be empty."); @@ -425,25 +420,11 @@ public static Map summaryForGroup(String[] colNames, Class[] int keyIdx = keys.indexOf(keyVal); if (keyIdx < 0) { keys.add(keyVal); - MeasureIteratorBase[] mis = new MeasureIteratorBase[nStat]; + BaseMeasureIterator[] mis = new BaseMeasureIterator[nStat]; for (int j = 0; j < nStat; j++) { int index = idxStat[j]; - if (Double.class == colTypes[index] || Float.class == colTypes[index]) { - mis[j] = new MeasureIteratorDouble(colTypes[index]); - } else if (Long.class == colTypes[index] || Integer.class == colTypes[index] - || Short.class == colTypes[index] || Byte.class == colTypes[index]) { - mis[j] = new MeasureIteratorLong(colTypes[index]); - } else if (String.class == colTypes[index]) { - mis[j] = new MeasureIteratorString(colTypes[index]); - } else if (Boolean.class == colTypes[index]) { - mis[j] = new MeasureIteratorBoolean(colTypes[index]); - } else if (java.sql.Timestamp.class == colTypes[index]) { - mis[j] = new MeasureIteratorDate(colTypes[index]); - } else { - throw new AkUnsupportedOperationException(String.format( - "col type [%s] not supported.", colTypes[index].getSimpleName())); - } + mis[j] = StatisticsIteratorFactory.getMeasureIterator(colTypes[index]); } for (int i = 0; i < nStat; i++) { Object obj = row.getField(idxStat[i]); @@ -451,7 +432,7 @@ public static Map summaryForGroup(String[] colNames, Class[] } measures.add(mis); } else { - MeasureIteratorBase[] mis = measures.get(keyIdx); + BaseMeasureIterator[] mis = measures.get(keyIdx); for (int i = 0; i < nStat; i++) { Object obj = row.getField(idxStat[i]); mis[i].visit(obj); @@ -463,7 +444,7 @@ public static Map summaryForGroup(String[] colNames, Class[] for (int i = 0; i < keys.size(); i++) { SrtForWp srt = new SrtForWp(statColNames); String key = keys.get(i); - MeasureIteratorBase[] mis = measures.get(i); + BaseMeasureIterator[] mis = measures.get(i); for (int j = 0; j < nStat; j++) { mis[j].finalResult(srt.src[j]); } @@ -483,406 +464,3 @@ public static String getStringValue(Row row, int idx) { } } - -class MeasureIteratorString extends MeasureIteratorBase { - - public MeasureIteratorString(Class dataType) { - super(dataType); - } - - @Override - public void visit(Object obj) { - if (null == obj) { - countMissing++; - } else { - count++; - } - } - - @Override - public void finalResult(SummaryResultCol src) { - src.init(dataType, count + countMissing, count, countMissing, 0, 0, 0, 0, 0, - 0, 0, 0, 0, null, null); - } -} - -class MeasureIteratorBoolean extends MeasureIteratorBase { - - long countTrue; - long countFalse; - - public MeasureIteratorBoolean(Class dataType) { - super(dataType); - countTrue = 0; - countFalse = 0; - } - - @Override - public void visit(Object obj) { - if (null == obj) { - countMissing++; - } else { - if (obj.equals(Boolean.TRUE)) { - countTrue++; - } else { - countFalse++; - } - } - } - - @Override - public void finalResult(SummaryResultCol src) { - count = countTrue + countFalse; - Boolean min = null; - Boolean max = null; - if (count > 0) { - min = Boolean.FALSE; - max = Boolean.TRUE; - if (0 == countTrue) { - max = Boolean.FALSE; - } - if (0 == countFalse) { - min = Boolean.TRUE; - } - } - src.init(dataType, count + countMissing, count, countMissing, 0, 0, 0, 0, - countTrue, countTrue, countTrue, countTrue, countTrue, min, max); - } -} - -class MeasureIteratorLong extends MeasureIteratorBase { - public double sum; - public long minLong; - public long maxLong; - public Number min = null; - public Number max = null; - double sum2; - double sum3; - double sum4; - double absSum; - public long countZero; - - public MeasureIteratorLong(Class dataType) { - super(dataType); - sum = 0.0; - sum2 = 0.0; - sum3 = 0.0; - sum4 = 0.0; - minLong = Long.MAX_VALUE; - maxLong = Long.MIN_VALUE; - absSum = 0; - countZero = 0; - } - - @Override - public void visit(Object obj) { - if (null == obj) { - countMissing++; - } else { - long d = ((Number) obj).longValue(); - sum += d; - absSum += Math.abs(d); - sum2 += d * d; - sum3 += d * d * d; - sum4 += d * d * d * d; - count++; - if (d > maxLong) { - maxLong = d; - max = (Number) obj; - } - if (d < minLong) { - minLong = d; - min = (Number) obj; - } - if (0 == d) { - countZero++; - } - } - } - - @Override - public void finalResult(SummaryResultCol src) { - src.init(dataType, count + countMissing, count, - countMissing, 0, 0, 0, countZero, - sum, absSum, sum2, sum3, sum4, min, max); - } -} - -class MeasureIteratorDate extends MeasureIteratorBase { - - public long minLong; - public long maxLong; - - public MeasureIteratorDate(Class dataType) { - super(dataType); - minLong = Long.MAX_VALUE; - maxLong = Long.MIN_VALUE; - } - - @Override - public void visit(Object obj) { - if (null == obj) { - countMissing++; - } else { - long d = ((java.sql.Timestamp) obj).getTime(); - count++; - if (d > maxLong) { - maxLong = d; - } - if (d < minLong) { - minLong = d; - } - } - - } - - @Override - public void finalResult(SummaryResultCol src) { - src.init(dataType, - count + countMissing, count, countMissing, - 0, 0, 0, 0, 0, 0, 0, 0, 0, - minLong, maxLong); - } -} - -class MeasureIteratorDouble extends MeasureIteratorBase { - - public double sum; - public double minDouble; - public double maxDouble; - public Number min = null; - public Number max = null; - double absSum; - double sum2; - double sum3; - double sum4; - public long countZero; - - public MeasureIteratorDouble(Class dataType) { - super(dataType); - sum = 0.0; - sum2 = 0.0; - sum3 = 0.0; - sum4 = 0.0; - absSum = 0; - minDouble = Double.POSITIVE_INFINITY; - maxDouble = Double.NEGATIVE_INFINITY; - countZero = 0; - } - - @Override - public void visit(Object obj) { - if (null == obj) { - countMissing++; - } else { - double d = ((Number) obj).doubleValue(); - sum += d; - absSum += Math.abs(d); - sum2 += d * d; - sum3 += d * d * d; - sum4 += d * d * d * d; - count++; - if (d > maxDouble) { - maxDouble = d; - max = (Number) obj; - } - if (d < minDouble) { - minDouble = d; - min = (Number) obj; - } - if (0.0 == d) { - countZero++; - } - } - } - - @Override - public void finalResult(SummaryResultCol src) { - src.init(dataType, count + countMissing, count, - countMissing, 0, 0, 0, countZero, - sum, absSum, sum2, sum3, sum4, min, max); - } -} - -abstract class MeasureIteratorBase { - - Class dataType; - long count; - long countMissing; - - public MeasureIteratorBase(Class dataType) { - this.dataType = dataType; - count = 0; - countMissing = 0; - } - - abstract void visit(Object obj); - - abstract void finalResult(SummaryResultCol src); -} - -class FrequencyIterator { - - int capacity; - Class dataType; - TreeMap mapFreq = null; - boolean bOutOfRange; - - public FrequencyIterator(Class dataType, int capacity) { - this.capacity = capacity; - this.dataType = dataType; - this.mapFreq = new TreeMap (); - this.bOutOfRange = false; - } - - public void visit(Object o) { - if (bOutOfRange || o == null) { - return; - } - if (o.getClass() == this.dataType) { - if (this.mapFreq.containsKey(o)) { - this.mapFreq.put(o, this.mapFreq.get(o) + 1); - } else { - if (mapFreq.size() >= this.capacity) { - this.bOutOfRange = true; - this.mapFreq.clear(); - } else { - this.mapFreq.put(o, new Long(1)); - } - } - } else { - throw new AkIllegalStateException("Not valid class type!"); - } - } - - public void finalResult(SummaryResultCol src) { - if (this.bOutOfRange) { - src.freq = null; - } else { - src.freq = new TreeMap (); - Iterator > it = this.mapFreq.entrySet().iterator(); - if (dataType == java.sql.Timestamp.class) { - while (it.hasNext()) { - Map.Entry e = it.next(); - src.freq.put(((java.sql.Timestamp) e.getKey()).getTime(), e.getValue()); - } - } else { - while (it.hasNext()) { - Map.Entry e = it.next(); - src.freq.put(e.getKey(), e.getValue()); - } - } - } - } - -} - -class DistinctValueIterator { - public HashSet mapFreq = null; - Class dataType; - - public DistinctValueIterator(Class dataType) { - this.dataType = dataType; - this.mapFreq = new HashSet <>(); - } - - public void visit(Object o) { - if (o.getClass() == this.dataType) { - mapFreq.add(o); - } else { - throw new AkIllegalStateException("Not valid class type!"); - } - } -} - -class TopKInterator { - - Class dataType; - int smallK; - int largeK; - PriorityQueue priqueS; - PriorityQueue priqueL; - - public TopKInterator(Class dataType, int small, int large) { - this.dataType = dataType; - this.smallK = small; - this.largeK = large; - priqueS = new PriorityQueue <>(small, new ValueComparator(this.dataType, -1)); - priqueL = new PriorityQueue <>(large, new ValueComparator(this.dataType, 1)); - } - - public void visit(Object obj) { - if (obj == null) { - return; - } - if (priqueL.size() < largeK) { - priqueL.add(obj); - } else { - priqueL.add(obj); - priqueL.poll(); - } - if (priqueS.size() < smallK) { - priqueS.add(obj); - } else { - priqueS.add(obj); - priqueS.poll(); - } - } - - public void finalResult(SummaryResultCol src) { - int large = priqueL.size(); - int small = priqueS.size(); - src.topItems = new Object[large]; - src.bottomItems = new Object[small]; - if (dataType == java.sql.Timestamp.class) { - for (int i = 0; i < small; i++) { - src.bottomItems[small - i - 1] = ((java.sql.Timestamp) priqueS.poll()).getTime(); - } - for (int i = 0; i < large; i++) { - src.topItems[large - i - 1] = ((java.sql.Timestamp) priqueL.poll()).getTime(); - } - } else { - for (int i = 0; i < small; i++) { - src.bottomItems[small - i - 1] = priqueS.poll(); - } - for (int i = 0; i < large; i++) { - src.topItems[large - i - 1] = priqueL.poll(); - } - } - } - - class ValueComparator implements Comparator { - - Class dataType; - int sortKey = 1; - - ValueComparator(Class dataType, int sortKey) { - this.dataType = dataType; - this.sortKey = sortKey; - } - - @Override - public int compare(Object t, Object t1) { - if (Double.class == this.dataType) { - return (sortKey) * (((Double) t).compareTo((Double) t1)); - } else if (Integer.class == this.dataType) { - return (sortKey) * (((Integer) t).compareTo((Integer) t1)); - } else if (Long.class == this.dataType) { - return (sortKey) * (((Long) t).compareTo((Long) t1)); - } else if (Float.class == this.dataType) { - return (sortKey) * (((Float) t).compareTo((Float) t1)); - } else if (Boolean.class == this.dataType) { - return (sortKey) * (((Boolean) t).compareTo((Boolean) t1)); - } else if (String.class == this.dataType) { - return (sortKey) * (((String) t).compareTo((String) t1)); - } else if (java.sql.Timestamp.class == this.dataType) { - return (sortKey) * (((java.sql.Timestamp) t).compareTo((java.sql.Timestamp) t1)); - } else { - throw new AkIllegalStateException(String.format( - "type [%s] not support.", this.dataType.getSimpleName())); - } - } - - } - -} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/StatisticsIteratorFactory.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/StatisticsIteratorFactory.java new file mode 100644 index 000000000..e89f7e086 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/statistics/StatisticsIteratorFactory.java @@ -0,0 +1,128 @@ +package com.alibaba.alink.operator.common.statistics.statistics; + +import org.apache.flink.api.common.typeinfo.TypeInformation; + +import com.alibaba.alink.common.exceptions.AkIllegalStateException; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.sql.Date; +import java.sql.Time; +import java.sql.Timestamp; + +public class StatisticsIteratorFactory { + + public static BaseMeasureIterator getMeasureIterator(TypeInformation type) { + return getMeasureIterator(type.getTypeClass()); + } + + public static BaseMeasureIterator getMeasureIterator(Class cls) { + if (Number.class.isAssignableFrom(cls)) { + if (Double.class == cls) { + return new NumberMeasureIterator (); + } else if (Long.class == cls) { + return new NumberMeasureIterator (); + } else if (Byte.class == cls) { + return new NumberMeasureIterator (); + } else if (Integer.class == cls) { + return new NumberMeasureIterator (); + } else if (Float.class == cls) { + return new NumberMeasureIterator (); + } else if (Short.class == cls) { + return new NumberMeasureIterator (); + } else if (BigDecimal.class == cls) { + return new NumberMeasureIterator (); + } else if (BigInteger.class == cls) { + return new NumberMeasureIterator (); + } + } else if (java.util.Date.class.isAssignableFrom(cls)) { + if (java.sql.Timestamp.class == cls) { + return new DateMeasureIterator (); + } else if (java.sql.Date.class == cls) { + return new DateMeasureIterator (); + } else if (java.sql.Time.class == cls) { + return new DateMeasureIterator