Skip to content

Commit

Permalink
apply builtin i256
Browse files Browse the repository at this point in the history
  • Loading branch information
taiyang-li committed Dec 13, 2024
1 parent 6fba7de commit cf02634
Showing 1 changed file with 94 additions and 52 deletions.
146 changes: 94 additions & 52 deletions cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,40 +58,42 @@ struct DecimalPlusImpl
template <typename T>
static bool apply(T a, T b, T & r)
{
return !common::addOverflow(a, b, r);
r = a + b;
return true;
}

template <>
static bool apply(Int128 a, Int128 b, Int128 & r)
{
if (canCastLower(a, b))
{
UInt64 low_result;
if (common::addOverflow(static_cast<UInt64>(a), static_cast<UInt64>(b), low_result))
return !common::addOverflow(a, b, r);

r = static_cast<Int128>(low_result);
return true;
Int64 low_result;
if (!common::addOverflow(static_cast<Int64>(a), static_cast<Int64>(b), low_result))
{
r = static_cast<Int128>(low_result);
return true;
}
}
return !common::addOverflow(a, b, r);

r = a + b;
return true;
}

template <>
static bool apply(Int256 a, Int256 b, Int256 & r)
{
if (canCastLower(a, b))
{
UInt128 low_result;
if (common::addOverflow(static_cast<UInt128>(a), static_cast<UInt128>(b), low_result))
return !common::addOverflow(a, b, r);

r = static_cast<Int256>(low_result);
return true;
Int128 low_result;
if (!common::addOverflow(static_cast<Int128>(a), static_cast<Int128>(b), low_result))
{
r = static_cast<Int256>(low_result);
return true;
}
}

return !common::addOverflow(a, b, r);
// r = toInt256(toNewInt256(a) + toNewInt256(b));
// return true;
r = toInt256(toNewInt256(a) + toNewInt256(b));
return true;
}

#if USE_EMBEDDED_COMPILER
Expand All @@ -110,41 +112,42 @@ struct DecimalMinusImpl
template <typename T>
static bool apply(T a, T b, T & r)
{
return !common::subOverflow(a, b, r);
r = a - b;
return true;
}

template <>
static bool apply(Int128 a, Int128 b, Int128 & r)
{
if (canCastLower(a, b))
{
UInt64 low_result;
if (common::subOverflow(static_cast<UInt64>(a), static_cast<UInt64>(b), low_result))
return !common::subOverflow(a, b, r);

r = static_cast<Int128>(low_result);
return true;
Int64 low_result;
if (!common::subOverflow(static_cast<Int64>(a), static_cast<Int64>(b), low_result))
{
r = static_cast<Int128>(low_result);
return true;
}
}

return !common::subOverflow(a, b, r);
r = a - b;
return true;
}

template <>
static bool apply(Int256 a, Int256 b, Int256 & r)
{
if (canCastLower(a, b))
{
UInt128 low_result;
if (common::subOverflow(static_cast<UInt128>(a), static_cast<UInt128>(b), low_result))
return !common::subOverflow(a, b, r);

r = static_cast<Int256>(low_result);
return true;
Int128 low_result;
if (!common::subOverflow(static_cast<Int128>(a), static_cast<Int128>(b), low_result))
{
r = static_cast<Int256>(low_result);
return true;
}
}

return !common::subOverflow(a, b, r);
// r = toInt256(toNewInt256(a) - toNewInt256(b));
// return true;
r = toInt256(toNewInt256(a) - toNewInt256(b));
return true;
}


Expand All @@ -165,30 +168,41 @@ struct DecimalMultiplyImpl
template <typename T>
static bool apply(T a, T b, T & c)
{
return !common::mulOverflow(a, b, c);
c = a * b;
return true;
}

template <Int128>
static bool apply(Int128 a, Int128 b, Int128 & r)
{
if (canCastLower(a, b))
{
UInt64 low_result = 0;
if (common::mulOverflow(static_cast<UInt64>(a), static_cast<UInt64>(b), low_result))
return !common::mulOverflow(a, b, r);

r = static_cast<Int128>(low_result);
return true;
Int64 low_result = 0;
if (!common::mulOverflow(static_cast<Int64>(a), static_cast<Int64>(b), low_result))
{
r = static_cast<Int128>(low_result);
return true;
}
}

return !common::mulOverflow(a, b, r);
r = a * b;
return true;
}

template <>
static bool apply(Int256 a, Int256 b, Int256 & r)
{
// r = toInt256(toNewInt256(a) * toNewInt256(b));
r = a * b;
if (canCastLower(a, b))
{
Int128 low_result = 0;
if (!common::mulOverflow(static_cast<Int128>(a), static_cast<Int128>(b), low_result))
{
r = static_cast<Int256>(low_result);
return true;
}
}

r = toInt256(toNewInt256(a) * toNewInt256(b));
return true;
}

Expand Down Expand Up @@ -222,7 +236,7 @@ struct DecimalDivideImpl

if (canCastLower(a, b))
{
r = static_cast<Int128>(static_cast<UInt64>(a) / static_cast<UInt64>(b));
r = static_cast<Int128>(static_cast<Int64>(a) / static_cast<Int64>(b));
return true;
}

Expand All @@ -238,16 +252,11 @@ struct DecimalDivideImpl

if (canCastLower(a, b))
{
UInt128 low_result = 0;
UInt128 low_a = static_cast<UInt128>(a);
UInt128 low_b = static_cast<UInt128>(b);
apply(low_a, low_b, low_result);
r = static_cast<Int256>(low_result);
r = static_cast<Int256>(static_cast<Int128>(a) / static_cast<Int128>(b));
return true;
}

r = a / b;
// r = toInt256(toNewInt256(a) / toNewInt256(b));
r = toInt256(toNewInt256(a) / toNewInt256(b));
return true;
}

Expand Down Expand Up @@ -275,6 +284,39 @@ struct DecimalModuloImpl
return true;
}

template <>
static bool apply(Int128 a, Int128 b, Int128 & r)
{
if (b == 0)
return false;

if (canCastLower(a, b))
{
r = static_cast<Int128>(static_cast<Int64>(a) % static_cast<Int64>(b));
return true;
}

r = a % b;
return true;
}


template <>
static bool apply(Int256 a, Int256 b, Int256 & r)
{
if (b == 0)
return false;

if (canCastLower(a, b))
{
r = static_cast<Int256>(static_cast<Int128>(a) % static_cast<Int128>(b));
return true;
}

r = toInt256(toNewInt256(a) % toNewInt256(b));
return true;
}

#if USE_EMBEDDED_COMPILER
static constexpr bool compilable = true;

Expand Down

0 comments on commit cf02634

Please sign in to comment.