Skip to content

Commit

Permalink
Eliminated more unnecessary data copies
Browse files Browse the repository at this point in the history
  • Loading branch information
dsuponitskiy-duality committed Sep 3, 2024
1 parent aae0814 commit 5f01a66
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/core/include/lattice/hal/dcrtpoly-interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ class DCRTPolyInterface : public ILElement<DerivedType, BigVecType> {
* @param element The element to store
*/
void SetElementAtIndex(usint index, TowerType&& element) {
return this->GetDerived().SetElementAtIndex(index, element);
return this->GetDerived().SetElementAtIndex(index, std::move(element));
}

/***********************************************************************
Expand Down
19 changes: 10 additions & 9 deletions src/pke/lib/cryptocontext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,30 +788,31 @@ std::unordered_map<uint32_t, DCRTPoly> CryptoContextImpl<DCRTPoly>::ShareKeys(co
for (size_t k = 0; k < vecSize; k++) {
auto modq_k = elementParams->GetParams()[k]->GetModulus();

NativeVector powtempvec(ring_dimension, modq_k);
NativePoly powtemppoly(elementParams->GetParams()[k], Format::COEFFICIENT);
NativePoly fevalpoly(elementParams->GetParams()[k], Format::COEFFICIENT, true);

NativeInteger powtemp(1);
for (size_t t = 1; t < threshold; t++) {
powtemp = powtemp.ModMul(i, modq_k);

for (size_t d = 0; d < ring_dimension; d++) {
powtempvec.at(d) = powtemp;
NativeVector powtempvec(ring_dimension, modq_k);
for (size_t i = 0; i < powtempvec.GetLength(); ++i) {
// TODO (dsuponit): should we have a contructor to get a value for all m_data elements in NativeVector
powtempvec[i] = powtemp;
}

powtemppoly.SetValues(powtempvec, Format::COEFFICIENT);
powtemppoly.SetValues(std::move(powtempvec), Format::COEFFICIENT);

auto fst = fs[t].GetElementAtIndex(k);

for (size_t l = 0; l < ring_dimension; l++) {
fevalpoly.at(l) += powtemppoly.at(l).ModMul(fst.at(l), modq_k);
for (size_t i = 0; i < ring_dimension; ++i) {
fevalpoly.at(i) += powtemppoly.at(i).ModMul(fst.at(i), modq_k);
}
}
fevalpoly += fs[0].GetElementAtIndex(k);

fevalpoly.SetFormat(Format::COEFFICIENT);
feval.SetElementAtIndex(k, fevalpoly);
feval.SetElementAtIndex(k, std::move(fevalpoly));
}
// assign fi
SecretShares[i] = feval;
Expand Down Expand Up @@ -888,7 +889,7 @@ void CryptoContextImpl<DCRTPoly>::RecoverSharedKey(PrivateKey<DCRTPoly>& sk,
}
}
multpoly.SetFormat(Format::EVALUATION);
Lagrange_coeffs[j].SetElementAtIndex(k, multpoly);
Lagrange_coeffs[j].SetElementAtIndex(k, std::move(multpoly));
}
Lagrange_coeffs[j].SetFormat(Format::COEFFICIENT);
}
Expand All @@ -901,7 +902,7 @@ void CryptoContextImpl<DCRTPoly>::RecoverSharedKey(PrivateKey<DCRTPoly>& sk,
const auto& share = sk_shares[client_indexes[i]].GetAllElements()[k];
lagrange_sum_of_elems_poly += coeff.TimesNoCheck(share);
}
lagrange_sum_of_elems.SetElementAtIndex(k, lagrange_sum_of_elems_poly);
lagrange_sum_of_elems.SetElementAtIndex(k, std::move(lagrange_sum_of_elems_poly));
}
lagrange_sum_of_elems.SetFormat(Format::EVALUATION);
sk->SetPrivateElement(lagrange_sum_of_elems);
Expand Down
2 changes: 1 addition & 1 deletion src/pke/lib/keyswitch/keyswitch-bv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ EvalKey<DCRTPoly> KeySwitchBV::KeySwitchGenInternal(const PrivateKey<DCRTPoly> o

for (size_t k = 0; k < sOldDecomposed.size(); k++) {
DCRTPoly filtered(elementParams, Format::EVALUATION, true);
filtered.SetElementAtIndex(i, sOldDecomposed[k]);
filtered.SetElementAtIndex(i, std::move(sOldDecomposed[k]));

DCRTPoly u = (cryptoParams->GetSecretKeyDist() == GAUSSIAN) ?
DCRTPoly(dgg, elementParams, Format::EVALUATION) :
Expand Down
12 changes: 6 additions & 6 deletions src/pke/lib/keyswitch/keyswitch-hybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ Ciphertext<DCRTPoly> KeySwitchHYBRID::KeySwitchExt(ConstCiphertext<DCRTPoly> cip
if ((addFirst) || (k > 0)) {
auto cMult = cv[k].TimesNoCheck(cryptoParams->GetPModq());
for (usint i = 0; i < sizeQl; i++) {
resultElements[k].SetElementAtIndex(i, cMult.GetElementAtIndex(i));
resultElements[k].SetElementAtIndex(i, std::move(cMult.GetElementAtIndex(i)));
}
}
}
Expand Down Expand Up @@ -386,11 +386,11 @@ std::shared_ptr<std::vector<DCRTPoly>> KeySwitchHYBRID::EvalKeySwitchPrecomputeC

uint32_t sizePartQl = partsCt[part].GetNumOfElements();
partsCtCompl[part] = partCtClone.ApproxSwitchCRTBasis(
cryptoParams->GetParamsPartQ(part), cryptoParams->GetParamsComplPartQ(sizeQl - 1, part),
cryptoParams->GetPartQlHatInvModq(part, sizePartQl - 1),
cryptoParams->GetPartQlHatInvModqPrecon(part, sizePartQl - 1),
cryptoParams->GetPartQlHatModp(sizeQl - 1, part),
cryptoParams->GetmodComplPartqBarrettMu(sizeQl - 1, part));
cryptoParams->GetParamsPartQ(part), cryptoParams->GetParamsComplPartQ(sizeQl - 1, part),
cryptoParams->GetPartQlHatInvModq(part, sizePartQl - 1),
cryptoParams->GetPartQlHatInvModqPrecon(part, sizePartQl - 1),
cryptoParams->GetPartQlHatModp(sizeQl - 1, part),
cryptoParams->GetmodComplPartqBarrettMu(sizeQl - 1, part));

partsCtCompl[part].SetFormat(Format::EVALUATION);

Expand Down
8 changes: 4 additions & 4 deletions src/pke/lib/scheme/ckksrns/ckksrns-fhe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2264,8 +2264,8 @@ Plaintext FHECKKSRNS::MakeAuxPlaintext(const CryptoContextImpl<DCRTPoly>& cc, co
NativeVector nativeVec(N, nativeParams[i]->GetModulus());
FitToNativeVector(N, temp, Max128BitValue(), &nativeVec);
NativePoly element = plainElement.GetElementAtIndex(i);
element.SetValues(nativeVec, Format::COEFFICIENT);
plainElement.SetElementAtIndex(i, element);
element.SetValues(std::move(nativeVec), Format::COEFFICIENT);
plainElement.SetElementAtIndex(i, std::move(element));
}

usint numTowers = nativeParams.size();
Expand Down Expand Up @@ -2407,8 +2407,8 @@ Plaintext FHECKKSRNS::MakeAuxPlaintext(const CryptoContextImpl<DCRTPoly>& cc, co
NativeVector nativeVec(N, nativeParams[i]->GetModulus());
FitToNativeVector(N, temp, Max64BitValue(), &nativeVec);
NativePoly element = plainElement.GetElementAtIndex(i);
element.SetValues(nativeVec, Format::COEFFICIENT);
plainElement.SetElementAtIndex(i, element);
element.SetValues(std::move(nativeVec), Format::COEFFICIENT);
plainElement.SetElementAtIndex(i, std::move(element));
}

usint numTowers = nativeParams.size();
Expand Down
2 changes: 1 addition & 1 deletion src/pke/lib/scheme/ckksrns/ckksrns-leveledshe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ Ciphertext<DCRTPoly> LeveledSHECKKSRNS::EvalFastRotationExt(ConstCiphertext<DCRT
DCRTPoly psiC0 = DCRTPoly(paramsQlP, Format::EVALUATION, true);
auto cMult = ciphertext->GetElements()[0].TimesNoCheck(cryptoParams->GetPModq());
for (usint i = 0; i < sizeQl; i++) {
psiC0.SetElementAtIndex(i, cMult.GetElementAtIndex(i));
psiC0.SetElementAtIndex(i, std::move(cMult.GetElementAtIndex(i)));
}
(*cTilda)[0] += psiC0;
}
Expand Down
8 changes: 4 additions & 4 deletions src/pke/lib/scheme/ckksrns/ckksrns-schemeswitching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ Plaintext SWITCHCKKSRNS::MakeAuxPlaintext(const CryptoContextImpl<DCRTPoly>& cc,
NativeVector nativeVec(N, nativeParams[i]->GetModulus());
FitToNativeVector(N, temp, Max128BitValue(), &nativeVec);
NativePoly element = plainElement.GetElementAtIndex(i);
element.SetValues(nativeVec, Format::COEFFICIENT);
plainElement.SetElementAtIndex(i, element);
element.SetValues(std::move(nativeVec), Format::COEFFICIENT);
plainElement.SetElementAtIndex(i, std::move(element));
}

usint numTowers = nativeParams.size();
Expand Down Expand Up @@ -415,8 +415,8 @@ Plaintext SWITCHCKKSRNS::MakeAuxPlaintext(const CryptoContextImpl<DCRTPoly>& cc,
NativeVector nativeVec(N, nativeParams[i]->GetModulus());
FitToNativeVector(N, temp, Max64BitValue(), &nativeVec);
NativePoly element = plainElement.GetElementAtIndex(i);
element.SetValues(nativeVec, Format::COEFFICIENT);
plainElement.SetElementAtIndex(i, element);
element.SetValues(std::move(nativeVec), Format::COEFFICIENT);
plainElement.SetElementAtIndex(i, std::move(element));
}

usint numTowers = nativeParams.size();
Expand Down

0 comments on commit 5f01a66

Please sign in to comment.