Skip to content

Commit

Permalink
Removed more PublicKey input parameters and some duplicate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dsuponitskiy-duality committed Oct 11, 2023
1 parent 963139d commit 57cd4de
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 120 deletions.
58 changes: 3 additions & 55 deletions src/pke/include/cryptocontext.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,41 +371,6 @@ class CryptoContextImpl : public Serializable {
}
}

void CompareKeyTag(const PrivateKey<Element> privateKey, const PublicKey<Element> publicKey,
CALLER_INFO_ARGS_HDR) const {
if (privateKey == nullptr) {
std::string errorMsg(std::string("PrivateKey is nullptr") + CALLER_INFO);
OPENFHE_THROW(config_error, errorMsg);
}
else if (publicKey == nullptr) {
std::string errorMsg(std::string("PublicKey is nullptr") + CALLER_INFO);
OPENFHE_THROW(config_error, errorMsg);
}

if (privateKey->GetKeyTag() != publicKey->GetKeyTag()) {
std::string errorMsg(std::string("Public key does not match private key") + CALLER_INFO);
OPENFHE_THROW(config_error, errorMsg);
}
}

void CompareKeyTag(const ConstCiphertext<Element> ciphertext1, const ConstCiphertext<Element> ciphertext2,
CALLER_INFO_ARGS_HDR) const {
if (ciphertext1 == nullptr) {
std::string errorMsg(std::string("Ciphertext1 is nullptr") + CALLER_INFO);
OPENFHE_THROW(config_error, errorMsg);
}
else if (ciphertext2 == nullptr) {
std::string errorMsg(std::string("Ciphertext2 is nullptr") + CALLER_INFO);
OPENFHE_THROW(config_error, errorMsg);
}

if (ciphertext1->GetKeyTag() != ciphertext2->GetKeyTag()) {
std::string errorMsg(std::string("Ciphertexts were not generated with the same crypto context") +
CALLER_INFO);
OPENFHE_THROW(config_error, errorMsg);
}
}

PrivateKey<Element> privateKey;

public:
Expand Down Expand Up @@ -1986,20 +1951,6 @@ class CryptoContextImpl : public Serializable {
return GetScheme()->EvalAutomorphismKeyGen(privateKey, indexList);
}

/**
* NOT USED BY ANY CRYPTO SCHEME: Generate automophism keys for a public and private key
*/
std::shared_ptr<std::map<usint, EvalKey<Element>>> EvalAutomorphismKeyGen(
const PublicKey<Element> publicKey, const PrivateKey<Element> privateKey,
const std::vector<usint>& indexList) const {
ValidateKey(publicKey);
ValidateKey(privateKey);
if (!indexList.size())
OPENFHE_THROW(config_error, "Input index vector is empty");

return GetScheme()->EvalAutomorphismKeyGen(publicKey, privateKey, indexList);
}

/**
* Function for evaluating automorphism of ciphertext at index i
*
Expand Down Expand Up @@ -2204,10 +2155,8 @@ class CryptoContextImpl : public Serializable {
*
* @param privateKey private key.
* @param indexList list of indices.
* @param publicKey public key (used in NTRU schemes). Not used anymore.
*/
void EvalAtIndexKeyGen(const PrivateKey<Element> privateKey, const std::vector<int32_t>& indexList,
const PublicKey<Element> publicKey = nullptr);
void EvalAtIndexKeyGen(const PrivateKey<Element> privateKey, const std::vector<int32_t>& indexList);

/**
* EvalRotateKeyGen generates evaluation keys for a list of rotation indices.
Expand All @@ -2217,9 +2166,8 @@ class CryptoContextImpl : public Serializable {
* @param indexList list of indices.
* @param publicKey public key (used in NTRU schemes).
*/
void EvalRotateKeyGen(const PrivateKey<Element> privateKey, const std::vector<int32_t>& indexList,
const PublicKey<Element> publicKey = nullptr) {
EvalAtIndexKeyGen(privateKey, indexList, publicKey);
void EvalRotateKeyGen(const PrivateKey<Element> privateKey, const std::vector<int32_t>& indexList) {
EvalAtIndexKeyGen(privateKey, indexList);
};

/**
Expand Down
19 changes: 1 addition & 18 deletions src/pke/include/schemebase/base-leveledshe.h
Original file line number Diff line number Diff line change
Expand Up @@ -549,21 +549,6 @@ class LeveledSHEBase {
virtual std::shared_ptr<std::map<usint, EvalKey<Element>>> EvalAutomorphismKeyGen(
const PrivateKey<Element> privateKey, const std::vector<usint>& indexList) const;

/**
* Virtual function to generate all isomorphism keys for a given private key
*
* @param publicKey encryption key for the new ciphertext.
* @param origPrivateKey original private key used for decryption.
* @param indexList list of automorphism indices to be computed
* @return returns the evaluation keys
*/
virtual std::shared_ptr<std::map<usint, EvalKey<Element>>> EvalAutomorphismKeyGen(
const PublicKey<Element> publicKey, const PrivateKey<Element> privateKey,
const std::vector<usint>& indexList) const {
std::string errMsg = "EvalAutomorphismKeyGen is not implemented for this scheme.";
OPENFHE_THROW(not_implemented_error, errMsg);
}

/**
* Virtual function for evaluating automorphism of ciphertext at index i
*
Expand Down Expand Up @@ -612,14 +597,12 @@ class LeveledSHEBase {
* Generates evaluation keys for a list of indices
* Currently works only for power-of-two and cyclic-group cyclotomics
*
* @param publicKey encryption key for the new ciphertext.
* @param origPrivateKey original private key used for decryption.
* @param indexList list of indices to be computed
* @return returns the evaluation keys
*/
virtual std::shared_ptr<std::map<usint, EvalKey<Element>>> EvalAtIndexKeyGen(
const PublicKey<Element> publicKey, const PrivateKey<Element> privateKey,
const std::vector<int32_t>& indexList) const;
const PrivateKey<Element> privateKey, const std::vector<int32_t>& indexList) const;

/**
* Moves i-th slot to slot 0
Expand Down
7 changes: 1 addition & 6 deletions src/pke/include/schemebase/base-scheme.h
Original file line number Diff line number Diff line change
Expand Up @@ -833,10 +833,6 @@ class SchemeBase {
virtual std::shared_ptr<std::map<usint, EvalKey<Element>>> EvalAutomorphismKeyGen(
const PrivateKey<Element> privateKey, const std::vector<usint>& indexList) const;

virtual std::shared_ptr<std::map<usint, EvalKey<Element>>> EvalAutomorphismKeyGen(
const PublicKey<Element> publicKey, const PrivateKey<Element> privateKey,
const std::vector<usint>& indexList) const;

virtual Ciphertext<Element> EvalAutomorphism(ConstCiphertext<Element> ciphertext, usint i,
const std::map<usint, EvalKey<Element>>& evalKeyMap,
CALLER_INFO_ARGS_HDR) const {
Expand Down Expand Up @@ -910,8 +906,7 @@ class SchemeBase {
}

virtual std::shared_ptr<std::map<usint, EvalKey<Element>>> EvalAtIndexKeyGen(
const PublicKey<Element> publicKey, const PrivateKey<Element> privateKey,
const std::vector<int32_t>& indexList) const;
const PrivateKey<Element> privateKey, const std::vector<int32_t>& indexList) const;

virtual Ciphertext<Element> EvalAtIndex(ConstCiphertext<Element> ciphertext, usint i,
const std::map<usint, EvalKey<Element>>& evalKeyMap) const {
Expand Down
10 changes: 2 additions & 8 deletions src/pke/lib/cryptocontext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,10 @@ void CryptoContextImpl<Element>::InsertEvalSumKey(const std::shared_ptr<std::map

template <typename Element>
void CryptoContextImpl<Element>::EvalAtIndexKeyGen(const PrivateKey<Element> privateKey,
const std::vector<int32_t>& indexList,
const PublicKey<Element> publicKey) {
const std::vector<int32_t>& indexList) {
ValidateKey(privateKey);

if (publicKey != nullptr && privateKey->GetKeyTag() != publicKey->GetKeyTag()) {
OPENFHE_THROW(config_error, "Public key passed to EvalAtIndexKeyGen does not match private key");
}

auto evalKeys = GetScheme()->EvalAtIndexKeyGen(publicKey, privateKey, indexList);

auto evalKeys = GetScheme()->EvalAtIndexKeyGen(privateKey, indexList);
InsertEvalAutomorphismKey(evalKeys, privateKey->GetKeyTag());
}

Expand Down
16 changes: 8 additions & 8 deletions src/pke/lib/scheme/ckksrns/ckksrns-fhe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>> FHECKKSRNS::EvalBootstrapKey
slots = M / 4;
// computing all indices for baby-step giant-step procedure
auto algo = cc->GetScheme();
auto evalKeys = algo->EvalAtIndexKeyGen(nullptr, privateKey, FindBootstrapRotationIndices(slots, M));
auto evalKeys = algo->EvalAtIndexKeyGen(privateKey, FindBootstrapRotationIndices(slots, M));

auto conjKey = ConjugateKeyGen(privateKey);
(*evalKeys)[M - 1] = conjKey;
Expand Down Expand Up @@ -2250,8 +2250,8 @@ Plaintext FHECKKSRNS::MakeAuxPlaintext(const CryptoContextImpl<DCRTPoly>& cc, co
if (logc < 0) {
OPENFHE_THROW(math_error, "Too small scaling factor");
}
int32_t logValid = (logc <= MAX_BITS_IN_WORD) ? logc : MAX_BITS_IN_WORD;
int32_t logApprox = logc - logValid;
int32_t logValid = (logc <= MAX_BITS_IN_WORD) ? logc : MAX_BITS_IN_WORD;
int32_t logApprox = logc - logValid;
double approxFactor = pow(2, logApprox);

std::vector<int64_t> temp(2 * slots);
Expand Down Expand Up @@ -2282,11 +2282,11 @@ Plaintext FHECKKSRNS::MakeAuxPlaintext(const CryptoContextImpl<DCRTPoly>& cc, co
double imagVal = prodFactor.imag();

if (realVal > realMax) {
realMax = realVal;
realMax = realVal;
realMaxIdx = idx;
}
if (imagVal > imagMax) {
imagMax = imagVal;
imagMax = imagVal;
imagMaxIdx = idx;
}
}
Expand All @@ -2309,11 +2309,11 @@ Plaintext FHECKKSRNS::MakeAuxPlaintext(const CryptoContextImpl<DCRTPoly>& cc, co
int64_t re = std::llround(dre);
int64_t im = std::llround(dim);

temp[i] = (re < 0) ? Max64BitValue() + re : re;
temp[i] = (re < 0) ? Max64BitValue() + re : re;
temp[i + slots] = (im < 0) ? Max64BitValue() + im : im;
}

const std::shared_ptr<ILDCRTParams<BigInteger>> bigParams = plainElement.GetParams();
const std::shared_ptr<ILDCRTParams<BigInteger>> bigParams = plainElement.GetParams();
const std::vector<std::shared_ptr<ILNativeParams>>& nativeParams = bigParams->GetParams();

for (size_t i = 0; i < nativeParams.size(); i++) {
Expand Down Expand Up @@ -2349,7 +2349,7 @@ Plaintext FHECKKSRNS::MakeAuxPlaintext(const CryptoContextImpl<DCRTPoly>& cc, co
// Scale back up by the approxFactor to get the correct encoding.
if (logApprox > 0) {
int32_t logStep = (logApprox <= MAX_LOG_STEP) ? logApprox : MAX_LOG_STEP;
auto intStep = DCRTPoly::Integer(uint64_t(1) << logStep);
auto intStep = DCRTPoly::Integer(uint64_t(1) << logStep);
std::vector<DCRTPoly::Integer> crtApprox(numTowers, intStep);
logApprox -= logStep;

Expand Down
8 changes: 4 additions & 4 deletions src/pke/lib/scheme/ckksrns/ckksrns-schemeswitching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,7 @@ std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>> SWITCHCKKSRNS::EvalCKKStoFHE
indexRotationS2C.erase(unique(indexRotationS2C.begin(), indexRotationS2C.end()), indexRotationS2C.end());

auto algo = ccCKKS->GetScheme();
auto evalKeys = algo->EvalAtIndexKeyGen(publicKey, privateKey, indexRotationS2C);
auto evalKeys = algo->EvalAtIndexKeyGen(privateKey, indexRotationS2C);

const DCRTPoly& s = privateKey->GetPrivateElement();
usint N = s.GetRingDimension();
Expand Down Expand Up @@ -1386,7 +1386,7 @@ std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>> SWITCHCKKSRNS::EvalFHEWtoCKK
indexRotationHomDec.end());

auto algo = ccCKKS->GetScheme();
auto evalKeys = algo->EvalAtIndexKeyGen(publicKey, privateKey, indexRotationHomDec);
auto evalKeys = algo->EvalAtIndexKeyGen(privateKey, indexRotationHomDec);

// Compute multiplication key
ccCKKS->EvalMultKeyGen(privateKey);
Expand Down Expand Up @@ -1718,7 +1718,7 @@ std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>> SWITCHCKKSRNS::EvalSchemeSwi
indexRotationS2C.erase(unique(indexRotationS2C.begin(), indexRotationS2C.end()), indexRotationS2C.end());

auto algo = ccCKKS->GetScheme();
auto evalKeys = algo->EvalAtIndexKeyGen(publicKey, privateKey, indexRotationS2C);
auto evalKeys = algo->EvalAtIndexKeyGen(privateKey, indexRotationS2C);

// Compute conjugation key
const DCRTPoly& s = privateKey->GetPrivateElement();
Expand Down Expand Up @@ -1869,7 +1869,7 @@ std::vector<Ciphertext<DCRTPoly>> SWITCHCKKSRNS::EvalMinSchemeSwitching(ConstCip
std::vector<std::complex<double>> ones(numValues / (2 * M), 1.0);
Plaintext ptxtOnes = cc->MakeCKKSPackedPlaintext(ones, 1, 0, nullptr, slots);
cSelect = cc->EvalAdd(
cSelect, cc->EvalAtIndex(cc->EvalSub(ptxtOnes, cSelect), -static_cast<int32_t>(numValues / (2 * M))));
cSelect, cc->EvalAtIndex(cc->EvalSub(ptxtOnes, cSelect), -static_cast<int32_t>(numValues / (2 * M))));

auto cExpandSelect = cSelect;
if (M > 1) {
Expand Down
3 changes: 1 addition & 2 deletions src/pke/lib/schemebase/base-leveledshe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,7 @@ Ciphertext<Element> LeveledSHEBase<Element>::EvalFastRotation(

template <class Element>
std::shared_ptr<std::map<usint, EvalKey<Element>>> LeveledSHEBase<Element>::EvalAtIndexKeyGen(
const PublicKey<Element> publicKey, const PrivateKey<Element> privateKey,
const std::vector<int32_t>& indexList) const {
const PrivateKey<Element> privateKey, const std::vector<int32_t>& indexList) const {
const auto cc = privateKey->GetCryptoContext();

usint M = privateKey->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder();
Expand Down
21 changes: 2 additions & 19 deletions src/pke/lib/schemebase/base-scheme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,12 @@ std::vector<EvalKey<Element>> SchemeBase<Element>::EvalMultKeysGen(const Private

template <typename Element>
std::shared_ptr<std::map<usint, EvalKey<Element>>> SchemeBase<Element>::EvalAtIndexKeyGen(
const PublicKey<Element> publicKey, const PrivateKey<Element> privateKey,
const std::vector<int32_t>& indexList) const {
const PrivateKey<Element> privateKey, const std::vector<int32_t>& indexList) const {
VerifyLeveledSHEEnabled(__func__);
if (!privateKey)
OPENFHE_THROW(config_error, "Input private key is nullptr");

auto evalKeyMap = m_LeveledSHE->EvalAtIndexKeyGen(publicKey, privateKey, indexList);
auto evalKeyMap = m_LeveledSHE->EvalAtIndexKeyGen(privateKey, indexList);
for (auto& key : *evalKeyMap)
key.second->SetKeyTag(privateKey->GetKeyTag());
return evalKeyMap;
Expand Down Expand Up @@ -428,22 +427,6 @@ std::shared_ptr<std::map<usint, EvalKey<Element>>> SchemeBase<Element>::EvalAuto
return evalKeyMap;
}

template <typename Element>
std::shared_ptr<std::map<usint, EvalKey<Element>>> SchemeBase<Element>::EvalAutomorphismKeyGen(
const PublicKey<Element> publicKey, const PrivateKey<Element> privateKey,
const std::vector<usint>& indexList) const {
VerifyLeveledSHEEnabled(__func__);
if (!publicKey)
OPENFHE_THROW(config_error, "Input public key is nullptr");
if (!privateKey)
OPENFHE_THROW(config_error, "Input private key is nullptr");

auto evalKeyMap = m_LeveledSHE->EvalAutomorphismKeyGen(publicKey, privateKey, indexList);
for (auto& key : *evalKeyMap)
key.second->SetKeyTag(privateKey->GetKeyTag());
return evalKeyMap;
}

template class SchemeBase<DCRTPoly>;

} // namespace lbcrypto

0 comments on commit 57cd4de

Please sign in to comment.