Skip to content

Commit

Permalink
Code clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
dsuponitskiy-duality committed Sep 10, 2024
1 parent dcd1917 commit bc51121
Showing 1 changed file with 23 additions and 47 deletions.
70 changes: 23 additions & 47 deletions src/pke/lib/cryptocontext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,7 @@ Ciphertext<Element> CryptoContextImpl<Element>::EvalSum(ConstCiphertext<Element>
ValidateCiphertext(ciphertext);

auto evalSumKeys = CryptoContextImpl<Element>::GetEvalAutomorphismKeyMap(ciphertext->GetKeyTag());
auto rv = GetScheme()->EvalSum(ciphertext, batchSize, evalSumKeys);
return rv;
return GetScheme()->EvalSum(ciphertext, batchSize, evalSumKeys);
}

template <typename Element>
Expand All @@ -386,8 +385,7 @@ Ciphertext<Element> CryptoContextImpl<Element>::EvalSumRows(ConstCiphertext<Elem
usint subringDim) const {
ValidateCiphertext(ciphertext);

auto rv = GetScheme()->EvalSumRows(ciphertext, numRows, evalSumKeys, subringDim);
return rv;
return GetScheme()->EvalSumRows(ciphertext, numRows, evalSumKeys, subringDim);
}

template <typename Element>
Expand All @@ -397,26 +395,21 @@ Ciphertext<Element> CryptoContextImpl<Element>::EvalSumCols(
ValidateCiphertext(ciphertext);

auto evalSumKeys = CryptoContextImpl<Element>::GetEvalAutomorphismKeyMap(ciphertext->GetKeyTag());
auto rv = GetScheme()->EvalSumCols(ciphertext, numCols, evalSumKeys, evalSumKeysRight);
return rv;
return GetScheme()->EvalSumCols(ciphertext, numCols, evalSumKeys, evalSumKeysRight);
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::EvalAtIndex(ConstCiphertext<Element> ciphertext, int32_t index) const {
ValidateCiphertext(ciphertext);

// If the index is zero, no rotation is needed, copy the ciphertext and return
// This is done after the keyMap so that it is protected if there's not a
// valid key.
// This is done after the keyMap so that it is protected if there's not a valid key.
if (0 == index) {
auto rv = ciphertext->Clone();
return rv;
return ciphertext->Clone();
}

auto evalAutomorphismKeys = CryptoContextImpl<Element>::GetEvalAutomorphismKeyMap(ciphertext->GetKeyTag());

auto rv = GetScheme()->EvalAtIndex(ciphertext, index, evalAutomorphismKeys);
return rv;
return GetScheme()->EvalAtIndex(ciphertext, index, evalAutomorphismKeys);
}

template <typename Element>
Expand All @@ -425,41 +418,30 @@ Ciphertext<Element> CryptoContextImpl<Element>::EvalMerge(
ValidateCiphertext(ciphertextVector[0]);

auto evalAutomorphismKeys = CryptoContextImpl<Element>::GetEvalAutomorphismKeyMap(ciphertextVector[0]->GetKeyTag());

auto rv = GetScheme()->EvalMerge(ciphertextVector, evalAutomorphismKeys);

return rv;
return GetScheme()->EvalMerge(ciphertextVector, evalAutomorphismKeys);
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::EvalInnerProduct(ConstCiphertext<Element> ct1,
ConstCiphertext<Element> ct2, usint batchSize) const {
ValidateCiphertext(ct1);
if (ct2 == nullptr || ct1->GetKeyTag() != ct2->GetKeyTag())
OPENFHE_THROW(
"Information passed to EvalInnerProduct was not generated "
"with this crypto context");
OPENFHE_THROW("Information was not generated with this crypto context");

auto evalSumKeys = CryptoContextImpl<Element>::GetEvalAutomorphismKeyMap(ct1->GetKeyTag());
auto ek = CryptoContextImpl<Element>::GetEvalMultKeyVector(ct1->GetKeyTag());

auto rv = GetScheme()->EvalInnerProduct(ct1, ct2, batchSize, evalSumKeys, ek[0]);
return rv;
return GetScheme()->EvalInnerProduct(ct1, ct2, batchSize, evalSumKeys, ek[0]);
}

template <typename Element>
Ciphertext<Element> CryptoContextImpl<Element>::EvalInnerProduct(ConstCiphertext<Element> ct1, ConstPlaintext ct2,
usint batchSize) const {
ValidateCiphertext(ct1);
if (ct2 == nullptr)
OPENFHE_THROW(
"Information passed to EvalInnerProduct was not generated "
"with this crypto context");
OPENFHE_THROW("Information was not generated with this crypto context");

auto evalSumKeys = CryptoContextImpl<Element>::GetEvalAutomorphismKeyMap(ct1->GetKeyTag());

auto rv = GetScheme()->EvalInnerProduct(ct1, ct2, batchSize, evalSumKeys);
return rv;
return GetScheme()->EvalInnerProduct(ct1, ct2, batchSize, evalSumKeys);
}

template <typename Element>
Expand All @@ -477,9 +459,9 @@ template <typename Element>
DecryptResult CryptoContextImpl<Element>::Decrypt(ConstCiphertext<Element> ciphertext,
const PrivateKey<Element> privateKey, Plaintext* plaintext) {
if (ciphertext == nullptr)
OPENFHE_THROW("ciphertext passed to Decrypt is empty");
OPENFHE_THROW("ciphertext is empty");
if (plaintext == nullptr)
OPENFHE_THROW("plaintext passed to Decrypt is empty");
OPENFHE_THROW("plaintext is empty");
ValidateKey(privateKey);

// determine which type of plaintext that you need to decrypt into
Expand Down Expand Up @@ -582,13 +564,11 @@ template <>
DecryptResult CryptoContextImpl<DCRTPoly>::Decrypt(ConstCiphertext<DCRTPoly> ciphertext,
const PrivateKey<DCRTPoly> privateKey, Plaintext* plaintext) {
if (ciphertext == nullptr)
OPENFHE_THROW("ciphertext passed to Decrypt is empty");
OPENFHE_THROW("ciphertext is empty");
if (plaintext == nullptr)
OPENFHE_THROW("plaintext passed to Decrypt is empty");
OPENFHE_THROW("plaintext is empty");
if (privateKey == nullptr || Mismatched(privateKey->GetCryptoContext()))
OPENFHE_THROW(
"Information passed to Decrypt was not generated with "
"this crypto context");
OPENFHE_THROW("Information was not generated with this crypto context");

// determine which type of plaintext that you need to decrypt into
// Plaintext decrypted =
Expand Down Expand Up @@ -643,9 +623,7 @@ DecryptResult CryptoContextImpl<DCRTPoly>::MultipartyDecryptFusion(
for (size_t i = 0; i < last_ciphertext; i++) {
ValidateCiphertext(partialCiphertextVec[i]);
if (partialCiphertextVec[i]->GetEncodingType() != partialCiphertextVec[0]->GetEncodingType())
OPENFHE_THROW(
"Ciphertexts passed to MultipartyDecryptFusion have "
"mismatched encoding types");
OPENFHE_THROW("Ciphertexts have mismatched encoding types");
}

// determine which type of plaintext that you need to decrypt into
Expand Down Expand Up @@ -730,21 +708,21 @@ std::unordered_map<uint32_t, DCRTPoly> CryptoContextImpl<DCRTPoly>::ShareKeys(co
auto ring_dimension = elementParams->GetRingDimension();

// condition for inverse in lagrange coeff to exist.
for (usint k = 0; k < vecSize; k++) {
auto modq_k = elementParams->GetParams()[k]->GetModulus();
for (size_t i = 0; i < vecSize; ++i) {
auto modq_k = elementParams->GetParams()[i]->GetModulus();
if (N >= modq_k)
OPENFHE_THROW("Number of parties N needs to be less than DCRTPoly moduli");
}

// secret sharing
const usint num_of_shares = N - 1;
std::unordered_map<uint32_t, DCRTPoly> SecretShares;

if (shareType == "additive") {
// generate a random share of N-2 elements and create the last share as sk - (sk_1 + ... + sk_N-2)
typename DCRTPoly::DugType dug;
DCRTPoly rsum(dug, elementParams, Format::EVALUATION);

const uint32_t num_of_shares = N - 1;
std::vector<DCRTPoly> SecretSharesVec;
SecretSharesVec.reserve(num_of_shares);
SecretSharesVec.push_back(rsum);
Expand All @@ -755,11 +733,9 @@ std::unordered_map<uint32_t, DCRTPoly> CryptoContextImpl<DCRTPoly>::ShareKeys(co
}
SecretSharesVec.push_back(sk->GetPrivateElement() - rsum);

usint ctr = 0;
for (size_t i = 1; i <= N; i++) {
for (size_t i = 1, ctr = 0; i <= N; ++i) {
if (i != index) {
SecretShares[i] = SecretSharesVec[ctr];
ctr++;
SecretShares[i] = SecretSharesVec[ctr++];
}
}
}
Expand All @@ -777,7 +753,7 @@ std::unordered_map<uint32_t, DCRTPoly> CryptoContextImpl<DCRTPoly>::ShareKeys(co
}

// evaluate the polynomial at the index of the parties 1 to N
for (size_t i = 1; i <= N; i++) {
for (size_t i = 1; i <= N; ++i) {
if (i != index) {
DCRTPoly feval(elementParams, Format::COEFFICIENT, true);
for (size_t k = 0; k < vecSize; k++) {
Expand Down

0 comments on commit bc51121

Please sign in to comment.