Skip to content

Commit

Permalink
intt optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
pascoec committed Oct 8, 2024
1 parent cf579ba commit b111bec
Showing 1 changed file with 40 additions and 11 deletions.
51 changes: 40 additions & 11 deletions src/core/include/math/hal/intnat/transformnat-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ void NumberTheoreticTransformNat<VecType>::ForwardTransformToBitReverseInPlace(c
(*element)[i + 0] += omegaFactor - (omegaFactor >= (modulus - loVal) ? modulus : 0);
if (omegaFactor > loVal)
loVal += modulus;
(*element)[i + t] = loVal - omegaFactor;
(*element)[i + 1] = loVal - omegaFactor;
#endif
}
}
Expand Down Expand Up @@ -496,11 +496,13 @@ void NumberTheoreticTransformNat<VecType>::InverseTransformFromBitReverseInPlace
const IntType& preconCycloOrderInv, VecType* element) {
auto modulus{element->GetModulus()};
uint32_t n(element->GetLength());
auto omega1Inv{rootOfUnityInverseTable[n]};
auto preconOmega1Inv{preconRootOfUnityInverseTable[n]};
for (uint32_t i{0}; i < n; i += 2) {
auto omega{rootOfUnityInverseTable[(i + n) >> 1]};
auto preconOmega{preconRootOfUnityInverseTable[(i + n) >> 1]};
auto hiVal{(*element)[i + 1]};
auto loVal{(*element)[i + 0]};
auto hiVal{(*element)[i + 1]};
#if defined(__GNUC__) && !defined(__clang__)
auto omegaFactor{loVal};
if (omegaFactor < hiVal)
Expand All @@ -509,28 +511,24 @@ void NumberTheoreticTransformNat<VecType>::InverseTransformFromBitReverseInPlace
loVal += hiVal;
if (loVal >= modulus)
loVal -= modulus;
loVal.ModMulFastConstEq(cycloOrderInv, modulus, preconCycloOrderInv);
omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega);
omegaFactor.ModMulFastConstEq(cycloOrderInv, modulus, preconCycloOrderInv);
(*element)[i + 0] = loVal;
(*element)[i + 1] = omegaFactor;
#else
auto omegaFactor{loVal + (hiVal > loVal ? modulus : 0) - hiVal};
loVal += hiVal - (hiVal >= (modulus - loVal) ? modulus : 0);
loVal.ModMulFastConstEq(cycloOrderInv, modulus, preconCycloOrderInv);
(*element)[i + 0] = loVal;
omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega);
omegaFactor.ModMulFastConstEq(cycloOrderInv, modulus, preconCycloOrderInv);
(*element)[i + 1] = omegaFactor;
#endif
}
for (uint32_t m{n >> 2}, t{2}, logt{2}; m >= 1; m >>= 1, t <<= 1, ++logt) {
for (uint32_t m{n >> 2}, t{2}, logt{2}; m > 1; m >>= 1, t <<= 1, ++logt) {
for (uint32_t i{0}; i < m; ++i) {
auto omega{rootOfUnityInverseTable[i + m]};
auto preconOmega{preconRootOfUnityInverseTable[i + m]};
for (uint32_t j1{i << logt}, j2{j1 + t}; j1 < j2; ++j1) {
auto hiVal{(*element)[j1 + t]};
auto loVal{(*element)[j1 + 0]};
auto hiVal{(*element)[j1 + t]};
#if defined(__GNUC__) && !defined(__clang__)
auto omegaFactor{loVal};
if (omegaFactor < hiVal)
Expand All @@ -551,6 +549,30 @@ void NumberTheoreticTransformNat<VecType>::InverseTransformFromBitReverseInPlace
}
}
}
const uint32_t j2 = n >> 1;
for (uint32_t j1{0}; j1 < j2; ++j1) {
auto loVal{(*element)[j1]};
auto hiVal{(*element)[j1 + j2]};
#if defined(__GNUC__) && !defined(__clang__)
auto omegaFactor{loVal};
if (omegaFactor < hiVal)
omegaFactor += modulus;
omegaFactor -= hiVal;
loVal += hiVal;
if (loVal >= modulus)
loVal -= modulus;
omegaFactor.ModMulFastConstEq(omega1Inv, modulus, preconOmega1Inv);
(*element)[j1 + 0] = loVal;
(*element)[j1 + j2] = omegaFactor;
#else
(*element)[j1] += hiVal - (hiVal >= (modulus - loVal) ? modulus : 0);
auto omegaFactor = loVal + (hiVal > loVal ? modulus : 0) - hiVal;
omegaFactor.ModMulFastConstEq(omega1Inv, modulus, preconOmega1Inv);
(*element)[j1 + j2] = omegaFactor;
#endif
}
for (uint32_t i = 0; i < j2; ++i)
(*element)[i].ModMulFastConstEq(cycloOrderInv, modulus, preconCycloOrderInv);
}

template <typename VecType>
Expand Down Expand Up @@ -706,7 +728,6 @@ void ChineseRemainderTransformFTTNat<VecType>::InverseTransformFromBitReverse(co
template <typename VecType>
void ChineseRemainderTransformFTTNat<VecType>::PreCompute(const IntType& rootOfUnity, const usint CycloOrder,
const IntType& modulus) {
// Half of cyclo order
usint CycloOrderHf = (CycloOrder >> 1);

auto mapSearch = m_rootOfUnityReverseTableByModulus.find(modulus);
Expand All @@ -717,7 +738,7 @@ void ChineseRemainderTransformFTTNat<VecType>::PreCompute(const IntType& rootOfU
usint msb = GetMSB(CycloOrderHf - 1);
IntType mu = modulus.ComputeMu();
VecType Table(CycloOrderHf, modulus);
VecType TableI(CycloOrderHf, modulus);
VecType TableI(CycloOrderHf + 1, modulus);
IntType rootOfUnityInverse = rootOfUnity.ModInverse(modulus);
usint iinv;
for (usint i = 0; i < CycloOrderHf; i++) {
Expand All @@ -739,7 +760,7 @@ void ChineseRemainderTransformFTTNat<VecType>::PreCompute(const IntType& rootOfU

NativeInteger nativeModulus = modulus.ConvertToInt();
VecType preconTable(CycloOrderHf, nativeModulus);
VecType preconTableI(CycloOrderHf, nativeModulus);
VecType preconTableI(CycloOrderHf + 1, nativeModulus);

for (usint i = 0; i < CycloOrderHf; i++) {
preconTable[i] = NativeInteger(m_rootOfUnityReverseTableByModulus[modulus][i].ConvertToInt())
Expand All @@ -757,6 +778,14 @@ void ChineseRemainderTransformFTTNat<VecType>::PreCompute(const IntType& rootOfU
m_rootOfUnityPreconReverseTableByModulus[modulus] = preconTable;
m_rootOfUnityInversePreconReverseTableByModulus[modulus] = preconTableI;
m_cycloOrderInversePreconTableByModulus[modulus] = preconTableCOI;

// optimization: precompute omega * cycloOrderInverse
m_rootOfUnityInverseReverseTableByModulus[modulus][CycloOrderHf] =
m_rootOfUnityInverseReverseTableByModulus[modulus][1].ModMul(
m_cycloOrderInverseTableByModulus[modulus][msb], modulus);
m_rootOfUnityInversePreconReverseTableByModulus[modulus][CycloOrderHf] =
NativeInteger(m_rootOfUnityInverseReverseTableByModulus[modulus][CycloOrderHf].ConvertToInt())
.PrepModMulConst(nativeModulus);
}
}
}
Expand Down

0 comments on commit b111bec

Please sign in to comment.