From 11e13198c80cf331a933ba16318c4caef4a1d38e Mon Sep 17 00:00:00 2001 From: Dmitriy Suponitskiy Date: Mon, 25 Nov 2024 23:26:30 -0500 Subject: [PATCH] Fixed syntax for multiple virtual functions, virtual destructors, made some print functions protected to force users to use operator<<(), fixed multiple issues, etc. --- src/pke/include/encoding/ckkspackedencoding.h | 72 ++++++-------- src/pke/include/encoding/coefpackedencoding.h | 99 ++++++++++--------- src/pke/include/encoding/packedencoding.h | 92 +++++++++-------- src/pke/include/encoding/plaintext.h | 98 ++++++++++-------- src/pke/include/encoding/stringencoding.h | 54 +++++----- .../schemebase/base-cryptoparameters.h | 13 ++- .../schemebase/rlwe-cryptoparameters.h | 26 ++--- .../include/schemerns/rns-cryptoparameters.h | 10 +- 8 files changed, 234 insertions(+), 230 deletions(-) diff --git a/src/pke/include/encoding/ckkspackedencoding.h b/src/pke/include/encoding/ckkspackedencoding.h index 4a3237028..c53698a6b 100644 --- a/src/pke/include/encoding/ckkspackedencoding.h +++ b/src/pke/include/encoding/ckkspackedencoding.h @@ -64,7 +64,7 @@ class CKKSPackedEncoding : public PlaintextImpl { std::is_same::value || std::is_same::value, bool>::type = true> - CKKSPackedEncoding(std::shared_ptr vp, EncodingParams ep) : PlaintextImpl(vp, ep, CKKSRNS_SCHEME) { + CKKSPackedEncoding(std::shared_ptr vp, EncodingParams ep) : PlaintextImpl(vp, ep, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME) { this->slots = GetDefaultSlotSize(); if (this->slots > (GetElementRingDimension() / 2)) { OPENFHE_THROW("The number of slots cannot be larger than half of ring dimension"); @@ -83,7 +83,7 @@ class CKKSPackedEncoding : public PlaintextImpl { bool>::type = true> CKKSPackedEncoding(std::shared_ptr vp, EncodingParams ep, const std::vector>& coeffs, size_t noiseScaleDeg, uint32_t level, double scFact, size_t slots) - : PlaintextImpl(vp, ep, CKKSRNS_SCHEME), value(coeffs) { + : PlaintextImpl(vp, ep, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME), value(coeffs) { // validate the number of slots if ((slots & (slots - 1)) != 0) { OPENFHE_THROW("The number of slots should be a power of two"); @@ -109,7 +109,7 @@ class CKKSPackedEncoding : public PlaintextImpl { * @param rhs - The input object to copy. */ explicit CKKSPackedEncoding(const std::vector>& rhs, size_t slots) - : PlaintextImpl(std::shared_ptr(0), nullptr, CKKSRNS_SCHEME), value(rhs) { + : PlaintextImpl(std::shared_ptr(0), nullptr, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME), value(rhs) { // validate the number of slots if ((slots & (slots - 1)) != 0) { OPENFHE_THROW("The number of slots should be a power of two"); @@ -128,7 +128,7 @@ class CKKSPackedEncoding : public PlaintextImpl { /** * @brief Default empty constructor with empty uninitialized data elements. */ - CKKSPackedEncoding() : PlaintextImpl(std::shared_ptr(0), nullptr, CKKSRNS_SCHEME) { + CKKSPackedEncoding() : PlaintextImpl(std::shared_ptr(0), nullptr, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME) { this->slots = GetDefaultSlotSize(); if (this->slots > (GetElementRingDimension() / 2)) { OPENFHE_THROW("The number of slots cannot be larger than half of ring dimension"); @@ -147,7 +147,7 @@ class CKKSPackedEncoding : public PlaintextImpl { OPENFHE_THROW("CKKSPackedEncoding::Decode() is not implemented. Use CKKSPackedEncoding::Decode(...) instead."); } - bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode); + bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode) override; const std::vector>& GetCKKSPackedValue() const override { return value; @@ -175,14 +175,6 @@ class CKKSPackedEncoding : public PlaintextImpl { const std::vector& b, const std::vector& mods); - /** - * GetEncodingType - * @return CKKS_PACKED_ENCODING - */ - PlaintextEncodings GetEncodingType() const override { - return CKKS_PACKED_ENCODING; - } - /** * Get method to return the length of plaintext * @@ -215,40 +207,16 @@ class CKKSPackedEncoding : public PlaintextImpl { value.resize(siz); } - /** - * Method to compare two plaintext to test for equivalence. This method does - * not test that the plaintext are of the same type. - * - * @param other - the other plaintext to compare to. - * @return whether the two plaintext are equivalent. - */ - bool CompareTo(const PlaintextImpl& other) const override { - const auto& rv = static_cast(other); - return this->value == rv.value; - } - /** * @brief Destructor method. */ static void Destroy(); - void PrintValue(std::ostream& out) const override { - // for sanity's sake, trailing zeros get elided into "..." - // out.precision(15); - out << "("; - size_t i = value.size(); - while (--i > 0) - if (value[i] != std::complex(0, 0)) - break; - - for (size_t j = 0; j <= i; j++) { - out << value[j].real() << ", "; - } - - out << " ... ); "; - out << "Estimated precision: " << encodingParams->GetPlaintextModulus() - m_logError << " bits" << std::endl; - } - + /** + * @brief GetFormattedValues() is called by operator<< and requires a precision as an argument + * @param precision number of decimal digits of precision to print + * @return string with all values and "estimated precision" + */ std::string GetFormattedValues(int64_t precision) const override { std::stringstream ss; ss << "("; @@ -279,10 +247,30 @@ class CKKSPackedEncoding : public PlaintextImpl { double m_logError = 0; protected: + void PrintValue(std::ostream& out) const override { + out << GetFormattedValues(8) << std::endl; + } + usint GetDefaultSlotSize() { auto batchSize = GetEncodingParams()->GetBatchSize(); return (0 == batchSize) ? GetElementRingDimension() / 2 : batchSize; } + + /** + * Method to compare two plaintext to test for equivalence. This method does + * not test that the plaintext are of the same type. + * + * @param rhs - the other plaintext to compare to. + * @return whether the two plaintext are equivalent. + */ + bool CompareTo(const PlaintextImpl& rhs) const override { + const auto* el = dynamic_cast(&rhs); + if (el == nullptr) + return false; + + return this->value == el->value; + } + /** * Set modulus and recalculates the vector values to fit the modulus * diff --git a/src/pke/include/encoding/coefpackedencoding.h b/src/pke/include/encoding/coefpackedencoding.h index 089bbe1a8..f89f19808 100644 --- a/src/pke/include/encoding/coefpackedencoding.h +++ b/src/pke/include/encoding/coefpackedencoding.h @@ -48,13 +48,54 @@ namespace lbcrypto { class CoefPackedEncoding : public PlaintextImpl { std::vector value; +protected: + /** + * @brief PrintValue() is called by operator<< + * @param out stream to print to + */ + void PrintValue(std::ostream& out) const override { + out << "("; + + // for sanity's sake: get rid of all trailing zeroes and print "..." instead + size_t i = value.size(); + bool allZeroes = true; + while (i > 0) { + --i; + if (value[i] != 0) { + allZeroes = false; + break; + } + } + + if (allZeroes == false) { + for (size_t j = 0; j <= i; ++j) + out << value[j] << ", "; + } + out << "... )"; + } + + /** + * Method to compare two plaintext to test for equivalence + * Testing that the plaintexts are of the same type done in operator== + * + * @param rhs - the other plaintext to compare to. + * @return whether the two plaintext are equivalent. + */ + bool CompareTo(const PlaintextImpl& rhs) const override { + const auto* el = dynamic_cast(&rhs); + if (el == nullptr) + return false; + + return this->value == el->value; + } + public: template ::value || std::is_same::value || std::is_same::value, bool>::type = true> CoefPackedEncoding(std::shared_ptr vp, EncodingParams ep, SCHEME schemeId = SCHEME::INVALID_SCHEME) - : PlaintextImpl(vp, ep, schemeId) {} + : PlaintextImpl(vp, ep, COEF_PACKED_ENCODING, schemeId) {} template ::value || std::is_same::value || @@ -62,15 +103,15 @@ class CoefPackedEncoding : public PlaintextImpl { bool>::type = true> CoefPackedEncoding(std::shared_ptr vp, EncodingParams ep, const std::vector& coeffs, SCHEME schemeId = SCHEME::INVALID_SCHEME) - : PlaintextImpl(vp, ep, schemeId), value(coeffs) {} + : PlaintextImpl(vp, ep, COEF_PACKED_ENCODING, schemeId), value(coeffs) {} - virtual ~CoefPackedEncoding() = default; + ~CoefPackedEncoding() override = default; /** * GetCoeffsValue * @return the un-encoded scalar */ - const std::vector& GetCoefPackedValue() const { + const std::vector& GetCoefPackedValue() const override { return value; } @@ -78,7 +119,7 @@ class CoefPackedEncoding : public PlaintextImpl { * SetIntVectorValue * @param val integer vector to initialize the plaintext */ - void SetIntVectorValue(const std::vector& val) { + void SetIntVectorValue(const std::vector& val) override { value = val; } @@ -86,28 +127,20 @@ class CoefPackedEncoding : public PlaintextImpl { * Encode the plaintext into the Poly * @return true on success */ - bool Encode(); + bool Encode() override; /** * Decode the Poly into the string * @return true on success - */ - bool Decode(); - - /** - * GetEncodingType - * @return this is a COEF_PACKED_ENCODING encoding - */ - PlaintextEncodings GetEncodingType() const { - return COEF_PACKED_ENCODING; - } + */ + bool Decode() override; /** * Get length of the plaintext * * @return number of elements in this plaintext */ - size_t GetLength() const { + size_t GetLength() const override { return value.size(); } @@ -115,39 +148,9 @@ class CoefPackedEncoding : public PlaintextImpl { * SetLength of the plaintext to the given size * @param siz */ - void SetLength(size_t siz) { + void SetLength(size_t siz) override { value.resize(siz); } - - /** - * Method to compare two plaintext to test for equivalence - * Testing that the plaintexts are of the same type done in operator== - * - * @param other - the other plaintext to compare to. - * @return whether the two plaintext are equivalent. - */ - bool CompareTo(const PlaintextImpl& other) const { - const auto& oth = static_cast(other); - return oth.value == this->value; - } - - /** - * PrintValue - used by operator<< for this object - * @param out - */ - void PrintValue(std::ostream& out) const { - // for sanity's sake, trailing zeros get elided into "..." - out << "("; - size_t i = value.size(); - while (--i > 0) - if (value[i] != 0) - break; - - for (size_t j = 0; j <= i; j++) - out << ' ' << value[j]; - - out << " ... )"; - } }; } /* namespace lbcrypto */ diff --git a/src/pke/include/encoding/packedencoding.h b/src/pke/include/encoding/packedencoding.h index 97a02c4a0..6d76f1467 100644 --- a/src/pke/include/encoding/packedencoding.h +++ b/src/pke/include/encoding/packedencoding.h @@ -70,21 +70,21 @@ class PackedEncoding : public PlaintextImpl { std::is_same::value || std::is_same::value, bool>::type = true> - PackedEncoding(std::shared_ptr vp, EncodingParams ep) : PlaintextImpl(vp, ep) {} + PackedEncoding(std::shared_ptr vp, EncodingParams ep) : PlaintextImpl(vp, ep, PACKED_ENCODING) {} template ::value || std::is_same::value || std::is_same::value, bool>::type = true> PackedEncoding(std::shared_ptr vp, EncodingParams ep, const std::vector& coeffs) - : PlaintextImpl(vp, ep), value(coeffs) {} + : PlaintextImpl(vp, ep, PACKED_ENCODING), value(coeffs) {} template ::value || std::is_same::value || std::is_same::value, bool>::type = true> PackedEncoding(std::shared_ptr vp, EncodingParams ep, std::initializer_list coeffs) - : PlaintextImpl(vp, ep), value(coeffs) {} + : PlaintextImpl(vp, ep, PACKED_ENCODING), value(coeffs) {} /** * @brief Constructs a container with a copy of each of the elements in rhs, @@ -92,7 +92,7 @@ class PackedEncoding : public PlaintextImpl { * @param rhs - The input object to copy. */ explicit PackedEncoding(const std::vector& rhs) - : PlaintextImpl(std::shared_ptr(0), nullptr), value(rhs) {} + : PlaintextImpl(std::shared_ptr(0), nullptr, PACKED_ENCODING), value(rhs) {} /** * @brief Constructs a container with a copy of each of the elements in il, in @@ -100,22 +100,22 @@ class PackedEncoding : public PlaintextImpl { * @param arr the list to copy. */ PackedEncoding(std::initializer_list arr) - : PlaintextImpl(std::shared_ptr(0), nullptr), value(arr) {} + : PlaintextImpl(std::shared_ptr(0), nullptr, PACKED_ENCODING), value(arr) {} /** * @brief Default empty constructor with empty uninitialized data elements. */ - PackedEncoding() : PlaintextImpl(std::shared_ptr(0), nullptr), value() {} + PackedEncoding() : PlaintextImpl(std::shared_ptr(0), nullptr, PACKED_ENCODING), value() {} static usint GetAutomorphismGenerator(usint m) { return m_automorphismGenerator[m]; } - bool Encode(); + bool Encode() override; - bool Decode(); + bool Decode() override; - const std::vector& GetPackedValue() const { + const std::vector& GetPackedValue() const override { return value; } @@ -123,24 +123,16 @@ class PackedEncoding : public PlaintextImpl { * SetIntVectorValue * @param val integer vector to initialize the plaintext */ - void SetIntVectorValue(const std::vector& val) { + void SetIntVectorValue(const std::vector& val) override { value = val; } - /** - * GetEncodingType - * @return PACKED_ENCODING - */ - PlaintextEncodings GetEncodingType() const { - return PACKED_ENCODING; - } - /** * Get method to return the length of plaintext * * @return the length of the plaintext in terms of the number of bits. */ - size_t GetLength() const { + size_t GetLength() const override { return value.size(); } @@ -164,39 +156,53 @@ class PackedEncoding : public PlaintextImpl { * SetLength of the plaintext to the given size * @param siz */ - void SetLength(size_t siz) { + void SetLength(size_t siz) override { value.resize(siz); } /** - * Method to compare two plaintext to test for equivalence. This method does - * not test that the plaintext are of the same type. - * - * @param other - the other plaintext to compare to. - * @return whether the two plaintext are equivalent. - */ - bool CompareTo(const PlaintextImpl& other) const { - const auto& rv = static_cast(other); - return this->value == rv.value; - } - - /** - * @brief Destructor method. - */ + * @brief Destructor method. + */ static void Destroy(); - void PrintValue(std::ostream& out) const { - // for sanity's sake, trailing zeros get elided into "..." +protected: + /** + * @brief PrintValue() is called by operator<< + * @param out stream to print to + */ + void PrintValue(std::ostream& out) const override { out << "("; - size_t i = value.size(); - while (--i > 0) - if (value[i] != 0) + // for sanity's sake: get rid of all trailing zeroes and print "..." instead + size_t i = value.size(); + bool allZeroes = true; + while (i > 0) { + --i; + if (value[i] != 0) { + allZeroes = false; break; + } + } + + if (allZeroes == false) { + for (size_t j = 0; j <= i; ++j) + out << value[j] << ", "; + } + out << "... )"; + } - for (size_t j = 0; j <= i; j++) - out << ' ' << value[j]; - - out << " ... )"; + /** + * Method to compare two plaintext to test for equivalence. This method does + * not test that the plaintext are of the same type. + * + * @param rhs - the other plaintext to compare to. + * @return whether the two plaintext are equivalent. + */ + bool CompareTo(const PlaintextImpl& rhs) const override { + const auto* el = dynamic_cast(&rhs); + if (el == nullptr) + return false; + + return this->value == el->value; } private: diff --git a/src/pke/include/encoding/plaintext.h b/src/pke/include/encoding/plaintext.h index 9e419c599..edba4ba12 100644 --- a/src/pke/include/encoding/plaintext.h +++ b/src/pke/include/encoding/plaintext.h @@ -83,32 +83,52 @@ class PlaintextImpl { size_t level = 0; size_t noiseScaleDeg = 1; usint slots = 0; + PlaintextEncodings ptxtEncoding = INVALID_ENCODING; SCHEME schemeID; +protected: + /** + * @brief PrintValue() is called by operator<< + * @param out + */ + virtual void PrintValue(std::ostream& out) const = 0; + + /** + * Method to compare two plaintext to test for equivalence. + * This method is called by operator== + * + * @param other - the other plaintext to compare to. + * @return whether the two plaintext are equivalent. + */ + virtual bool CompareTo(const PlaintextImpl& other) const = 0; + public: - PlaintextImpl(const std::shared_ptr& vp, EncodingParams ep, SCHEME schemeTag = SCHEME::INVALID_SCHEME, - bool isEncoded = false) + PlaintextImpl(const std::shared_ptr& vp, EncodingParams ep, PlaintextEncodings encoding, + SCHEME schemeTag = SCHEME::INVALID_SCHEME, bool isEncoded = false) : isEncoded(isEncoded), typeFlag(IsPoly), encodingParams(std::move(ep)), encodedVector(vp, Format::COEFFICIENT), + ptxtEncoding(encoding), schemeID(schemeTag) {} - PlaintextImpl(const std::shared_ptr& vp, EncodingParams ep, + PlaintextImpl(const std::shared_ptr& vp, EncodingParams ep, PlaintextEncodings encoding, SCHEME schemeTag = SCHEME::INVALID_SCHEME, bool isEncoded = false) : isEncoded(isEncoded), typeFlag(IsNativePoly), encodingParams(std::move(ep)), encodedNativeVector(vp, Format::COEFFICIENT), + ptxtEncoding(encoding), schemeID(schemeTag) {} - PlaintextImpl(const std::shared_ptr& vp, EncodingParams ep, + PlaintextImpl(const std::shared_ptr& vp, EncodingParams ep, PlaintextEncodings encoding, SCHEME schemeTag = SCHEME::INVALID_SCHEME, bool isEncoded = false) : isEncoded(isEncoded), typeFlag(IsDCRTPoly), encodingParams(std::move(ep)), encodedVector(vp, Format::COEFFICIENT), encodedVectorDCRT(vp, Format::COEFFICIENT), + ptxtEncoding(encoding), schemeID(schemeTag) {} PlaintextImpl(const PlaintextImpl& rhs) @@ -122,6 +142,7 @@ class PlaintextImpl { level(rhs.level), noiseScaleDeg(rhs.noiseScaleDeg), slots(rhs.slots), + ptxtEncoding(rhs.ptxtEncoding), schemeID(rhs.schemeID) {} PlaintextImpl(PlaintextImpl&& rhs) @@ -135,15 +156,18 @@ class PlaintextImpl { level(rhs.level), noiseScaleDeg(rhs.noiseScaleDeg), slots(rhs.slots), + ptxtEncoding(rhs.ptxtEncoding), schemeID(rhs.schemeID) {} - virtual ~PlaintextImpl() {} + virtual ~PlaintextImpl() = default; /** * GetEncodingType * @return Encoding type used by this plaintext */ - virtual PlaintextEncodings GetEncodingType() const = 0; + PlaintextEncodings GetEncodingType() const { + return ptxtEncoding; + } /** * Get the scaling factor of the plaintext for CKKS-based plaintexts. @@ -203,10 +227,13 @@ class PlaintextImpl { virtual bool Encode() = 0; /** - * Decode the polynomial into the plaintext + * @brief Decode the polynomial into the plaintext * @return */ virtual bool Decode() = 0; + virtual bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode) { + OPENFHE_THROW("Not implemented"); + } /** * Calculate and return lower bound that can be encoded with the plaintext @@ -358,7 +385,7 @@ class PlaintextImpl { OPENFHE_THROW("not a packed coefficient vector"); } virtual const std::vector& GetPackedValue() const { - OPENFHE_THROW("not a packed coefficient vector"); + OPENFHE_THROW("not a packed vector"); } virtual const std::vector>& GetCKKSPackedValue() const { OPENFHE_THROW("not a packed vector of complex numbers"); @@ -373,15 +400,6 @@ class PlaintextImpl { OPENFHE_THROW("does not support an int vector"); } - /** - * Method to compare two plaintext to test for equivalence. - * This method is called by operator== - * - * @param other - the other plaintext to compare to. - * @return whether the two plaintext are equivalent. - */ - virtual bool CompareTo(const PlaintextImpl& other) const = 0; - /** * operator== for plaintexts. This method makes sure the plaintexts are of * the same type. @@ -398,39 +416,33 @@ class PlaintextImpl { } /** - * operator<< for ostream integration - calls PrintValue - * @param out - * @param item - * @return - */ - friend std::ostream& operator<<(std::ostream& out, const PlaintextImpl& item); - - /** - * PrintValue is called by operator<< - * @param out - */ - virtual void PrintValue(std::ostream& out) const = 0; + * @brief operator<< for ostream integration - calls PrintValue() + * @param out + * @param item + * @return + */ + friend std::ostream& operator<<(std::ostream& out, const PlaintextImpl& item) { + item.PrintValue(out); + return out; + } + friend std::ostream& operator<<(std::ostream& out, const Plaintext& item) { + if (item) + out << *item; // Call the non-pointer version + else + OPENFHE_THROW("Cannot de-reference nullptr for printing"); + return out; + } /** - * GetFormattedValues() has a logic similar to PrintValue(), but requires a precision as an argument - * @param precision number of decimal digits of precision to print - * @return string with all values and "estimated precision" - */ + * @brief GetFormattedValues() is similar to PrintValue() and requires a precision as an argument + * @param precision number of decimal digits of precision to print + * @return string with all values + */ virtual std::string GetFormattedValues(int64_t precision) const { OPENFHE_THROW("not implemented"); } }; -inline std::ostream& operator<<(std::ostream& out, const PlaintextImpl& item) { - item.PrintValue(out); - return out; -} - -inline std::ostream& operator<<(std::ostream& out, const Plaintext& item) { - item->PrintValue(out); - return out; -} - inline bool operator==(const Plaintext& p1, const Plaintext& p2) { return *p1 == *p2; } diff --git a/src/pke/include/encoding/stringencoding.h b/src/pke/include/encoding/stringencoding.h index 8ab1fbe27..146d24683 100644 --- a/src/pke/include/encoding/stringencoding.h +++ b/src/pke/include/encoding/stringencoding.h @@ -53,25 +53,25 @@ class StringEncoding : public PlaintextImpl { std::is_same::value || std::is_same::value, bool>::type = true> - StringEncoding(std::shared_ptr vp, EncodingParams ep) : PlaintextImpl(vp, ep) {} + StringEncoding(std::shared_ptr vp, EncodingParams ep) : PlaintextImpl(vp, ep, STRING_ENCODING) {} template ::value || std::is_same::value || std::is_same::value, bool>::type = true> StringEncoding(std::shared_ptr vp, EncodingParams ep, const std::string& str) - : PlaintextImpl(vp, ep), ptx(str) {} + : PlaintextImpl(vp, ep, STRING_ENCODING), ptx(str) {} // TODO provide wide-character version (for unicode); right now this class // only supports strings of 7-bit ASCII characters - virtual ~StringEncoding() {} + ~StringEncoding() override = default; /** * GetStringValue * @return the un-encoded string */ - const std::string& GetStringValue() const { + const std::string& GetStringValue() const override { return ptx; } @@ -79,7 +79,7 @@ class StringEncoding : public PlaintextImpl { * SetStringValue * @param val to initialize the Plaintext */ - void SetStringValue(const std::string& value) { + void SetStringValue(const std::string& value) override { ptx = value; } @@ -87,48 +87,44 @@ class StringEncoding : public PlaintextImpl { * Encode the plaintext into the Poly * @return true on success */ - bool Encode(); + bool Encode() override; /** * Decode the Poly into the string * @return true on success */ - bool Decode(); - - /** - * GetEncodingType - * @return STRING_ENCODING - */ - PlaintextEncodings GetEncodingType() const { - return STRING_ENCODING; - } + bool Decode() override; /** * Get length of the plaintext * * @return number of elements in this plaintext */ - size_t GetLength() const { + size_t GetLength() const override { return ptx.size(); } +protected: /** - * Method to compare two plaintext to test for equivalence - * Testing that the plaintexts are of the same type done in operator== - * - * @param other - the other plaintext to compare to. - * @return whether the two plaintext are equivalent. - */ - bool CompareTo(const PlaintextImpl& other) const { - const auto& oth = static_cast(other); - return oth.ptx == this->ptx; + * Method to compare two plaintext to test for equivalence + * Testing that the plaintexts are of the same type done in operator== + * + * @param rhs - the other plaintext to compare to. + * @return whether the two plaintext are equivalent. + */ + bool CompareTo(const PlaintextImpl& rhs) const override { + const auto* el = dynamic_cast(&rhs); + if (el == nullptr) + return false; + + return this->ptx == el->ptx; } /** - * PrintValue - used by operator<< for this object - * @param out - */ - void PrintValue(std::ostream& out) const { + * PrintValue - used by operator<< for this object + * @param out + */ + void PrintValue(std::ostream& out) const override { out << ptx; } }; diff --git a/src/pke/include/schemebase/base-cryptoparameters.h b/src/pke/include/schemebase/base-cryptoparameters.h index 1c0ad4989..5ae25e2c8 100644 --- a/src/pke/include/schemebase/base-cryptoparameters.h +++ b/src/pke/include/schemebase/base-cryptoparameters.h @@ -68,7 +68,7 @@ class CryptoParametersBase : public Serializable { * * @return the plaintext modulus. */ - virtual const PlaintextModulus& GetPlaintextModulus() const { + PlaintextModulus GetPlaintextModulus() const { return m_encodingParams->GetPlaintextModulus(); } @@ -77,7 +77,7 @@ class CryptoParametersBase : public Serializable { * * @return the ring element parameters. */ - virtual const std::shared_ptr GetElementParams() const { + const std::shared_ptr GetElementParams() const { return m_params; } @@ -88,14 +88,14 @@ class CryptoParametersBase : public Serializable { * * @return the encoding parameters. */ - virtual const EncodingParams GetEncodingParams() const { + const EncodingParams GetEncodingParams() const { return m_encodingParams; } /** * Sets the value of plaintext modulus p */ - virtual void SetPlaintextModulus(const PlaintextModulus& plaintextModulus) { + void SetPlaintextModulus(PlaintextModulus plaintextModulus) { m_encodingParams->SetPlaintextModulus(plaintextModulus); } @@ -119,7 +119,7 @@ class CryptoParametersBase : public Serializable { return out; } - virtual usint GetDigitSize() const { + virtual uint32_t GetDigitSize() const { return 0; } @@ -167,7 +167,7 @@ class CryptoParametersBase : public Serializable { ar(::cereal::make_nvp("enp", m_encodingParams)); } - std::string SerializedObjectName() const { + std::string SerializedObjectName() const override { return "CryptoParametersBase"; } static uint32_t SerializedVersion() { @@ -209,7 +209,6 @@ class CryptoParametersBase : public Serializable { out << "Encoding Parameters: " << *m_encodingParams << std::endl; } -protected: // element-specific parameters std::shared_ptr m_params; diff --git a/src/pke/include/schemebase/rlwe-cryptoparameters.h b/src/pke/include/schemebase/rlwe-cryptoparameters.h index eb8192815..86e4ae44c 100644 --- a/src/pke/include/schemebase/rlwe-cryptoparameters.h +++ b/src/pke/include/schemebase/rlwe-cryptoparameters.h @@ -129,7 +129,7 @@ class CryptoParametersRLWE : public CryptoParametersBase { /** * Virtual Destructor */ - ~CryptoParametersRLWE() = default; + ~CryptoParametersRLWE() override = default; /** * Returns the value of standard deviation r for discrete Gaussian @@ -174,7 +174,7 @@ class CryptoParametersRLWE : public CryptoParametersBase { * * @return the digit size. */ - usint GetDigitSize() const { + uint32_t GetDigitSize() const override { return m_digitSize; } @@ -184,7 +184,7 @@ class CryptoParametersRLWE : public CryptoParametersBase { * * @return maximum power of secret key */ - uint32_t GetMaxRelinSkDeg() const { + uint32_t GetMaxRelinSkDeg() const override { return m_maxRelinSkDeg; } @@ -414,14 +414,6 @@ class CryptoParametersRLWE : public CryptoParametersBase { m_thresholdNumOfParties = thresholdNumOfParties; } - void PrintParameters(std::ostream& os) const { - CryptoParametersBase::PrintParameters(os); - - os << "Distrib parm " << GetDistributionParameter() << ", Assurance measure " << GetAssuranceMeasure() - << ", Noise scale " << GetNoiseScale() << ", Digit Size " << GetDigitSize() << ", SecretKeyDist " - << GetSecretKeyDist() << ", Standard security level " << GetStdLevel() << std::endl; - } - template void save(Archive& ar, std::uint32_t const version) const { ar(::cereal::base_class>(this)); @@ -465,7 +457,7 @@ class CryptoParametersRLWE : public CryptoParametersBase { m_dggFlooding.SetStd(m_floodingDistributionParameter); } - std::string SerializedObjectName() const { + std::string SerializedObjectName() const override { return "CryptoParametersRLWE"; } @@ -479,7 +471,7 @@ class CryptoParametersRLWE : public CryptoParametersBase { // noise scale PlaintextModulus m_noiseScale = 1; // digit size - usint m_digitSize = 1; + uint32_t m_digitSize = 1; // the highest power of secret key for which relinearization key is generated uint32_t m_maxRelinSkDeg = 2; // specifies whether the secret polynomials are generated from discrete @@ -541,6 +533,14 @@ class CryptoParametersRLWE : public CryptoParametersBase { m_numAdversarialQueries == el->GetNumAdversarialQueries() && m_thresholdNumOfParties == el->GetThresholdNumOfParties(); } + + void PrintParameters(std::ostream& os) const override { + CryptoParametersBase::PrintParameters(os); + + os << "Distrib parm " << GetDistributionParameter() << ", Assurance measure " << GetAssuranceMeasure() + << ", Noise scale " << GetNoiseScale() << ", Digit Size " << GetDigitSize() << ", SecretKeyDist " + << GetSecretKeyDist() << ", Standard security level " << GetStdLevel() << std::endl; + } }; } // namespace lbcrypto diff --git a/src/pke/include/schemerns/rns-cryptoparameters.h b/src/pke/include/schemerns/rns-cryptoparameters.h index b252b9b56..75588b682 100644 --- a/src/pke/include/schemerns/rns-cryptoparameters.h +++ b/src/pke/include/schemerns/rns-cryptoparameters.h @@ -140,7 +140,7 @@ class CryptoParametersRNS : public CryptoParametersRLWE { m_MPIntBootCiphertextCompressionLevel = mPIntBootCiphertextCompressionLevel; } - ~CryptoParametersRNS() = default; + ~CryptoParametersRNS() override = default; /** * @brief CompareTo() is a method to compare two CryptoParametersRNS objects. @@ -160,6 +160,10 @@ class CryptoParametersRNS : public CryptoParametersRLWE { m_multipartyMode == el->GetMultipartyMode() && m_executionMode == el->GetExecutionMode(); } + void PrintParameters(std::ostream& os) const override { + CryptoParametersRLWE::PrintParameters(os); + } + public: /** * Computes all tables needed for decryption, homomorphic multiplication and key switching. @@ -201,10 +205,6 @@ class CryptoParametersRNS : public CryptoParametersRLWE { return static_cast(NoiseFlooding::MULTIPARTY_MOD_SIZE * NoiseFlooding::NUM_MODULI_MULTIPARTY); } - void PrintParameters(std::ostream& os) const override { - CryptoParametersBase::PrintParameters(os); - } - ///////////////////////////////////// // PrecomputeCRTTables /////////////////////////////////////