Skip to content

Commit

Permalink
[GLUTEN-7015][VL] Remove udf native registration (#7016)
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma authored Sep 4, 2024
1 parent 8bc1842 commit e620830
Show file tree
Hide file tree
Showing 18 changed files with 156 additions and 608 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@

public class UdfJniWrapper {

public static native void getFunctionSignatures();
public static native void registerFunctionSignatures();
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.gluten.backendsapi.ListenerApi
import org.apache.gluten.execution.datasource.{GlutenOrcWriterInjects, GlutenParquetWriterInjects, GlutenRowSplitter}
import org.apache.gluten.expression.UDFMappings
import org.apache.gluten.init.NativeBackendInitializer
import org.apache.gluten.udf.UdfJniWrapper
import org.apache.gluten.utils._
import org.apache.gluten.vectorized.{JniLibLoader, JniWorkspace}

Expand Down Expand Up @@ -81,6 +82,7 @@ class VeloxListenerApi extends ListenerApi with Logging {
SparkDirectoryUtil.init(conf)
UDFResolver.resolveUdfConf(conf, isDriver = true)
initialize(conf)
UdfJniWrapper.registerFunctionSignatures()
}

override def onDriverShutdown(): Unit = shutdown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.apache.gluten.extension.injector.GlutenInjector.{LegacyInjector, RasI
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFallbackReporter}
import org.apache.spark.sql.expression.UDFResolver

class VeloxRuleApi extends RuleApi {
import VeloxRuleApi._
Expand All @@ -47,7 +46,6 @@ private object VeloxRuleApi {
// Regular Spark rules.
injector.injectOptimizerRule(CollectRewriteRule.apply)
injector.injectOptimizerRule(HLLRewriteRule.apply)
UDFResolver.getFunctionSignatures().foreach(injector.injectFunction)
injector.injectPostHocResolutionRule(ArrowConvertorRule.apply)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import org.apache.gluten.utils.VeloxIntermediateData
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.expression.UDFResolver
import org.apache.spark.sql.hive.HiveUDAFInspector
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -681,14 +683,25 @@ object VeloxAggregateFunctionsBuilder {
aggregateFunc: AggregateFunction,
mode: AggregateMode): Long = {
val functionMap = args.asInstanceOf[JHashMap[String, JLong]]
val sigName = AggregateFunctionsBuilder.getSubstraitFunctionName(aggregateFunc)
val (sigName, aggFunc) =
try {
(AggregateFunctionsBuilder.getSubstraitFunctionName(aggregateFunc), aggregateFunc)
} catch {
case e: GlutenNotSupportException =>
HiveUDAFInspector.getUDAFClassName(aggregateFunc) match {
case Some(udafClass) if UDFResolver.UDAFNames.contains(udafClass) =>
(udafClass, UDFResolver.getUdafExpression(udafClass)(aggregateFunc.children))
case _ => throw e
}
case e: Throwable => throw e
}

ExpressionBuilder.newScalarFunction(
functionMap,
ConverterUtils.makeFuncName(
// Substrait-to-Velox procedure will choose appropriate companion function if needed.
sigName,
VeloxIntermediateData.getInputTypes(aggregateFunc, mode == PartialMerge || mode == Final),
VeloxIntermediateData.getInputTypes(aggFunc, mode == PartialMerge || mode == Final),
FunctionConfig.REQ
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,14 @@ package org.apache.spark.sql.expression

import org.apache.gluten.backendsapi.velox.VeloxBackendSettings
import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException}
import org.apache.gluten.expression.{ConverterUtils, ExpressionTransformer, ExpressionType, GenericExpressionTransformer, Transformable}
import org.apache.gluten.udf.UdfJniWrapper
import org.apache.gluten.expression._
import org.apache.gluten.vectorized.JniWorkspace

import org.apache.spark.{SparkConf, SparkContext, SparkFiles}
import org.apache.spark.{SparkConf, SparkFiles}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, ExpressionInfo, Unevaluable}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
Expand Down Expand Up @@ -334,32 +332,6 @@ object UDFResolver extends Logging {
.mkString(",")
}

def getFunctionSignatures(): Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = {
val sparkContext = SparkContext.getActive.get
val sparkConf = sparkContext.conf
val udfLibPaths = sparkConf.getOption(VeloxBackendSettings.GLUTEN_VELOX_UDF_LIB_PATHS)

udfLibPaths match {
case None =>
Seq.empty
case Some(_) =>
UdfJniWrapper.getFunctionSignatures()
UDFNames.map {
name =>
(
new FunctionIdentifier(name),
new ExpressionInfo(classOf[UDFExpression].getName, name),
(e: Seq[Expression]) => getUdfExpression(name, name)(e))
}.toSeq ++ UDAFNames.map {
name =>
(
new FunctionIdentifier(name),
new ExpressionInfo(classOf[UserDefinedAggregateFunction].getName, name),
(e: Seq[Expression]) => getUdafExpression(name)(e))
}.toSeq
}
}

private def checkAllowTypeConversion: Boolean = {
SQLConf.get
.getConfString(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION, "false")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package org.apache.gluten.expression

import org.apache.gluten.backendsapi.velox.VeloxBackendSettings
import org.apache.gluten.tags.{SkipTestTags, UDFTest}

import org.apache.spark.SparkConf
Expand All @@ -25,7 +24,6 @@ import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.expression.UDFResolver

import java.nio.file.Paths
import java.sql.Date

abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {

Expand Down Expand Up @@ -92,64 +90,55 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
.set("spark.memory.offHeap.size", "1024MB")
}

test("test udf") {
val df = spark.sql("""select
| myudf1(100L),
| myudf2(1),
| myudf2(1L),
| myudf3(),
| myudf3(1),
| myudf3(1, 2, 3),
| myudf3(1L),
| myudf3(1L, 2L, 3L),
| mydate(cast('2024-03-25' as date), 5)
|""".stripMargin)
assert(
df.collect()
.sameElements(Array(Row(105L, 6, 6L, 5, 6, 11, 6L, 11L, Date.valueOf("2024-03-30")))))
}

test("test udf allow type conversion") {
withSQLConf(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION -> "true") {
val df = spark.sql("""select myudf1("100"), myudf1(1), mydate('2024-03-25', 5)""")
assert(
df.collect()
.sameElements(Array(Row(105L, 6L, Date.valueOf("2024-03-30")))))
}

withSQLConf(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION -> "false") {
assert(
spark
.sql("select mydate2('2024-03-25', 5)")
.collect()
.sameElements(Array(Row(Date.valueOf("2024-03-30")))))
}
}

test("test udaf") {
val df = spark.sql("""select
| myavg(1),
| myavg(1L),
| myavg(cast(1.0 as float)),
| myavg(cast(1.0 as double)),
| mycount_if(true)
|""".stripMargin)
df.collect()
assert(
df.collect()
.sameElements(Array(Row(1.0, 1.0, 1.0, 1.0, 1L))))
}
// Aggregate result can be flaky.
ignore("test native hive udaf") {
val tbl = "test_hive_udaf_replacement"
withTempPath {
dir =>
try {
// Check native hive udaf has been registered.
val udafClass = "test.org.apache.spark.sql.MyDoubleAvg"
assert(UDFResolver.UDAFNames.contains(udafClass))

test("test udaf allow type conversion") {
withSQLConf(VeloxBackendSettings.GLUTEN_VELOX_UDF_ALLOW_TYPE_CONVERSION -> "true") {
val df = spark.sql("""select myavg("1"), myavg("1.0"), mycount_if("true")""")
assert(
df.collect()
.sameElements(Array(Row(1.0, 1.0, 1L))))
spark.sql(s"""
|CREATE TEMPORARY FUNCTION my_double_avg
|AS '$udafClass'
|""".stripMargin)
spark.sql(s"""
|CREATE EXTERNAL TABLE $tbl
|LOCATION 'file://$dir'
|AS select * from values (1, '1'), (2, '2'), (3, '3')
|""".stripMargin)
val df = spark.sql(s"""select
| my_double_avg(cast(col1 as double)),
| my_double_avg(cast(col2 as double))
| from $tbl
|""".stripMargin)
val nativeImplicitConversionDF = spark.sql(s"""select
| my_double_avg(col1),
| my_double_avg(col2)
| from $tbl
|""".stripMargin)
val nativeResult = df.collect()
val nativeImplicitConversionResult = nativeImplicitConversionDF.collect()

UDFResolver.UDAFNames.remove(udafClass)
val fallbackDF = spark.sql(s"""select
| my_double_avg(cast(col1 as double)),
| my_double_avg(cast(col2 as double))
| from $tbl
|""".stripMargin)
val fallbackResult = fallbackDF.collect()
assert(nativeResult.sameElements(fallbackResult))
assert(nativeImplicitConversionResult.sameElements(fallbackResult))
} finally {
spark.sql(s"DROP TABLE IF EXISTS $tbl")
spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS my_double_avg")
}
}
}

test("test hive udf replacement") {
test("test native hive udf") {
val tbl = "test_hive_udf_replacement"
withTempPath {
dir =>
Expand All @@ -169,12 +158,15 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
|AS 'org.apache.spark.sql.hive.execution.UDFStringString'
|""".stripMargin)

val nativeResultWithImplicitConversion =
spark.sql(s"""SELECT hive_string_string(col1, 'a') FROM $tbl""").collect()
val nativeResult =
spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""").collect()
// Unregister native hive udf to fallback.
UDFResolver.UDFNames.remove("org.apache.spark.sql.hive.execution.UDFStringString")
val fallbackResult =
spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""").collect()
assert(nativeResultWithImplicitConversion.sameElements(fallbackResult))
assert(nativeResult.sameElements(fallbackResult))

// Add an unimplemented udf to the map to test fallback of registered native hive udf.
Expand Down Expand Up @@ -237,6 +229,7 @@ class VeloxUdfSuiteLocal extends VeloxUdfSuite {
super.sparkConf
.set("spark.files", udfLibPath)
.set("spark.gluten.sql.columnar.backend.velox.udfLibraryPaths", udfLibRelativePath)
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
}
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/velox/jni/JniUdf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void gluten::finalizeVeloxJniUDF(JNIEnv* env) {
env->DeleteGlobalRef(udfResolverClass);
}

void gluten::jniGetFunctionSignatures(JNIEnv* env) {
void gluten::jniRegisterFunctionSignatures(JNIEnv* env) {
auto udfLoader = gluten::UdfLoader::getInstance();

const auto& signatures = udfLoader->getRegisteredUdfSignatures();
Expand Down
2 changes: 1 addition & 1 deletion cpp/velox/jni/JniUdf.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ void initVeloxJniUDF(JNIEnv* env);

void finalizeVeloxJniUDF(JNIEnv* env);

void jniGetFunctionSignatures(JNIEnv* env);
void jniRegisterFunctionSignatures(JNIEnv* env);

} // namespace gluten
4 changes: 2 additions & 2 deletions cpp/velox/jni/VeloxJniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ JNIEXPORT void JNICALL Java_org_apache_gluten_init_NativeBackendInitializer_shut
JNI_METHOD_END()
}

JNIEXPORT void JNICALL Java_org_apache_gluten_udf_UdfJniWrapper_getFunctionSignatures( // NOLINT
JNIEXPORT void JNICALL Java_org_apache_gluten_udf_UdfJniWrapper_registerFunctionSignatures( // NOLINT
JNIEnv* env,
jclass) {
JNI_METHOD_START
gluten::jniGetFunctionSignatures(env);
gluten::jniRegisterFunctionSignatures(env);
JNI_METHOD_END()
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1162,11 +1162,11 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag
"regr_sxy",
"regr_replacement"};

auto udfFuncs = UdfLoader::getInstance()->getRegisteredUdafNames();
auto udafFuncs = UdfLoader::getInstance()->getRegisteredUdafNames();

for (const auto& funcSpec : funcSpecs) {
auto funcName = SubstraitParser::getNameBeforeDelimiter(funcSpec);
if (supportedAggFuncs.find(funcName) == supportedAggFuncs.end() && udfFuncs.find(funcName) == udfFuncs.end()) {
if (supportedAggFuncs.find(funcName) == supportedAggFuncs.end() && udafFuncs.find(funcName) == udafFuncs.end()) {
LOG_VALIDATION_MSG(funcName + " was not supported in AggregateRel.");
return false;
}
Expand Down
21 changes: 13 additions & 8 deletions cpp/velox/tests/MyUdfTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,17 @@ class MyUdfTest : public FunctionBaseTest {
}
};

TEST_F(MyUdfTest, myudf1) {
const auto myudf1 = [&](const int64_t& number) {
return evaluateOnce<int64_t>("myudf1(c0)", BIGINT(), std::make_optional(number));
};

EXPECT_EQ(5, myudf1(0));
EXPECT_EQ(105, myudf1(100));
EXPECT_EQ(3147483652, myudf1(3147483647)); // int64
TEST_F(MyUdfTest, hivestringstring) {
auto map = facebook::velox::exec::vectorFunctionFactories();
const std::string candidate = {"org.apache.spark.sql.hive.execution.UDFStringString"};
ASSERT(map.withRLock([&candidate](auto& self) -> bool {
auto iter = self.find(candidate);
std::unordered_map<std::string, std::string> values;
const facebook::velox::core::QueryConfig config(std::move(values));
iter->second.factory(
candidate,
{facebook::velox::exec::VectorFunctionArg{facebook::velox::VARCHAR()},
facebook::velox::exec::VectorFunctionArg{facebook::velox::VARCHAR()}},
config) != nullptr;
});)
}
3 changes: 0 additions & 3 deletions cpp/velox/udf/examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,3 @@ target_link_libraries(myudf velox)

add_library(myudaf SHARED "MyUDAF.cc")
target_link_libraries(myudaf velox)

add_executable(test_myudf "TestMyUDF.cc")
target_link_libraries(test_myudf velox)
Loading

0 comments on commit e620830

Please sign in to comment.