Skip to content

Commit

Permalink
Merge pull request cdapio#15478 from cdapio/dataset_union_fix
Browse files Browse the repository at this point in the history
[CDAP-20911] Fix ClassCastException by adding a workaround for SPARK-46176
  • Loading branch information
tivv authored Nov 30, 2023
2 parents 3042d9c + dcb21cc commit 1be1fa0
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 65 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright © 2023 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/

package io.cdap.cdap.datapipeline;

import com.google.common.collect.ImmutableMap;
import io.cdap.cdap.etl.common.Constants;
import java.util.Map;

/**
* This test runs all testcases of DataPipelineTest while enforcing maximum Dataset usage
*/
public class DatasetDataPipelineTest extends DataPipelineTest{

@Override
protected Map<String, String> addRuntimeArguments(Map<String, String> arguments) {
return ImmutableMap.<String, String>builder().putAll(arguments)
.put(Constants.DATASET_FORCE, "true").build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ public final class Constants {
public static final String CONSOLIDATE_STAGES = "spark.cdap.pipeline.consolidate.stages";
public static final String CACHE_FUNCTIONS = "spark.cdap.pipeline.functioncache.enable";
public static final String DATASET_KRYO_ENABLED = "spark.cdap.pipeline.dataset.kryo.enable";

/**
* Force using Datasets instead of RDDs right out of BatchSource. Should mostly
* be used for testing
*/
public static final String DATASET_FORCE = "spark.cdap.pipeline.dataset.force";
public static final String DATASET_AGGREGATE_ENABLED = "spark.cdap.pipeline.aggregate.dataset.enable";
public static final String DISABLE_ELT_PUSHDOWN = "cdap.pipeline.pushdown.disable";
public static final String DATASET_AGGREGATE_IGNORE_PARTITIONS =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import io.cdap.cdap.etl.spark.function.JoinOnFunction;
import io.cdap.cdap.etl.spark.function.PluginFunctionContext;
import io.cdap.cdap.internal.io.SchemaTypeAdapter;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.SQLContext;
Expand Down Expand Up @@ -141,15 +142,22 @@ protected SparkCollection<RecordInfo<Object>> getSource(StageSpec stageSpec,
}
}

// If SQL engine is not initiated : use default spark method (RDDCollection)
// If SQL engine is not initiated : use default spark method (RDDCollection or OpaqueDatasetCollection)
boolean shouldForceDatasets = Boolean.parseBoolean(
sec.getRuntimeArguments().getOrDefault(Constants.DATASET_FORCE, Boolean.FALSE.toString()));
PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stageSpec, sec, collector);
FlatMapFunction<Tuple2<Object, Object>, RecordInfo<Object>> sourceFunction =
new BatchSourceFunction(pluginFunctionContext, functionCacheFactory.newCache());
this.functionCacheFactory = functionCacheFactory;
JavaRDD<RecordInfo<Object>> rdd = sourceFactory
.createRDD(sec, jsc, stageSpec.getName(), Object.class, Object.class)
.flatMap(sourceFunction);
if (shouldForceDatasets) {
return OpaqueDatasetCollection.fromRdd(
rdd, sec, jsc, new SQLContext(jsc), datasetContext, sinkFactory, functionCacheFactory);
}
return new RDDCollection<>(sec, functionCacheFactory, jsc,
new SQLContext(jsc), datasetContext, sinkFactory, sourceFactory
.createRDD(sec, jsc, stageSpec.getName(), Object.class, Object.class)
.flatMap(sourceFunction));
new SQLContext(jsc), datasetContext, sinkFactory, rdd);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ public DataframeCollection(Schema schema, Dataset<Row> dataframe, JavaSparkExecu
super(sec, jsc, sqlContext, datasetContext, sinkFactory, functionCacheFactory);
this.schema = Objects.requireNonNull(schema);
this.dataframe = dataframe;
if (!Row.class.isAssignableFrom(dataframe.encoder().clsTag().runtimeClass())) {
throw new IllegalArgumentException(
"Dataframe collection received dataset of " + dataframe.encoder().clsTag()
.runtimeClass());
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@
public abstract class DatasetCollection<T> extends DelegatingSparkCollection<T>
implements BatchCollection<T>{
private static final Encoder KRYO_OBJECT_ENCODER = Encoders.kryo(Object.class);
private static final Encoder KRYO_ARRAY_ENCODER = Encoders.kryo(Object[].class);
private static final Encoder KRYO_TUPLE_ENCODER = Encoders.tuple(
KRYO_OBJECT_ENCODER, KRYO_OBJECT_ENCODER);
private static final Encoder JAVA_OBJECT_ENCODER = Encoders.javaSerialization(Object.class);
private static final Encoder JAVA_ARRAY_ENCODER = Encoders.javaSerialization(Object[].class);
private static final Encoder JAVA_TUPLE_ENCODER = Encoders.tuple(
JAVA_OBJECT_ENCODER, JAVA_OBJECT_ENCODER);

Expand Down Expand Up @@ -112,6 +114,13 @@ protected static <V> Encoder<V> objectEncoder(boolean useKryoForDatasets) {
return useKryoForDatasets ? KRYO_OBJECT_ENCODER : JAVA_OBJECT_ENCODER;
}

/**
* helper function to provide a generified encoder for array of serializable type
* @param useKryoForDatasets
*/
protected static <V> Encoder<V[]> arrayEncoder(boolean useKryoForDatasets) {
return useKryoForDatasets ? KRYO_ARRAY_ENCODER : JAVA_ARRAY_ENCODER;
}
/**
* helper function to provide a generified encoder for tuple of two serializable types
*/
Expand All @@ -126,6 +135,12 @@ protected <V> Encoder<V> objectEncoder() {
return objectEncoder(useKryoForDatasets);
}

/**
* helper function to provide a generified encoder for array of serializable type
*/
protected <V> Encoder<V[]> arrayEncoder() {
return arrayEncoder(useKryoForDatasets);
}
@Override
public <U> SparkCollection<U> map(Function<T, U> function) {
MapFunction<T, U> mapFunction = function::call;
Expand Down Expand Up @@ -158,10 +173,25 @@ protected DatasetCollection<T> cache(StorageLevel cacheStorageLevel) {
return wrap(getDataset().persist(cacheStorageLevel));
}

private static <T> Object[] arrayWrapper(T value) {
return value == null ? null : new Object[]{value};
}

private static <T> T arrayUnWrapper(Object[] array) {
return array == null ? null : (T) array[0];
}

@Override
public SparkCollection union(SparkCollection other) {
if (other instanceof DatasetCollection) {
return wrap(getDataset().unionAll(((DatasetCollection) other).getDataset()));
//We need to workaround https://issues.apache.org/jira/browse/SPARK-46176 that
//causes problems with union for Dataset[Object]. We'll wrap and unwrap value into array.
MapFunction<T, Object[]> wrapper = DatasetCollection::arrayWrapper;
MapFunction<Object[], T> unWrapper = DatasetCollection::arrayUnWrapper;
Dataset<Object[]> left = getDataset().map(wrapper, arrayEncoder());
Dataset<Object[]> right = ((DatasetCollection) other).getDataset()
.map(wrapper, arrayEncoder());
return wrap(left.unionAll(right).map(unWrapper, objectEncoder()));
}
return super.union(other);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ private OpaqueDatasetCollection(Dataset<T> dataset,
FunctionCache.Factory functionCacheFactory) {
super(sec, jsc, sqlContext, datasetContext, sinkFactory, functionCacheFactory);
this.dataset = dataset;
if (Row.class.isAssignableFrom(dataset.encoder().clsTag().runtimeClass())) {
throw new IllegalArgumentException(
"Opaque collection received dataset of Row (" + dataset.encoder().clsTag()
.runtimeClass() + "). DataframeCollection should be used.");
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.cdap.cdap.etl.common.RecordInfo;
import io.cdap.cdap.etl.common.TrackedTransform;
import io.cdap.cdap.etl.spark.CombinedEmitter;
import java.util.concurrent.ExecutionException;
import org.apache.spark.api.java.function.FlatMapFunction;

import java.util.Iterator;
Expand All @@ -43,15 +44,21 @@ public TransformFunction(PluginFunctionContext pluginFunctionContext, FunctionCa

@Override
public Iterator<RecordInfo<Object>> call(T input) throws Exception {
if (transform == null) {
Transform<T, Object> plugin = pluginFunctionContext.createAndInitializePlugin(functionCache);
transform = new TrackedTransform<>(plugin, pluginFunctionContext.createStageMetrics(),
pluginFunctionContext.getDataTracer(),
pluginFunctionContext.getStageStatisticsCollector());
emitter = new CombinedEmitter<>(pluginFunctionContext.getStageName());
try {
if (transform == null) {
Transform<T, Object> plugin = pluginFunctionContext.createAndInitializePlugin(
functionCache);
transform = new TrackedTransform<>(plugin, pluginFunctionContext.createStageMetrics(),
pluginFunctionContext.getDataTracer(),
pluginFunctionContext.getStageStatisticsCollector());
emitter = new CombinedEmitter<>(pluginFunctionContext.getStageName());
}
emitter.reset();
transform.transform(input, emitter);
return emitter.getEmitted().iterator();
} catch (Exception e) {
throw new ExecutionException("Error when transforming stage "
+ pluginFunctionContext.getStageName() + ": " + e, e);
}
emitter.reset();
transform.transform(input, emitter);
return emitter.getEmitted().iterator();
}
}

0 comments on commit 1be1fa0

Please sign in to comment.