From 4612b7b08bfb059f5d71ed8d8247cb6f525059a7 Mon Sep 17 00:00:00 2001 From: xiedeyantu Date: Mon, 8 Apr 2024 12:46:17 +0800 Subject: [PATCH] fix return type --- be/src/vec/functions/function_string.h | 13 ++--- be/test/vec/function/function_string_test.cpp | 54 +++++++++---------- .../expressions/functions/scalar/Strcmp.java | 6 +-- 3 files changed, 35 insertions(+), 38 deletions(-) diff --git a/be/src/vec/functions/function_string.h b/be/src/vec/functions/function_string.h index 01766b9da456bb..0b4d5222d35e79 100644 --- a/be/src/vec/functions/function_string.h +++ b/be/src/vec/functions/function_string.h @@ -305,10 +305,7 @@ class FunctionStrcmp : public IFunction { size_t get_number_of_arguments() const override { return 2; } DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { - if (arguments[0]->is_nullable() || arguments[1]->is_nullable()) { - return make_nullable(std::make_shared()); - } - return std::make_shared(); + return std::make_shared(); } Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, @@ -318,7 +315,7 @@ class FunctionStrcmp : public IFunction { const auto& [arg1_column, arg1_const] = unpack_if_const(block.get_by_position(arguments[1]).column); - auto result_column = ColumnInt16::create(input_rows_count); + auto result_column = ColumnInt8::create(input_rows_count); if (auto arg0 = check_and_get_column(arg0_column.get())) { if (auto arg1 = check_and_get_column(arg1_column.get())) { @@ -337,14 +334,14 @@ class FunctionStrcmp : public IFunction { } private: - static void scalar_vector(const StringRef str, const ColumnString& vec1, ColumnInt16& res) { + static void scalar_vector(const StringRef str, const ColumnString& vec1, ColumnInt8& res) { size_t size = vec1.size(); for (size_t i = 0; i < size; ++i) { res.get_data()[i] = str.compare(vec1.get_data_at(i)); } } - static void vector_scalar(const ColumnString& vec0, const StringRef str, ColumnInt16& res) { + static void vector_scalar(const ColumnString& vec0, const StringRef str, ColumnInt8& res) { size_t size = vec0.size(); for (size_t i = 0; i < size; ++i) { res.get_data()[i] = vec0.get_data_at(i).compare(str); @@ -352,7 +349,7 @@ class FunctionStrcmp : public IFunction { } static void vector_vector(const ColumnString& vec0, const ColumnString& vec1, - ColumnInt16& res) { + ColumnInt8& res) { size_t size = vec0.size(); for (size_t i = 0; i < size; ++i) { res.get_data()[i] = vec0.get_data_at(i).compare(vec1.get_data_at(i)); diff --git a/be/test/vec/function/function_string_test.cpp b/be/test/vec/function/function_string_test.cpp index c4db26ddd4cafc..44b3f25e388166 100644 --- a/be/test/vec/function/function_string_test.cpp +++ b/be/test/vec/function/function_string_test.cpp @@ -1193,62 +1193,62 @@ TEST(function_string_test, function_strcmp_test) { InputTypeSet input_types = {TypeIndex::String, TypeIndex::String}; DataSet data_set = {{{Null(), Null()}, Null()}, - {{std::string(""), std::string("")}, (int16_t)0}, - {{std::string("test"), std::string("test")}, (int16_t)0}, - {{std::string("test1"), std::string("test")}, (int16_t)1}, - {{std::string("test"), std::string("test1")}, (int16_t)-1}, + {{std::string(""), std::string("")}, (int8_t)0}, + {{std::string("test"), std::string("test")}, (int8_t)0}, + {{std::string("test1"), std::string("test")}, (int8_t)1}, + {{std::string("test"), std::string("test1")}, (int8_t)-1}, {{Null(), std::string("test")}, Null()}, {{std::string("test"), Null()}, Null()}, - {{VARCHAR(""), VARCHAR("")}, (int16_t)0}, - {{VARCHAR("test"), VARCHAR("test")}, (int16_t)0}, - {{VARCHAR("test1"), VARCHAR("test")}, (int16_t)1}, - {{VARCHAR("test"), VARCHAR("test1")}, (int16_t)-1}, + {{VARCHAR(""), VARCHAR("")}, (int8_t)0}, + {{VARCHAR("test"), VARCHAR("test")}, (int8_t)0}, + {{VARCHAR("test1"), VARCHAR("test")}, (int8_t)1}, + {{VARCHAR("test"), VARCHAR("test1")}, (int8_t)-1}, {{Null(), VARCHAR("test")}, Null()}, {{VARCHAR("test"), Null()}, Null()}}; - static_cast(check_function(func_name, input_types, data_set)); + static_cast(check_function(func_name, input_types, data_set)); } { InputTypeSet input_types = {Consted {TypeIndex::String}, TypeIndex::String}; DataSet data_set = {{{Null(), Null()}, Null()}, - {{std::string(""), std::string("")}, (int16_t)0}, - {{std::string("test"), std::string("test")}, (int16_t)0}, - {{std::string("test1"), std::string("test")}, (int16_t)1}, - {{std::string("test"), std::string("test1")}, (int16_t)-1}, + {{std::string(""), std::string("")}, (int8_t)0}, + {{std::string("test"), std::string("test")}, (int8_t)0}, + {{std::string("test1"), std::string("test")}, (int8_t)1}, + {{std::string("test"), std::string("test1")}, (int8_t)-1}, {{Null(), std::string("test")}, Null()}, {{std::string("test"), Null()}, Null()}, - {{VARCHAR(""), VARCHAR("")}, (int16_t)0}, - {{VARCHAR("test"), VARCHAR("test")}, (int16_t)0}, - {{VARCHAR("test1"), VARCHAR("test")}, (int16_t)1}, - {{VARCHAR("test"), VARCHAR("test1")}, (int16_t)-1}, + {{VARCHAR(""), VARCHAR("")}, (int8_t)0}, + {{VARCHAR("test"), VARCHAR("test")}, (int8_t)0}, + {{VARCHAR("test1"), VARCHAR("test")}, (int8_t)1}, + {{VARCHAR("test"), VARCHAR("test1")}, (int8_t)-1}, {{Null(), VARCHAR("test")}, Null()}, {{VARCHAR("test"), Null()}, Null()}}; for (const auto& line : data_set) { DataSet const_dataset = {line}; static_cast( - check_function(func_name, input_types, const_dataset)); + check_function(func_name, input_types, const_dataset)); } } { InputTypeSet input_types = {TypeIndex::String, Consted {TypeIndex::String}}; DataSet data_set = {{{Null(), Null()}, Null()}, - {{std::string(""), std::string("")}, (int16_t)0}, - {{std::string("test"), std::string("test")}, (int16_t)0}, - {{std::string("test1"), std::string("test")}, (int16_t)1}, - {{std::string("test"), std::string("test1")}, (int16_t)-1}, + {{std::string(""), std::string("")}, (int8_t)0}, + {{std::string("test"), std::string("test")}, (int8_t)0}, + {{std::string("test1"), std::string("test")}, (int8_t)1}, + {{std::string("test"), std::string("test1")}, (int8_t)-1}, {{Null(), std::string("test")}, Null()}, {{std::string("test"), Null()}, Null()}, - {{VARCHAR(""), VARCHAR("")}, (int16_t)0}, - {{VARCHAR("test"), VARCHAR("test")}, (int16_t)0}, - {{VARCHAR("test1"), VARCHAR("test")}, (int16_t)1}, - {{VARCHAR("test"), VARCHAR("test1")}, (int16_t)-1}, + {{VARCHAR(""), VARCHAR("")}, (int8_t)0}, + {{VARCHAR("test"), VARCHAR("test")}, (int8_t)0}, + {{VARCHAR("test1"), VARCHAR("test")}, (int8_t)1}, + {{VARCHAR("test"), VARCHAR("test1")}, (int8_t)-1}, {{Null(), VARCHAR("test")}, Null()}, {{VARCHAR("test"), Null()}, Null()}}; for (const auto& line : data_set) { DataSet const_dataset = {line}; static_cast( - check_function(func_name, input_types, const_dataset)); + check_function(func_name, input_types, const_dataset)); } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Strcmp.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Strcmp.java index aa1f6c5506fda4..b9aaff85fce252 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Strcmp.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Strcmp.java @@ -23,8 +23,8 @@ import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; -import org.apache.doris.nereids.types.SmallIntType; import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.TinyIntType; import org.apache.doris.nereids.types.VarcharType; import com.google.common.base.Preconditions; @@ -39,8 +39,8 @@ public class Strcmp extends ScalarFunction implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable { public static final List SIGNATURES = ImmutableList.of( - FunctionSignature.ret(SmallIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT), - FunctionSignature.ret(SmallIntType.INSTANCE).args(StringType.INSTANCE, StringType.INSTANCE)); + FunctionSignature.ret(TinyIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT), + FunctionSignature.ret(TinyIntType.INSTANCE).args(StringType.INSTANCE, StringType.INSTANCE)); /** * constructor with 2 argument.