Skip to content

Commit

Permalink
Fixed syntax for multiple virtual functions, virtual destructors, mad…
Browse files Browse the repository at this point in the history
…e some print functions protected to force users to use operator<<(), fixed multiple issues, etc.
  • Loading branch information
dsuponitskiy-duality committed Nov 26, 2024
1 parent 8a5a5d7 commit 11e1319
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 230 deletions.
72 changes: 30 additions & 42 deletions src/pke/include/encoding/ckkspackedencoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class CKKSPackedEncoding : public PlaintextImpl {
std::is_same<T, NativePoly::Params>::value ||
std::is_same<T, DCRTPoly::Params>::value,
bool>::type = true>
CKKSPackedEncoding(std::shared_ptr<T> vp, EncodingParams ep) : PlaintextImpl(vp, ep, CKKSRNS_SCHEME) {
CKKSPackedEncoding(std::shared_ptr<T> vp, EncodingParams ep) : PlaintextImpl(vp, ep, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME) {
this->slots = GetDefaultSlotSize();
if (this->slots > (GetElementRingDimension() / 2)) {
OPENFHE_THROW("The number of slots cannot be larger than half of ring dimension");
Expand All @@ -83,7 +83,7 @@ class CKKSPackedEncoding : public PlaintextImpl {
bool>::type = true>
CKKSPackedEncoding(std::shared_ptr<T> vp, EncodingParams ep, const std::vector<std::complex<double>>& coeffs,
size_t noiseScaleDeg, uint32_t level, double scFact, size_t slots)
: PlaintextImpl(vp, ep, CKKSRNS_SCHEME), value(coeffs) {
: PlaintextImpl(vp, ep, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME), value(coeffs) {
// validate the number of slots
if ((slots & (slots - 1)) != 0) {
OPENFHE_THROW("The number of slots should be a power of two");
Expand All @@ -109,7 +109,7 @@ class CKKSPackedEncoding : public PlaintextImpl {
* @param rhs - The input object to copy.
*/
explicit CKKSPackedEncoding(const std::vector<std::complex<double>>& rhs, size_t slots)
: PlaintextImpl(std::shared_ptr<Poly::Params>(0), nullptr, CKKSRNS_SCHEME), value(rhs) {
: PlaintextImpl(std::shared_ptr<Poly::Params>(0), nullptr, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME), value(rhs) {
// validate the number of slots
if ((slots & (slots - 1)) != 0) {
OPENFHE_THROW("The number of slots should be a power of two");
Expand All @@ -128,7 +128,7 @@ class CKKSPackedEncoding : public PlaintextImpl {
/**
* @brief Default empty constructor with empty uninitialized data elements.
*/
CKKSPackedEncoding() : PlaintextImpl(std::shared_ptr<Poly::Params>(0), nullptr, CKKSRNS_SCHEME) {
CKKSPackedEncoding() : PlaintextImpl(std::shared_ptr<Poly::Params>(0), nullptr, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME) {
this->slots = GetDefaultSlotSize();
if (this->slots > (GetElementRingDimension() / 2)) {
OPENFHE_THROW("The number of slots cannot be larger than half of ring dimension");
Expand All @@ -147,7 +147,7 @@ class CKKSPackedEncoding : public PlaintextImpl {
OPENFHE_THROW("CKKSPackedEncoding::Decode() is not implemented. Use CKKSPackedEncoding::Decode(...) instead.");
}

bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode);
bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode) override;

const std::vector<std::complex<double>>& GetCKKSPackedValue() const override {
return value;
Expand Down Expand Up @@ -175,14 +175,6 @@ class CKKSPackedEncoding : public PlaintextImpl {
const std::vector<DCRTPoly::Integer>& b,
const std::vector<DCRTPoly::Integer>& mods);

/**
* GetEncodingType
* @return CKKS_PACKED_ENCODING
*/
PlaintextEncodings GetEncodingType() const override {
return CKKS_PACKED_ENCODING;
}

/**
* Get method to return the length of plaintext
*
Expand Down Expand Up @@ -215,40 +207,16 @@ class CKKSPackedEncoding : public PlaintextImpl {
value.resize(siz);
}

/**
* Method to compare two plaintext to test for equivalence. This method does
* not test that the plaintext are of the same type.
*
* @param other - the other plaintext to compare to.
* @return whether the two plaintext are equivalent.
*/
bool CompareTo(const PlaintextImpl& other) const override {
const auto& rv = static_cast<const CKKSPackedEncoding&>(other);
return this->value == rv.value;
}

/**
* @brief Destructor method.
*/
static void Destroy();

void PrintValue(std::ostream& out) const override {
// for sanity's sake, trailing zeros get elided into "..."
// out.precision(15);
out << "(";
size_t i = value.size();
while (--i > 0)
if (value[i] != std::complex<double>(0, 0))
break;

for (size_t j = 0; j <= i; j++) {
out << value[j].real() << ", ";
}

out << " ... ); ";
out << "Estimated precision: " << encodingParams->GetPlaintextModulus() - m_logError << " bits" << std::endl;
}

/**
* @brief GetFormattedValues() is called by operator<< and requires a precision as an argument
* @param precision number of decimal digits of precision to print
* @return string with all values and "estimated precision"
*/
std::string GetFormattedValues(int64_t precision) const override {
std::stringstream ss;
ss << "(";
Expand Down Expand Up @@ -279,10 +247,30 @@ class CKKSPackedEncoding : public PlaintextImpl {
double m_logError = 0;

protected:
void PrintValue(std::ostream& out) const override {
out << GetFormattedValues(8) << std::endl;
}

usint GetDefaultSlotSize() {
auto batchSize = GetEncodingParams()->GetBatchSize();
return (0 == batchSize) ? GetElementRingDimension() / 2 : batchSize;
}

/**
* Method to compare two plaintext to test for equivalence. This method does
* not test that the plaintext are of the same type.
*
* @param rhs - the other plaintext to compare to.
* @return whether the two plaintext are equivalent.
*/
bool CompareTo(const PlaintextImpl& rhs) const override {
const auto* el = dynamic_cast<const CKKSPackedEncoding*>(&rhs);
if (el == nullptr)
return false;

return this->value == el->value;
}

/**
* Set modulus and recalculates the vector values to fit the modulus
*
Expand Down
99 changes: 51 additions & 48 deletions src/pke/include/encoding/coefpackedencoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,106 +48,109 @@ namespace lbcrypto {
class CoefPackedEncoding : public PlaintextImpl {
std::vector<int64_t> value;

protected:
/**
* @brief PrintValue() is called by operator<<
* @param out stream to print to
*/
void PrintValue(std::ostream& out) const override {
out << "(";

// for sanity's sake: get rid of all trailing zeroes and print "..." instead
size_t i = value.size();
bool allZeroes = true;
while (i > 0) {
--i;
if (value[i] != 0) {
allZeroes = false;
break;
}
}

if (allZeroes == false) {
for (size_t j = 0; j <= i; ++j)
out << value[j] << ", ";
}
out << "... )";
}

/**
* Method to compare two plaintext to test for equivalence
* Testing that the plaintexts are of the same type done in operator==
*
* @param rhs - the other plaintext to compare to.
* @return whether the two plaintext are equivalent.
*/
bool CompareTo(const PlaintextImpl& rhs) const override {
const auto* el = dynamic_cast<const CoefPackedEncoding*>(&rhs);
if (el == nullptr)
return false;

return this->value == el->value;
}

public:
template <typename T, typename std::enable_if<std::is_same<T, Poly::Params>::value ||
std::is_same<T, NativePoly::Params>::value ||
std::is_same<T, DCRTPoly::Params>::value,
bool>::type = true>
CoefPackedEncoding(std::shared_ptr<T> vp, EncodingParams ep, SCHEME schemeId = SCHEME::INVALID_SCHEME)
: PlaintextImpl(vp, ep, schemeId) {}
: PlaintextImpl(vp, ep, COEF_PACKED_ENCODING, schemeId) {}

template <typename T, typename std::enable_if<std::is_same<T, Poly::Params>::value ||
std::is_same<T, NativePoly::Params>::value ||
std::is_same<T, DCRTPoly::Params>::value,
bool>::type = true>
CoefPackedEncoding(std::shared_ptr<T> vp, EncodingParams ep, const std::vector<int64_t>& coeffs,
SCHEME schemeId = SCHEME::INVALID_SCHEME)
: PlaintextImpl(vp, ep, schemeId), value(coeffs) {}
: PlaintextImpl(vp, ep, COEF_PACKED_ENCODING, schemeId), value(coeffs) {}

virtual ~CoefPackedEncoding() = default;
~CoefPackedEncoding() override = default;

/**
* GetCoeffsValue
* @return the un-encoded scalar
*/
const std::vector<int64_t>& GetCoefPackedValue() const {
const std::vector<int64_t>& GetCoefPackedValue() const override {
return value;
}

/**
* SetIntVectorValue
* @param val integer vector to initialize the plaintext
*/
void SetIntVectorValue(const std::vector<int64_t>& val) {
void SetIntVectorValue(const std::vector<int64_t>& val) override {
value = val;
}

/**
* Encode the plaintext into the Poly
* @return true on success
*/
bool Encode();
bool Encode() override;

/**
* Decode the Poly into the string
* @return true on success
*/
bool Decode();

/**
* GetEncodingType
* @return this is a COEF_PACKED_ENCODING encoding
*/
PlaintextEncodings GetEncodingType() const {
return COEF_PACKED_ENCODING;
}
*/
bool Decode() override;

/**
* Get length of the plaintext
*
* @return number of elements in this plaintext
*/
size_t GetLength() const {
size_t GetLength() const override {
return value.size();
}

/**
* SetLength of the plaintext to the given size
* @param siz
*/
void SetLength(size_t siz) {
void SetLength(size_t siz) override {
value.resize(siz);
}

/**
* Method to compare two plaintext to test for equivalence
* Testing that the plaintexts are of the same type done in operator==
*
* @param other - the other plaintext to compare to.
* @return whether the two plaintext are equivalent.
*/
bool CompareTo(const PlaintextImpl& other) const {
const auto& oth = static_cast<const CoefPackedEncoding&>(other);
return oth.value == this->value;
}

/**
* PrintValue - used by operator<< for this object
* @param out
*/
void PrintValue(std::ostream& out) const {
// for sanity's sake, trailing zeros get elided into "..."
out << "(";
size_t i = value.size();
while (--i > 0)
if (value[i] != 0)
break;

for (size_t j = 0; j <= i; j++)
out << ' ' << value[j];

out << " ... )";
}
};

} /* namespace lbcrypto */
Expand Down
Loading

0 comments on commit 11e1319

Please sign in to comment.