Skip to content

Commit

Permalink
fix return type
Browse files Browse the repository at this point in the history
  • Loading branch information
xiedeyantu committed Apr 8, 2024
1 parent 332e138 commit 4612b7b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 38 deletions.
13 changes: 5 additions & 8 deletions be/src/vec/functions/function_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,7 @@ class FunctionStrcmp : public IFunction {
size_t get_number_of_arguments() const override { return 2; }

DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
if (arguments[0]->is_nullable() || arguments[1]->is_nullable()) {
return make_nullable(std::make_shared<DataTypeInt16>());
}
return std::make_shared<DataTypeInt16>();
return std::make_shared<DataTypeInt8>();
}

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
Expand All @@ -318,7 +315,7 @@ class FunctionStrcmp : public IFunction {
const auto& [arg1_column, arg1_const] =
unpack_if_const(block.get_by_position(arguments[1]).column);

auto result_column = ColumnInt16::create(input_rows_count);
auto result_column = ColumnInt8::create(input_rows_count);

if (auto arg0 = check_and_get_column<ColumnString>(arg0_column.get())) {
if (auto arg1 = check_and_get_column<ColumnString>(arg1_column.get())) {
Expand All @@ -337,22 +334,22 @@ class FunctionStrcmp : public IFunction {
}

private:
static void scalar_vector(const StringRef str, const ColumnString& vec1, ColumnInt16& res) {
static void scalar_vector(const StringRef str, const ColumnString& vec1, ColumnInt8& res) {
size_t size = vec1.size();
for (size_t i = 0; i < size; ++i) {
res.get_data()[i] = str.compare(vec1.get_data_at(i));
}
}

static void vector_scalar(const ColumnString& vec0, const StringRef str, ColumnInt16& res) {
static void vector_scalar(const ColumnString& vec0, const StringRef str, ColumnInt8& res) {
size_t size = vec0.size();
for (size_t i = 0; i < size; ++i) {
res.get_data()[i] = vec0.get_data_at(i).compare(str);
}
}

static void vector_vector(const ColumnString& vec0, const ColumnString& vec1,
ColumnInt16& res) {
ColumnInt8& res) {
size_t size = vec0.size();
for (size_t i = 0; i < size; ++i) {
res.get_data()[i] = vec0.get_data_at(i).compare(vec1.get_data_at(i));
Expand Down
54 changes: 27 additions & 27 deletions be/test/vec/function/function_string_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1193,62 +1193,62 @@ TEST(function_string_test, function_strcmp_test) {
InputTypeSet input_types = {TypeIndex::String, TypeIndex::String};

DataSet data_set = {{{Null(), Null()}, Null()},
{{std::string(""), std::string("")}, (int16_t)0},
{{std::string("test"), std::string("test")}, (int16_t)0},
{{std::string("test1"), std::string("test")}, (int16_t)1},
{{std::string("test"), std::string("test1")}, (int16_t)-1},
{{std::string(""), std::string("")}, (int8_t)0},
{{std::string("test"), std::string("test")}, (int8_t)0},
{{std::string("test1"), std::string("test")}, (int8_t)1},
{{std::string("test"), std::string("test1")}, (int8_t)-1},
{{Null(), std::string("test")}, Null()},
{{std::string("test"), Null()}, Null()},
{{VARCHAR(""), VARCHAR("")}, (int16_t)0},
{{VARCHAR("test"), VARCHAR("test")}, (int16_t)0},
{{VARCHAR("test1"), VARCHAR("test")}, (int16_t)1},
{{VARCHAR("test"), VARCHAR("test1")}, (int16_t)-1},
{{VARCHAR(""), VARCHAR("")}, (int8_t)0},
{{VARCHAR("test"), VARCHAR("test")}, (int8_t)0},
{{VARCHAR("test1"), VARCHAR("test")}, (int8_t)1},
{{VARCHAR("test"), VARCHAR("test1")}, (int8_t)-1},
{{Null(), VARCHAR("test")}, Null()},
{{VARCHAR("test"), Null()}, Null()}};
static_cast<void>(check_function<DataTypeInt16, true>(func_name, input_types, data_set));
static_cast<void>(check_function<DataTypeInt8, true>(func_name, input_types, data_set));
}
{
InputTypeSet input_types = {Consted {TypeIndex::String}, TypeIndex::String};
DataSet data_set = {{{Null(), Null()}, Null()},
{{std::string(""), std::string("")}, (int16_t)0},
{{std::string("test"), std::string("test")}, (int16_t)0},
{{std::string("test1"), std::string("test")}, (int16_t)1},
{{std::string("test"), std::string("test1")}, (int16_t)-1},
{{std::string(""), std::string("")}, (int8_t)0},
{{std::string("test"), std::string("test")}, (int8_t)0},
{{std::string("test1"), std::string("test")}, (int8_t)1},
{{std::string("test"), std::string("test1")}, (int8_t)-1},
{{Null(), std::string("test")}, Null()},
{{std::string("test"), Null()}, Null()},
{{VARCHAR(""), VARCHAR("")}, (int16_t)0},
{{VARCHAR("test"), VARCHAR("test")}, (int16_t)0},
{{VARCHAR("test1"), VARCHAR("test")}, (int16_t)1},
{{VARCHAR("test"), VARCHAR("test1")}, (int16_t)-1},
{{VARCHAR(""), VARCHAR("")}, (int8_t)0},
{{VARCHAR("test"), VARCHAR("test")}, (int8_t)0},
{{VARCHAR("test1"), VARCHAR("test")}, (int8_t)1},
{{VARCHAR("test"), VARCHAR("test1")}, (int8_t)-1},
{{Null(), VARCHAR("test")}, Null()},
{{VARCHAR("test"), Null()}, Null()}};

for (const auto& line : data_set) {
DataSet const_dataset = {line};
static_cast<void>(
check_function<DataTypeInt16, true>(func_name, input_types, const_dataset));
check_function<DataTypeInt8, true>(func_name, input_types, const_dataset));
}
}
{
InputTypeSet input_types = {TypeIndex::String, Consted {TypeIndex::String}};
DataSet data_set = {{{Null(), Null()}, Null()},
{{std::string(""), std::string("")}, (int16_t)0},
{{std::string("test"), std::string("test")}, (int16_t)0},
{{std::string("test1"), std::string("test")}, (int16_t)1},
{{std::string("test"), std::string("test1")}, (int16_t)-1},
{{std::string(""), std::string("")}, (int8_t)0},
{{std::string("test"), std::string("test")}, (int8_t)0},
{{std::string("test1"), std::string("test")}, (int8_t)1},
{{std::string("test"), std::string("test1")}, (int8_t)-1},
{{Null(), std::string("test")}, Null()},
{{std::string("test"), Null()}, Null()},
{{VARCHAR(""), VARCHAR("")}, (int16_t)0},
{{VARCHAR("test"), VARCHAR("test")}, (int16_t)0},
{{VARCHAR("test1"), VARCHAR("test")}, (int16_t)1},
{{VARCHAR("test"), VARCHAR("test1")}, (int16_t)-1},
{{VARCHAR(""), VARCHAR("")}, (int8_t)0},
{{VARCHAR("test"), VARCHAR("test")}, (int8_t)0},
{{VARCHAR("test1"), VARCHAR("test")}, (int8_t)1},
{{VARCHAR("test"), VARCHAR("test1")}, (int8_t)-1},
{{Null(), VARCHAR("test")}, Null()},
{{VARCHAR("test"), Null()}, Null()}};

for (const auto& line : data_set) {
DataSet const_dataset = {line};
static_cast<void>(
check_function<DataTypeInt16, true>(func_name, input_types, const_dataset));
check_function<DataTypeInt8, true>(func_name, input_types, const_dataset));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.types.VarcharType;

import com.google.common.base.Preconditions;
Expand All @@ -39,8 +39,8 @@ public class Strcmp extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(SmallIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(SmallIntType.INSTANCE).args(StringType.INSTANCE, StringType.INSTANCE));
FunctionSignature.ret(TinyIntType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(TinyIntType.INSTANCE).args(StringType.INSTANCE, StringType.INSTANCE));

/**
* constructor with 2 argument.
Expand Down

0 comments on commit 4612b7b

Please sign in to comment.