Skip to content

Commit

Permalink
intt optimization (#876)
Browse files Browse the repository at this point in the history
* intt optimization

* added comments

* more comments

* omega[bitreversed(1)] * (n inverse) no longer precomputed
  • Loading branch information
pascoec authored Oct 23, 2024
1 parent c938455 commit 8f72e90
Showing 1 changed file with 79 additions and 12 deletions.
91 changes: 79 additions & 12 deletions src/core/include/math/hal/intnat/transformnat-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,25 @@ template <typename VecType>
void NumberTheoreticTransformNat<VecType>::ForwardTransformToBitReverseInPlace(const VecType& rootOfUnityTable,
const VecType& preconRootOfUnityTable,
VecType* element) {
auto modulus{element->GetModulus()};
uint32_t n(element->GetLength() >> 1), t{n}, logt{GetMSB(t)};
for (uint32_t m{1}; m < n; m <<= 1, t >>= 1, --logt) {
//
// NTT based on the Cooley-Tukey (CT) butterfly
// Inputs: element (vector of size n in standard ordering)
// rootOfUnityTable (precomputed roots of unity in bit-reversed ordering)
// Output: NTT(element) in bit-reversed ordering
//
// for (m = 1, t = n, logt = log(t); m < n; m=2*m, t=t/2, --logt) do
// for (i = 0; i < m; ++i) do
// omega = rootOfUnityInverseTable[i + m]
// for (j1 = (i << logt), j2 = (j1 + t); j1 < j2; ++j1) do
// loVal = element[j1 + 0]
// hiVal = element[j1 + t]*omega
// element[j1 + 0] = (loVal + hiVal) mod modulus
// element[j1 + t] = (loVal - hiVal) mod modulus
//

const auto modulus{element->GetModulus()};
const uint32_t n(element->GetLength() >> 1);
for (uint32_t m{1}, t{n}, logt{GetMSB(t)}; m < n; m <<= 1, t >>= 1, --logt) {
for (uint32_t i{0}; i < m; ++i) {
auto omega{rootOfUnityTable[i + m]};
auto preconOmega{preconRootOfUnityTable[i + m]};
Expand All @@ -331,6 +347,7 @@ void NumberTheoreticTransformNat<VecType>::ForwardTransformToBitReverseInPlace(c
}
}
}
// peeled off last ntt stage for performance
for (uint32_t i{0}; i < (n << 1); i += 2) {
auto omegaFactor{(*element)[i + 1]};
auto omega{rootOfUnityTable[(i >> 1) + n]};
Expand All @@ -350,7 +367,7 @@ void NumberTheoreticTransformNat<VecType>::ForwardTransformToBitReverseInPlace(c
(*element)[i + 0] += omegaFactor - (omegaFactor >= (modulus - loVal) ? modulus : 0);
if (omegaFactor > loVal)
loVal += modulus;
(*element)[i + t] = loVal - omegaFactor;
(*element)[i + 1] = loVal - omegaFactor;
#endif
}
}
Expand Down Expand Up @@ -494,13 +511,38 @@ template <typename VecType>
void NumberTheoreticTransformNat<VecType>::InverseTransformFromBitReverseInPlace(
const VecType& rootOfUnityInverseTable, const VecType& preconRootOfUnityInverseTable, const IntType& cycloOrderInv,
const IntType& preconCycloOrderInv, VecType* element) {
//
// INTT based on the Gentleman-Sande (GS) butterfly
// Inputs: element (vector of size n in bit-reversed ordering)
// rootOfUnityInverseTable (precomputed roots of unity in bit-reversed ordering)
// cycloOrderInv (n inverse)
// Output: INTT(element) in standard ordering
//
// for (m = n/2, t = 1, logt = 1; m >= 1; m=m/2, t=2*t, ++logt) do
// for (i = 0; i < m; ++i) do
// omega = rootOfUnityInverseTable[i + m]
// for (j1 = (i << logt), j2 = (j1 + t); j1 < j2; ++j1) do
// loVal = element[j1 + 0]
// hiVal = element[j1 + t]
// element[j1 + 0] = (loVal + hiVal) mod modulus
// element[j1 + t] = (loVal - hiVal)*omega mod modulus
// for (i = 0; i < n; ++i) do
// element[i] = element[i]*cycloOrderInv mod modulus
//

auto modulus{element->GetModulus()};
uint32_t n(element->GetLength());

// precomputed omega[bitreversed(1)] * (n inverse). used in final stage of intt.
auto omega1Inv{rootOfUnityInverseTable[1].ModMulFastConst(cycloOrderInv, modulus, preconCycloOrderInv)};
auto preconOmega1Inv{omega1Inv.PrepModMulConst(modulus)};

// peeled off first stage for performance
for (uint32_t i{0}; i < n; i += 2) {
auto omega{rootOfUnityInverseTable[(i + n) >> 1]};
auto preconOmega{preconRootOfUnityInverseTable[(i + n) >> 1]};
auto hiVal{(*element)[i + 1]};
auto loVal{(*element)[i + 0]};
auto hiVal{(*element)[i + 1]};
#if defined(__GNUC__) && !defined(__clang__)
auto omegaFactor{loVal};
if (omegaFactor < hiVal)
Expand All @@ -509,28 +551,25 @@ void NumberTheoreticTransformNat<VecType>::InverseTransformFromBitReverseInPlace
loVal += hiVal;
if (loVal >= modulus)
loVal -= modulus;
loVal.ModMulFastConstEq(cycloOrderInv, modulus, preconCycloOrderInv);
omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega);
omegaFactor.ModMulFastConstEq(cycloOrderInv, modulus, preconCycloOrderInv);
(*element)[i + 0] = loVal;
(*element)[i + 1] = omegaFactor;
#else
auto omegaFactor{loVal + (hiVal > loVal ? modulus : 0) - hiVal};
loVal += hiVal - (hiVal >= (modulus - loVal) ? modulus : 0);
loVal.ModMulFastConstEq(cycloOrderInv, modulus, preconCycloOrderInv);
(*element)[i + 0] = loVal;
omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega);
omegaFactor.ModMulFastConstEq(cycloOrderInv, modulus, preconCycloOrderInv);
(*element)[i + 1] = omegaFactor;
#endif
}
for (uint32_t m{n >> 2}, t{2}, logt{2}; m >= 1; m >>= 1, t <<= 1, ++logt) {
// inner stages
for (uint32_t m{n >> 2}, t{2}, logt{2}; m > 1; m >>= 1, t <<= 1, ++logt) {
for (uint32_t i{0}; i < m; ++i) {
auto omega{rootOfUnityInverseTable[i + m]};
auto preconOmega{preconRootOfUnityInverseTable[i + m]};
for (uint32_t j1{i << logt}, j2{j1 + t}; j1 < j2; ++j1) {
auto hiVal{(*element)[j1 + t]};
auto loVal{(*element)[j1 + 0]};
auto hiVal{(*element)[j1 + t]};
#if defined(__GNUC__) && !defined(__clang__)
auto omegaFactor{loVal};
if (omegaFactor < hiVal)
Expand All @@ -551,6 +590,35 @@ void NumberTheoreticTransformNat<VecType>::InverseTransformFromBitReverseInPlace
}
}
}

// peeled off final stage to implement optimization where n/2 scalar multiplies
// by (n inverse) are incorporated into the omegaFactor calculation.
// Please see https://github.com/openfheorg/openfhe-development/issues/872 for details.
uint32_t j2{n >> 1};
for (uint32_t j1{0}; j1 < j2; ++j1) {
auto loVal{(*element)[j1]};
auto hiVal{(*element)[j1 + j2]};
#if defined(__GNUC__) && !defined(__clang__)
auto omegaFactor{loVal};
if (omegaFactor < hiVal)
omegaFactor += modulus;
omegaFactor -= hiVal;
loVal += hiVal;
if (loVal >= modulus)
loVal -= modulus;
omegaFactor.ModMulFastConstEq(omega1Inv, modulus, preconOmega1Inv);
(*element)[j1 + 0] = loVal;
(*element)[j1 + j2] = omegaFactor;
#else
(*element)[j1] += hiVal - (hiVal >= (modulus - loVal) ? modulus : 0);
auto omegaFactor = loVal + (hiVal > loVal ? modulus : 0) - hiVal;
omegaFactor.ModMulFastConstEq(omega1Inv, modulus, preconOmega1Inv);
(*element)[j1 + j2] = omegaFactor;
#endif
}
// perform remaining n/2 scalar multiplies by (n inverse)
for (uint32_t i = 0; i < j2; ++i)
(*element)[i].ModMulFastConstEq(cycloOrderInv, modulus, preconCycloOrderInv);
}

template <typename VecType>
Expand Down Expand Up @@ -706,7 +774,6 @@ void ChineseRemainderTransformFTTNat<VecType>::InverseTransformFromBitReverse(co
template <typename VecType>
void ChineseRemainderTransformFTTNat<VecType>::PreCompute(const IntType& rootOfUnity, const usint CycloOrder,
const IntType& modulus) {
// Half of cyclo order
usint CycloOrderHf = (CycloOrder >> 1);

auto mapSearch = m_rootOfUnityReverseTableByModulus.find(modulus);
Expand Down

0 comments on commit 8f72e90

Please sign in to comment.