Skip to content

Commit

Permalink
[feature](function) round function defaults to rounding normally
Browse files Browse the repository at this point in the history
  • Loading branch information
Mryange authored Mar 5, 2024
1 parent ea027f3 commit b288736
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 83 deletions.
92 changes: 28 additions & 64 deletions be/src/vec/functions/round.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ enum class RoundingMode {
};

enum class TieBreakingMode {
Auto, // use banker's rounding for floating point numbers, round up otherwise
Auto, // use round up
Bankers, // use banker's rounding
};

Expand Down Expand Up @@ -178,59 +178,16 @@ class DecimalRoundingImpl {
}
};

#if defined(__SSE4_1__) || defined(__aarch64__)

template <typename T>
class BaseFloatRoundingComputation;

template <>
class BaseFloatRoundingComputation<Float32> {
public:
using ScalarType = Float32;
using VectorType = __m128;
static const size_t data_count = 4;

static VectorType load(const ScalarType* in) { return _mm_loadu_ps(in); }
static VectorType load1(const ScalarType in) { return _mm_load1_ps(&in); }
static void store(ScalarType* out, VectorType val) { _mm_storeu_ps(out, val); }
static VectorType multiply(VectorType val, VectorType scale) { return _mm_mul_ps(val, scale); }
static VectorType divide(VectorType val, VectorType scale) { return _mm_div_ps(val, scale); }
template <RoundingMode mode>
static VectorType apply(VectorType val) {
return _mm_round_ps(val, int(mode));
}

static VectorType prepare(size_t scale) { return load1(scale); }
};

template <>
class BaseFloatRoundingComputation<Float64> {
public:
using ScalarType = Float64;
using VectorType = __m128d;
static const size_t data_count = 2;

static VectorType load(const ScalarType* in) { return _mm_loadu_pd(in); }
static VectorType load1(const ScalarType in) { return _mm_load1_pd(&in); }
static void store(ScalarType* out, VectorType val) { _mm_storeu_pd(out, val); }
static VectorType multiply(VectorType val, VectorType scale) { return _mm_mul_pd(val, scale); }
static VectorType divide(VectorType val, VectorType scale) { return _mm_div_pd(val, scale); }
template <RoundingMode mode>
static VectorType apply(VectorType val) {
return _mm_round_pd(val, int(mode));
}

static VectorType prepare(size_t scale) { return load1(scale); }
};

#else

/// Implementation for ARM. Not vectorized.

template <TieBreakingMode tie_breaking_mode>
inline float roundWithMode(float x, RoundingMode mode) {
switch (mode) {
case RoundingMode::Round:
return nearbyintf(x);
case RoundingMode::Round: {
if constexpr (tie_breaking_mode == TieBreakingMode::Bankers) {
return nearbyintf(x);
} else {
return roundf(x);
}
}
case RoundingMode::Floor:
return floorf(x);
case RoundingMode::Ceil:
Expand All @@ -243,10 +200,16 @@ inline float roundWithMode(float x, RoundingMode mode) {
__builtin_unreachable();
}

template <TieBreakingMode tie_breaking_mode>
inline double roundWithMode(double x, RoundingMode mode) {
switch (mode) {
case RoundingMode::Round:
return nearbyint(x);
case RoundingMode::Round: {
if constexpr (tie_breaking_mode == TieBreakingMode::Bankers) {
return nearbyint(x);
} else {
return round(x);
}
}
case RoundingMode::Floor:
return floor(x);
case RoundingMode::Ceil:
Expand All @@ -259,7 +222,7 @@ inline double roundWithMode(double x, RoundingMode mode) {
__builtin_unreachable();
}

template <typename T>
template <typename T, TieBreakingMode tie_breaking_mode>
class BaseFloatRoundingComputation {
public:
using ScalarType = T;
Expand All @@ -273,19 +236,18 @@ class BaseFloatRoundingComputation {
static VectorType divide(VectorType val, VectorType scale) { return val / scale; }
template <RoundingMode mode>
static VectorType apply(VectorType val) {
return roundWithMode(val, mode);
return roundWithMode<tie_breaking_mode>(val, mode);
}

static VectorType prepare(size_t scale) { return load1(scale); }
};

#endif

/** Implementation of low-level round-off functions for floating-point values.
*/
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
class FloatRoundingComputation : public BaseFloatRoundingComputation<T> {
using Base = BaseFloatRoundingComputation<T>;
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode,
TieBreakingMode tie_breaking_mode>
class FloatRoundingComputation : public BaseFloatRoundingComputation<T, tie_breaking_mode> {
using Base = BaseFloatRoundingComputation<T, tie_breaking_mode>;

public:
static inline void compute(const T* __restrict in, const typename Base::VectorType& scale,
Expand All @@ -312,12 +274,13 @@ class FloatRoundingComputation : public BaseFloatRoundingComputation<T> {

/** Implementing high-level rounding functions.
*/
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode,
TieBreakingMode tie_breaking_mode>
struct FloatRoundingImpl {
private:
static_assert(!IsDecimalNumber<T>);

using Op = FloatRoundingComputation<T, rounding_mode, scale_mode>;
using Op = FloatRoundingComputation<T, rounding_mode, scale_mode, tie_breaking_mode>;
using Data = std::array<T, Op::data_count>;
using ColumnType = ColumnVector<T>;
using Container = typename ColumnType::Container;
Expand Down Expand Up @@ -433,7 +396,8 @@ struct Dispatcher {
using FunctionRoundingImpl = std::conditional_t<
IsDecimalNumber<T>, DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>,
std::conditional_t<
std::is_floating_point_v<T>, FloatRoundingImpl<T, rounding_mode, scale_mode>,
std::is_floating_point_v<T>,
FloatRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>,
IntegerRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>>>;

static ColumnPtr apply(const IColumn* col_general, Int16 scale_arg) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ under the License.

`T round(T x[, d])`
Rounds the argument `x` to `d` decimal places. `d` defaults to 0 if not specified. If d is negative, the left d digits of the decimal point are 0. If x or d is null, null is returned.
2.5 will round up to 3. If you want to round down to 2, please use the round_bankers function.

### example

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ under the License.

`T round(T x[, d])`
`x`四舍五入后保留d位小数,d默认为0。如果d为负数,则小数点左边d位为0。如果x或d为null,返回null。

2.5会舍入到3,如果想要舍入到2的算法,请使用round_bankers函数。
### example

```
Expand Down
13 changes: 13 additions & 0 deletions regression-test/data/correctness/test_float_round_up.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select --
2.5 3.0

-- !select --
3.5 4.0

-- !select --
2.5 2.0

-- !select --
3.5 4.0

Original file line number Diff line number Diff line change
Expand Up @@ -2934,7 +2934,7 @@ Monday
0.0
0.0
0.0
0.0
1.0
1.0
1.0
1.0
Expand All @@ -2948,7 +2948,7 @@ Monday
0.0
0.0
0.0
0.0
1.0
1.0
1.0
1.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ string3
0.0
0.0
0.0
0.0
1.0
1.0
1.0
1.0
Expand All @@ -483,7 +483,7 @@ string3
0.0
0.0
0.0
0.0
1.0
1.0
1.0
1.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,6 @@
-- !select --
10.12

-- !select --
0.0 0.0
0.5 0.0
1.0 1.0
1.5 2.0
2.0 2.0
2.5 2.0
3.0 3.0
3.5 4.0
4.0 4.0
4.5 4.0

-- !select --
0.0 0.0
0.5 1.0
Expand All @@ -35,6 +23,18 @@
4.0 4.0
4.5 5.0

-- !select --
0.0 0
0.5 1
1.0 1
1.5 2
2.0 2
2.5 3
3.0 3
3.5 4
4.0 4
4.5 5

-- !truncate --
1 1989 1001 123.1 0.1 6.3
2 1986 1001 1243.5 20.2 789.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ tipfortip/issues 10 1 10.0
Mindwerks/wildmidi 9 9 1.0
NeuroVault/NeuroVault 9 1 9.0
THE-ESCAPIST/RSSHub 9 7 1.29
WhisperSystems/TextSecure 9 8 1.12
WhisperSystems/TextSecure 9 8 1.13
XLabs/Xamarin-Forms-Labs 9 6 1.5
aws/eks-distro 9 1 9.0
disco-trooper/weather-app 9 9 1.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ tipfortip/issues 10 1 10.0
Mindwerks/wildmidi 9 9 1.0
NeuroVault/NeuroVault 9 1 9.0
THE-ESCAPIST/RSSHub 9 7 1.29
WhisperSystems/TextSecure 9 8 1.12
WhisperSystems/TextSecure 9 8 1.13
XLabs/Xamarin-Forms-Labs 9 6 1.5
aws/eks-distro 9 1 9.0
disco-trooper/weather-app 9 9 1.0
Expand Down
38 changes: 38 additions & 0 deletions regression-test/suites/correctness/test_float_round_up.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("test_float_round_up") {
sql """ set enable_nereids_planner=true; """
sql """ set enable_fallback_to_original_planner=false; """


qt_select """
select 5/2 , round(5/2);
"""

qt_select """
select 7/ 2 , round(7/2);
"""

qt_select """
select 5/2 , round_bankers(5/2);
"""

qt_select """
select 7/ 2 , round_bankers(7/2);
"""
}

0 comments on commit b288736

Please sign in to comment.