Skip to content

Commit

Permalink
[feature](functions) impl scalar functions trim_in、ltrim_in and rtrim…
Browse files Browse the repository at this point in the history
…_in (apache#41681)

trim_in is different from trim
Find and remove any characters in a set of characters at both ends of a
string (regardless of order)

mysql> SELECT TRIM('abcd', 'cde');
+---------------------+
| trim('abcd', 'cde') |
+---------------------+
| abcd                |
+---------------------+
1 row in set (0.02 sec)

mysql> SELECT TRIM_IN('abcd', 'cde');
+------------------------+
| trim_in('abcd', 'cde') |
+------------------------+
| ab                     |
+------------------------+
1 row in set (0.02 sec)
  • Loading branch information
liujiwen-up authored Oct 21, 2024
1 parent 09d02c0 commit fa5f1b9
Show file tree
Hide file tree
Showing 9 changed files with 1,250 additions and 3 deletions.
160 changes: 157 additions & 3 deletions be/src/vec/functions/function_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <math.h>
#include <re2/stringpiece.h>

#include <bitset>
#include <cstddef>
#include <string_view>

Expand Down Expand Up @@ -508,6 +509,15 @@ struct NameLTrim {
struct NameRTrim {
static constexpr auto name = "rtrim";
};
struct NameTrimIn {
static constexpr auto name = "trim_in";
};
struct NameLTrimIn {
static constexpr auto name = "ltrim_in";
};
struct NameRTrimIn {
static constexpr auto name = "rtrim_in";
};
template <bool is_ltrim, bool is_rtrim, bool trim_single>
struct TrimUtil {
static Status vector(const ColumnString::Chars& str_data,
Expand Down Expand Up @@ -535,6 +545,135 @@ struct TrimUtil {
return Status::OK();
}
};
template <bool is_ltrim, bool is_rtrim, bool trim_single>
struct TrimInUtil {
static Status vector(const ColumnString::Chars& str_data,
const ColumnString::Offsets& str_offsets, const StringRef& remove_str,
ColumnString::Chars& res_data, ColumnString::Offsets& res_offsets) {
const size_t offset_size = str_offsets.size();
res_offsets.resize(offset_size);
res_data.reserve(str_data.size());
bool all_ascii = simd::VStringFunctions::is_ascii(remove_str) &&
simd::VStringFunctions::is_ascii(StringRef(
reinterpret_cast<const char*>(str_data.data()), str_data.size()));

if (all_ascii) {
return impl_vectors_ascii(str_data, str_offsets, remove_str, res_data, res_offsets);
} else {
return impl_vectors_utf8(str_data, str_offsets, remove_str, res_data, res_offsets);
}
}

private:
static Status impl_vectors_ascii(const ColumnString::Chars& str_data,
const ColumnString::Offsets& str_offsets,
const StringRef& remove_str, ColumnString::Chars& res_data,
ColumnString::Offsets& res_offsets) {
const size_t offset_size = str_offsets.size();
std::bitset<128> char_lookup;
const char* remove_begin = remove_str.data;
const char* remove_end = remove_str.data + remove_str.size;

while (remove_begin < remove_end) {
char_lookup.set(static_cast<unsigned char>(*remove_begin));
remove_begin += 1;
}

for (size_t i = 0; i < offset_size; ++i) {
const char* str_begin =
reinterpret_cast<const char*>(str_data.data() + str_offsets[i - 1]);
const char* str_end = reinterpret_cast<const char*>(str_data.data() + str_offsets[i]);
const char* left_trim_pos = str_begin;
const char* right_trim_pos = str_end;

if constexpr (is_ltrim) {
while (left_trim_pos < str_end) {
if (!char_lookup.test(static_cast<unsigned char>(*left_trim_pos))) {
break;
}
++left_trim_pos;
}
}

if constexpr (is_rtrim) {
while (right_trim_pos > left_trim_pos) {
--right_trim_pos;
if (!char_lookup.test(static_cast<unsigned char>(*right_trim_pos))) {
++right_trim_pos;
break;
}
}
}

res_data.insert_assume_reserved(left_trim_pos, right_trim_pos);
res_offsets[i] = res_data.size();
}

return Status::OK();
}

static Status impl_vectors_utf8(const ColumnString::Chars& str_data,
const ColumnString::Offsets& str_offsets,
const StringRef& remove_str, ColumnString::Chars& res_data,
ColumnString::Offsets& res_offsets) {
const size_t offset_size = str_offsets.size();
res_offsets.resize(offset_size);
res_data.reserve(str_data.size());

std::unordered_set<std::string_view> char_lookup;
const char* remove_begin = remove_str.data;
const char* remove_end = remove_str.data + remove_str.size;

while (remove_begin < remove_end) {
size_t byte_len, char_len;
std::tie(byte_len, char_len) = simd::VStringFunctions::iterate_utf8_with_limit_length(
remove_begin, remove_end, 1);
char_lookup.insert(std::string_view(remove_begin, byte_len));
remove_begin += byte_len;
}

for (size_t i = 0; i < offset_size; ++i) {
const char* str_begin =
reinterpret_cast<const char*>(str_data.data() + str_offsets[i - 1]);
const char* str_end = reinterpret_cast<const char*>(str_data.data() + str_offsets[i]);
const char* left_trim_pos = str_begin;
const char* right_trim_pos = str_end;

if constexpr (is_ltrim) {
while (left_trim_pos < str_end) {
size_t byte_len, char_len;
std::tie(byte_len, char_len) =
simd::VStringFunctions::iterate_utf8_with_limit_length(left_trim_pos,
str_end, 1);
if (char_lookup.find(std::string_view(left_trim_pos, byte_len)) ==
char_lookup.end()) {
break;
}
left_trim_pos += byte_len;
}
}

if constexpr (is_rtrim) {
while (right_trim_pos > left_trim_pos) {
const char* prev_char_pos = right_trim_pos;
do {
--prev_char_pos;
} while ((*prev_char_pos & 0xC0) == 0x80);
size_t byte_len = right_trim_pos - prev_char_pos;
if (char_lookup.find(std::string_view(prev_char_pos, byte_len)) ==
char_lookup.end()) {
break;
}
right_trim_pos = prev_char_pos;
}
}

res_data.insert_assume_reserved(left_trim_pos, right_trim_pos);
res_offsets[i] = res_data.size();
}
return Status::OK();
}
};
// This is an implementation of a parameter for the Trim function.
template <bool is_ltrim, bool is_rtrim, typename Name>
struct Trim1Impl {
Expand Down Expand Up @@ -583,14 +722,23 @@ struct Trim2Impl {
const auto* remove_str_raw = col_right->get_chars().data();
const ColumnString::Offset remove_str_size = col_right->get_offsets()[0];
const StringRef remove_str(remove_str_raw, remove_str_size);

if (remove_str.size == 1) {
RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim, true>::vector(
col->get_chars(), col->get_offsets(), remove_str, col_res->get_chars(),
col_res->get_offsets())));
} else {
RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim, false>::vector(
col->get_chars(), col->get_offsets(), remove_str, col_res->get_chars(),
col_res->get_offsets())));
if constexpr (std::is_same<Name, NameTrimIn>::value ||
std::is_same<Name, NameLTrimIn>::value ||
std::is_same<Name, NameRTrimIn>::value) {
RETURN_IF_ERROR((TrimInUtil<is_ltrim, is_rtrim, false>::vector(
col->get_chars(), col->get_offsets(), remove_str,
col_res->get_chars(), col_res->get_offsets())));
} else {
RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim, false>::vector(
col->get_chars(), col->get_offsets(), remove_str,
col_res->get_chars(), col_res->get_offsets())));
}
}
block.replace_by_position(result, std::move(col_res));
} else {
Expand Down Expand Up @@ -1023,6 +1171,12 @@ void register_function_string(SimpleFunctionFactory& factory) {
factory.register_function<FunctionTrim<Trim2Impl<true, true, NameTrim>>>();
factory.register_function<FunctionTrim<Trim2Impl<true, false, NameLTrim>>>();
factory.register_function<FunctionTrim<Trim2Impl<false, true, NameRTrim>>>();
factory.register_function<FunctionTrim<Trim1Impl<true, true, NameTrimIn>>>();
factory.register_function<FunctionTrim<Trim1Impl<true, false, NameLTrimIn>>>();
factory.register_function<FunctionTrim<Trim1Impl<false, true, NameRTrimIn>>>();
factory.register_function<FunctionTrim<Trim2Impl<true, true, NameTrimIn>>>();
factory.register_function<FunctionTrim<Trim2Impl<true, false, NameLTrimIn>>>();
factory.register_function<FunctionTrim<Trim2Impl<false, true, NameRTrimIn>>>();
factory.register_function<FunctionConvertTo>();
factory.register_function<FunctionSubstring<Substr3Impl>>();
factory.register_function<FunctionSubstring<Substr2Impl>>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lower;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lpad;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Ltrim;
import org.apache.doris.nereids.trees.expressions.functions.scalar.LtrimIn;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MakeDate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsKey;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsValue;
Expand Down Expand Up @@ -358,6 +359,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.RoundBankers;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Rpad;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Rtrim;
import org.apache.doris.nereids.trees.expressions.functions.scalar.RtrimIn;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SecToTime;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Second;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SecondCeil;
Expand Down Expand Up @@ -438,6 +440,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Tokenize;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Translate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Trim;
import org.apache.doris.nereids.trees.expressions.functions.scalar.TrimIn;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Truncate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Unhex;
import org.apache.doris.nereids.trees.expressions.functions.scalar.UnixTimestamp;
Expand Down Expand Up @@ -760,6 +763,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(Lower.class, "lcase", "lower"),
scalar(Lpad.class, "lpad"),
scalar(Ltrim.class, "ltrim"),
scalar(LtrimIn.class, "ltrim_in"),
scalar(MakeDate.class, "makedate"),
scalar(MapContainsKey.class, "map_contains_key"),
scalar(MapContainsValue.class, "map_contains_value"),
Expand Down Expand Up @@ -835,6 +839,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(RoundBankers.class, "round_bankers"),
scalar(Rpad.class, "rpad"),
scalar(Rtrim.class, "rtrim"),
scalar(RtrimIn.class, "rtrim_in"),
scalar(Second.class, "second"),
scalar(SecondCeil.class, "second_ceil"),
scalar(SecondFloor.class, "second_floor"),
Expand Down Expand Up @@ -920,6 +925,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(ToQuantileState.class, "to_quantile_state"),
scalar(Translate.class, "translate"),
scalar(Trim.class, "trim"),
scalar(TrimIn.class, "trim_in"),
scalar(Truncate.class, "truncate"),
scalar(Unhex.class, "unhex"),
scalar(UnixTimestamp.class, "unix_timestamp"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,27 @@ private static String trimImpl(String first, String second, boolean left, boolea
return result;
}

private static String trimInImpl(String first, String second, boolean left, boolean right) {
StringBuilder result = new StringBuilder(first);

if (left) {
int start = 0;
while (start < result.length() && second.indexOf(result.charAt(start)) != -1) {
start++;
}
result.delete(0, start);
}
if (right) {
int end = result.length();
while (end > 0 && second.indexOf(result.charAt(end - 1)) != -1) {
end--;
}
result.delete(end, result.length());
}

return result.toString();
}

/**
* Executable arithmetic functions Trim
*/
Expand Down Expand Up @@ -199,6 +220,54 @@ public static Expression rtrimVarcharVarchar(StringLikeLiteral first, StringLike
return castStringLikeLiteral(first, trimImpl(first.getValue(), second.getValue(), false, true));
}

/**
* Executable arithmetic functions Trim_In
*/
@ExecFunction(name = "trim_in")
public static Expression trimInVarchar(StringLikeLiteral first) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), " ", true, true));
}

/**
* Executable arithmetic functions Trim_In
*/
@ExecFunction(name = "trim_in")
public static Expression trimInVarcharVarchar(StringLikeLiteral first, StringLikeLiteral second) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), second.getValue(), true, true));
}

/**
* Executable arithmetic functions ltrim_in
*/
@ExecFunction(name = "ltrim_in")
public static Expression ltrimInVarchar(StringLikeLiteral first) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), " ", true, false));
}

/**
* Executable arithmetic functions ltrim_in
*/
@ExecFunction(name = "ltrim_in")
public static Expression ltrimInVarcharVarchar(StringLikeLiteral first, StringLikeLiteral second) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), second.getValue(), true, false));
}

/**
* Executable arithmetic functions rtrim_in
*/
@ExecFunction(name = "rtrim_in")
public static Expression rtrimInVarchar(StringLikeLiteral first) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), " ", false, true));
}

/**
* Executable arithmetic functions rtrim_in
*/
@ExecFunction(name = "rtrim_in")
public static Expression rtrimInVarcharVarchar(StringLikeLiteral first, StringLikeLiteral second) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), second.getValue(), false, true));
}

/**
* Executable arithmetic functions Replace
*/
Expand Down
Loading

0 comments on commit fa5f1b9

Please sign in to comment.