diff --git a/.gitignore b/.gitignore index 3e759b75b..f8bc871b2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,50 @@ +# Other stuff +examples/CMakeCache.txt +examples/CMakeFiles +examples/Makefile +examples/cmake_install.cmake +src/CMakeCache.txt +src/CMakeFiles +src/Makefile +src/.config +src/autom4te.cache/* +src/cmake/SEALConfig.cmake +src/cmake/SEALConfigVersion.cmake +src/cmake/SEALTargets.cmake +src/cmake_install.cmake +src/install_manifest.txt +src/seal/CMakeCache.txt +src/seal/CMakeFiles +src/seal/Makefile +src/seal/cmake_install.cmake +src/seal/util/CMakeCache.txt +src/seal/util/CMakeFiles +src/seal/util/Makefile +src/seal/util/cmake_install.cmake +src/seal/util/config.h +tests/CMakeCache.txt +tests/CMakeFiles +tests/Makefile +tests/cmake_install.cmake +tests/install_manifest.txt +tests/seal/CMakeCache.txt +tests/seal/CMakeFiles +tests/seal/Makefile +tests/seal/cmake_install.cmake +tests/seal/util/CMakeCache.txt +tests/seal/util/CMakeFiles +tests/seal/util/Makefile +tests/seal/util/cmake_install.cmake +.ycm_extra_conf.py +.vimrc +.lvimrc +.local_vimrc +*/*.code-workspace +*/.vscode +*/build +*/*.build +*/compile_commands.json + ## Ignore Visual Studio temporary files, build results, and ## files generated by popular Visual Studio add-ons. ## @@ -23,15 +70,13 @@ bld/ [Bb]in/ [Oo]bj/ [Ll]og/ +[Ll]ib/ -# Visual Studio 2015/2017 cache/options directory +# Visual Studio 2015 cache/options directory .vs/ # Uncomment if you have tasks that create the project's static files in wwwroot #wwwroot/ -# Visual Studio 2017 auto generated files -Generated\ Files/ - # MSTest test Results [Tt]est[Rr]esult*/ [Bb]uild[Ll]og.* @@ -54,20 +99,14 @@ project.fragment.lock.json artifacts/ **/Properties/launchSettings.json -# StyleCop -StyleCopReport.xml - -# Files built by Visual Studio *_i.c *_p.c *_i.h *.ilk *.meta *.obj -*.iobj *.pch *.pdb -*.ipdb *.pgc *.pgd *.rsp @@ -221,10 +260,6 @@ ClientBin/ *.publishsettings orleans.codegen.cs -# Including strong name files can present a security risk -# (https://github.com/github/gitignore/pull/2483#issue-259490424) -#*.snk - # Since there are multiple workflows, uncomment next line to ignore bower_components # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) #bower_components/ @@ -239,8 +274,6 @@ _UpgradeReport_Files/ Backup*/ UpgradeLog*.XML UpgradeLog*.htm -ServiceFabricBackup/ -*.rptproj.bak # SQL Server files *.mdf @@ -251,7 +284,6 @@ ServiceFabricBackup/ *.rdl.data *.bim.layout *.bim_*.settings -*.rptproj.rsuser # Microsoft Fakes FakesAssemblies/ @@ -263,6 +295,9 @@ FakesAssemblies/ .ntvs_analysis.dat node_modules/ +# Typescript v1 declaration files +typings/ + # Visual Studio 6 build log *.plg @@ -316,15 +351,3 @@ __pycache__/ # OpenCover UI analysis results OpenCover/ - -# Azure Stream Analytics local run output -ASALocalRun/ - -# MSBuild Binary and Structured Log -*.binlog - -# NVidia Nsight GPU debugger configuration file -*.nvuser - -# MFractors (Xamarin productivity tool) working folder -.mfractor/ diff --git a/Contrib.md b/Contrib.md new file mode 100644 index 000000000..0e8e69417 --- /dev/null +++ b/Contrib.md @@ -0,0 +1,18 @@ +# Contributing + +This project welcomes contributions and suggestions. Most contributions require you +to agree to a Contributor License Agreement (CLA) declaring that you have the right to, +and actually do, grant us the rights to use your contribution. For details, visit +https://cla.microsoft.com. + +When you submit a pull request, a CLA-bot will automatically determine whether you need +to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow +the instructions provided by the bot. You will only need to do this once across all +repos using our CLA. + +This project has adopted the +[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). +For more information see the +[Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional +questions or comments. diff --git a/Issues.md b/Issues.md new file mode 100644 index 000000000..6d2149e8f --- /dev/null +++ b/Issues.md @@ -0,0 +1,27 @@ +# Issues + +## Technical questions + +The best way to get help with technical questions is on +[StackOverflow](https://stackoverflow.com/questions/tagged/seal) using the +[seal] tag. To contact the Microsoft SEAL team directly, please email +[sealcrypto@microsoft.com](mailto:sealcrypto@microsoft.com). + +## Bug reports + +We appreciate community efforts to find and fix bugs and issues in SEAL. If +you believe you have found a bug or want to report some other issue, please +do so on [GitHub](https://github.com/Microsoft/SEAL/issues). To help others +determine what the problem may be, we provide a helpful script that collects +relevant system information that you can submit with the bug report (see below). + +### System information + +To collect system information for an improved bug report, please run + make -C tools system_info +This will result in a file system\_info.tar.gz to be generated, which you can +optionally attach with your bug report. + +## Critical security issues + +For reporting critical security issues, see Security.md. diff --git a/README.md b/README.md index 441291b75..282c9a8dc 100644 --- a/README.md +++ b/README.md @@ -1,46 +1,141 @@ # Introduction -SEAL (Simple Encrypted Arithmetic Library) is an easy-to-use homomorphic encryption -library, developed by researchers in the Cryptography Research group at Microsoft -Research. SEAL is written in standard C++17 and can be compiled also as C++14. -# System requirements -Since SEAL has no external dependencies and is written in standard C++ it is easy -to build on any 64-bit system. For building in Windows, SEAL contains a Visual -Studio 2017 solution file. For building in Linux and Mac OS X, SEAL requires either -g++-6 or newer, or clang++-5 or newer. Please see INSTALL.txt for installation -instructions using CMake. +Microsoft Simple Encrypted Arithmetic Library (Microsoft SEAL) is an easy-to-use +homomorphic encryption library developed by researchers in the Cryptography +Research group at Microsoft Research. SEAL is written in modern standard C++ and +has no external dependencies, making it easy to compile and run in many different +environments. -# Documentation -The code-base contains (see SEALExamples/main.cpp) extensive and thoroughly -commented examples that should serve as a self-contained introduction to using SEAL. -In addition, the header files contain detailed comments for the public API. +For more information about the Microsoft SEAL project, see [http://sealcrypto.org](http://sealcrypto.org). # License -SEAL is licensed under the MIT license; see LICENSE.txt. - -# Contributing - -This project welcomes contributions and suggestions. Most contributions require you -to agree to a Contributor License Agreement (CLA) declaring that you have the right to, -and actually do, grant us the rights to use your contribution. For details, visit -https://cla.microsoft.com. - -When you submit a pull request, a CLA-bot will automatically determine whether you need -to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow -the instructions provided by the bot. You will only need to do this once across all -repos using our CLA. - -This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). -For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) -or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional -questions or comments. - -# Acknowledgements -We would like to thank John Wernsing, Michael Naehrig, Nathan Dowlin, Rachel Player, -Gizem Cetin, Susan Xia, Peter Rindal, Kyoohyung Han, Zhicong Huang, Amir Jalali, Wei Dai, -Ilia Iliashenko, and Sadegh Riazi for their contributions to the SEAL project. We would also -like to thank everyone who has sent us helpful comments, suggestions, and bug reports. - -# Contact Us -The best way to ask technical questions is on StackOverflow using the [seal] tag. To contact -us directly, please email [sealcrypto@microsoft.com](mailto:sealcrypto@microsoft.com). + +SEAL is licensed under the MIT license; see LICENSE. + +# Building and using SEAL + +## Windows + +SEAL comes with a Microsoft Visual Studio 2017 solution file SEAL.sln that can be +used to conveniently build the library, examples, and unit tests. + +#### Debug and release builds + +You can easily switch from Visual Studio configuration menu whether SEAL should be +built in Debug mode (no optimizations) or in Release mode. Please note that Debug +mode should not be used except for debugging SEAL itself, as the performance will be +orders of magnitude worse than in Release mode. + +#### Library + +Build the SEAL project (src/SEAL.vcxproj) from SEAL.sln. Building SEAL results +in the static library seal.lib to be created in lib/x64/$(Configuration). When +linking with applications, you need to add src/ (full path) as an include directory +for SEAL header files. + +#### Examples + +Build the SEALExamples project (examples/SEALExamples.vcxproj) from SEAL.sln. +This results in an executable sealexamples.exe to be created in bin/x64/$(Configuration). + +#### Unit tests + +The unit tests require the Google Test framework to be installed. The appropriate +NuGet package is already listed in tests/packages.config, so once you attempt to build +the SEALTest project (tests/SEALTest.vcxproj) from SEAL.sln Visual Studio will +automatically download and install it for you. + +## Linux and OS X + +SEAL is very easy to configure and build in Linux and OS X using CMake (>= 3.10). +A modern version of GNU G++ (>= 6.0) or Clang++ (>= 5.0) is needed. In OS X the +Xcode toolchain (>= 9.3) will work. + +In OS X you will need CMake with command line tools. For this, you can either +1. install the cmake package with [Homebrew](https://brew.sh), or +2. download CMake directly from [https://cmake.org/download](https://cmake.org/download) and [enable command line tools](https://stackoverflow.com/questions/30668601/installing-cmake-command-line-tools-on-a-mac). + +Below we give instructions for how to configure, build, and install SEAL either +system-wide (global install), or for a single user (local install). A system-wide +install requires elevated (root) privileges. + +#### Debug and release builds + +You can easily switch from CMake configuration options whether SEAL should be built in +Debug mode (no optimizations) or in Release mode. Please note that Debug mode should not +be used except for debugging SEAL itself, as the performance will be orders of magnitude +worse than in Release mode. + +### Global install + +#### Library + +If you have root access to the system you can install SEAL system-wide as follows: + cd src + cmake . + make + sudo make install + cd .. + +#### Examples + +To build the examples do: + cd examples + cmake . + make + cd .. + +After completing the above steps the sealexamples executable can be found in bin/. +See examples/CMakeLists.txt for how to link SEAL with your own project using cmake. + +#### Unit tests + +To build the unit tests, make sure you have the Google Test library (libgtest-dev) +installed. Then do: + cd tests + cmake . + make + cd .. + +After completing these steps the sealtest executable can be found in bin/. All unit +tests should pass successfully. + +### Local install + +#### Library + +To install SEAL locally, e.g., to ~/mylibs, do the following: + cd src + cmake -DCMAKE_INSTALL_PREFIX=~/mylibs . + make + make install + cd .. + +#### Examples + +To build the examples do: + cd examples + cmake -DCMAKE_PREFIX_PATH=~/mylibs . + make + cd .. + +After completing the above steps the sealexamples executable can be found in bin/. +See examples/CMakeLists.txt for how to link SEAL with your own project using cmake. + +#### Unit tests + +To build the unit tests, make sure you have the Google Test library (libgtest-dev) +installed. Then do: + cd tests + cmake -DCMAKE_PREFIX_PATH=~/mylibs . + make + cd .. + +After completing these steps the sealtest executable can be found in bin/. All unit +tests should pass successfully. + +# Documentation + +The code-base contains extensive and thoroughly commented examples that should +serve as a self-contained introduction to using SEAL (see examples/examples.cpp). In +addition, the header files contain detailed comments for the public API. diff --git a/SEAL.sln b/SEAL.sln new file mode 100644 index 000000000..de2ef65ad --- /dev/null +++ b/SEAL.sln @@ -0,0 +1,48 @@ +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio 15 +VisualStudioVersion = 15.0.26430.16 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "SEAL", "src\SEAL.vcxproj", "{7EA96C25-FC0D-485A-BB71-32B6DA55652A}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "SEALTest", "tests\SEALTest.vcxproj", "{0345DC4D-EFE3-460E-AB7E-AA6E05BB8DFF}" + ProjectSection(ProjectDependencies) = postProject + {7EA96C25-FC0D-485A-BB71-32B6DA55652A} = {7EA96C25-FC0D-485A-BB71-32B6DA55652A} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "SEALExamples", "examples\SEALExamples.vcxproj", "{2B57D847-26DC-45FF-B9AF-EE33910B5093}" + ProjectSection(ProjectDependencies) = postProject + {7EA96C25-FC0D-485A-BB71-32B6DA55652A} = {7EA96C25-FC0D-485A-BB71-32B6DA55652A} + EndProjectSection +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|x64 = Debug|x64 + Release|x64 = Release|x64 + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {7EA96C25-FC0D-485A-BB71-32B6DA55652A}.Debug|x64.ActiveCfg = Debug|x64 + {7EA96C25-FC0D-485A-BB71-32B6DA55652A}.Debug|x64.Build.0 = Debug|x64 + {7EA96C25-FC0D-485A-BB71-32B6DA55652A}.Release|x64.ActiveCfg = Release|x64 + {7EA96C25-FC0D-485A-BB71-32B6DA55652A}.Release|x64.Build.0 = Release|x64 + {0345DC4D-EFE3-460E-AB7E-AA6E05BB8DFF}.Debug|x64.ActiveCfg = Debug|x64 + {0345DC4D-EFE3-460E-AB7E-AA6E05BB8DFF}.Debug|x64.Build.0 = Debug|x64 + {0345DC4D-EFE3-460E-AB7E-AA6E05BB8DFF}.Release|x64.ActiveCfg = Release|x64 + {0345DC4D-EFE3-460E-AB7E-AA6E05BB8DFF}.Release|x64.Build.0 = Release|x64 + {2B57D847-26DC-45FF-B9AF-EE33910B5093}.Debug|x64.ActiveCfg = Debug|x64 + {2B57D847-26DC-45FF-B9AF-EE33910B5093}.Debug|x64.Build.0 = Debug|x64 + {2B57D847-26DC-45FF-B9AF-EE33910B5093}.Release|x64.ActiveCfg = Release|x64 + {2B57D847-26DC-45FF-B9AF-EE33910B5093}.Release|x64.Build.0 = Release|x64 + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {15A17F22-F747-4B82-BF5F-E0224AF4B3ED} + EndGlobalSection + GlobalSection(Performance) = preSolution + HasPerformanceSessions = true + EndGlobalSection + GlobalSection(Performance) = preSolution + HasPerformanceSessions = true + EndGlobalSection +EndGlobal diff --git a/Security.md b/Security.md new file mode 100644 index 000000000..777860636 --- /dev/null +++ b/Security.md @@ -0,0 +1,8 @@ +# Reporting Security Issues + +Security issues and bugs should be reported privately, via email, to the Microsoft Security +Response Center (MSRC) at [secure@microsoft.com](mailto:secure@microsoft.com). You should +receive a response within 24 hours. If for some reason you do not, please follow up via +email to ensure we received your original message. Further information, including the +[MSRC PGP](https://technet.microsoft.com/en-us/security/dn606155) key, can be found in +the [Security TechCenter](https://technet.microsoft.com/en-us/security/default). diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt new file mode 100644 index 000000000..03ca7661d --- /dev/null +++ b/examples/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +cmake_minimum_required(VERSION 3.10) + +project(SEALExamples VERSION 3.1.0 LANGUAGES CXX) + +# Executable will be in ../bin +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/../bin) + +add_executable(sealexamples examples.cpp) + +# Import SEAL +find_package(SEAL 3.1.0 EXACT REQUIRED) + +# Link SEAL +target_link_libraries(sealexamples SEAL::seal) diff --git a/examples/SEALExamples.vcxproj b/examples/SEALExamples.vcxproj new file mode 100644 index 000000000..4364cbc56 --- /dev/null +++ b/examples/SEALExamples.vcxproj @@ -0,0 +1,109 @@ + + + + + Debug + x64 + + + Release + x64 + + + + {2B57D847-26DC-45FF-B9AF-EE33910B5093} + Win32Proj + SEALExamples + 10.0.16299.0 + + + + Application + true + v141 + Unicode + + + Application + false + v141 + true + Unicode + + + + + + + + + + + + + + + true + $(SolutionDir)bin\$(Platform)\$(Configuration)\ + $(ProjectDir)obj\$(Platform)\$(Configuration)\ + sealexamples + + + false + $(SolutionDir)bin\$(Platform)\$(Configuration)\ + $(ProjectDir)obj\$(Platform)\$(Configuration)\ + sealexamples + + + + Level3 + NotUsing + Disabled + + + true + $(SolutionDir)src + stdcpp17 + /Zc:__cplusplus %(AdditionalOptions) + + + Console + true + $(SolutionDir)lib\$(Platform)\$(Configuration);%(AdditionalLibraryDirectories) + seal.lib;%(AdditionalDependencies) + + + + + Level3 + NotUsing + MaxSpeed + true + true + + + true + $(SolutionDir)src + stdcpp17 + /Zc:__cplusplus %(AdditionalOptions) + + + Console + true + true + true + $(SolutionDir)lib\$(Platform)\$(Configuration);%(AdditionalLibraryDirectories) + seal.lib;%(AdditionalDependencies) + true + + + + + + + + + + + + \ No newline at end of file diff --git a/examples/SEALExamples.vcxproj.filters b/examples/SEALExamples.vcxproj.filters new file mode 100644 index 000000000..272ea16f4 --- /dev/null +++ b/examples/SEALExamples.vcxproj.filters @@ -0,0 +1,30 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms + + + {abd2e216-316f-4dad-a2a4-a72ffccfd92b} + + + + + Source Files + + + + + Linux + + + diff --git a/examples/examples.cpp b/examples/examples.cpp new file mode 100644 index 000000000..21950a821 --- /dev/null +++ b/examples/examples.cpp @@ -0,0 +1,2666 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "seal/seal.h" + +using namespace std; +using namespace seal; + +/* +Helper function: Prints the name of the example in a fancy banner. +*/ +void print_example_banner(string title) +{ + if (!title.empty()) + { + size_t title_length = title.length(); + size_t banner_length = title_length + 2 + 2 * 10; + string banner_top(banner_length, '*'); + string banner_middle = string(10, '*') + " " + title + " " + string(10, '*'); + + cout << endl + << banner_top << endl + << banner_middle << endl + << banner_top << endl + << endl; + } +} + +/* +Helper function: Prints the parameters in a SEALContext. +*/ +void print_parameters(shared_ptr context) +{ + // Verify parameters + if (!context) + { + throw invalid_argument("context is not set"); + } + auto &context_data = *context->context_data(); + + /* + Which scheme are we using? + */ + string scheme_name; + switch (context_data.parms().scheme()) + { + case scheme_type::BFV: + scheme_name = "BFV"; + break; + case scheme_type::CKKS: + scheme_name = "CKKS"; + break; + default: + throw invalid_argument("unsupported scheme"); + } + + cout << "/ Encryption parameters:" << endl; + cout << "| scheme: " << scheme_name << endl; + cout << "| poly_modulus_degree: " << + context_data.parms().poly_modulus_degree() << endl; + + /* + Print the size of the true (product) coefficient modulus. + */ + cout << "| coeff_modulus size: " << context_data. + total_coeff_modulus_bit_count() << " bits" << endl; + + /* + For the BFV scheme print the plain_modulus parameter. + */ + if (context_data.parms().scheme() == scheme_type::BFV) + { + cout << "| plain_modulus: " << context_data. + parms().plain_modulus().value() << endl; + } + + cout << "\\ noise_standard_deviation: " << context_data. + parms().noise_standard_deviation() << endl; + cout << endl; +} + +/* +Helper function: Prints the `parms_id' to std::ostream. +*/ +ostream &operator <<(ostream &stream, parms_id_type parms_id) +{ + stream << hex << parms_id[0] << " " << parms_id[1] << " " + << parms_id[2] << " " << parms_id[3] << dec; + return stream; +} + +/* +Helper function: Prints a vector of floating-point values. +*/ +template +void print_vector(vector vec, size_t print_size = 4, int prec = 3) +{ + /* + Save the formatting information for std::cout. + */ + ios old_fmt(nullptr); + old_fmt.copyfmt(cout); + + size_t slot_count = vec.size(); + + cout << fixed << setprecision(prec) << endl; + if(slot_count <= 2 * print_size) + { + cout << " ["; + for (size_t i = 0; i < slot_count; i++) + { + cout << " " << vec[i] << ((i != slot_count - 1) ? "," : " ]\n"); + } + } + else + { + vec.resize(max(vec.size(), 2 * print_size)); + cout << " ["; + for (size_t i = 0; i < print_size; i++) + { + cout << " " << vec[i] << ","; + } + if(vec.size() > 2 * print_size) + { + cout << " ...,"; + } + for (size_t i = slot_count - print_size; i < slot_count; i++) + { + cout << " " << vec[i] << ((i != slot_count - 1) ? "," : " ]\n"); + } + } + cout << endl; + + /* + Restore the old std::cout formatting. + */ + cout.copyfmt(old_fmt); +} + +void example_bfv_basics_i(); + +void example_bfv_basics_ii(); + +void example_bfv_basics_iii(); + +void example_bfv_basics_iv(); + +void example_ckks_basics_i(); + +void example_ckks_basics_ii(); + +void example_ckks_basics_iii(); + +void example_bfv_performance(); + +void example_ckks_performance(); + +int main() +{ +#ifdef SEAL_VERSION + cout << "SEAL version: " << SEAL_VERSION << endl; +#endif + while (true) + { + cout << "\nSEAL Examples:" << endl << endl; + cout << " 1. BFV Basics I" << endl; + cout << " 2. BFV Basics II" << endl; + cout << " 3. BFV Basics III" << endl; + cout << " 4. BFV Basics IV" << endl; + cout << " 5. BFV Performance Test" << endl; + cout << " 6. CKKS Basics I" << endl; + cout << " 7. CKKS Basics II" << endl; + cout << " 8. CKKS Basics III" << endl; + cout << " 9. CKKS Performance Test" << endl; + cout << " 0. Exit" << endl; + + /* + Print how much memory we have allocated from the current memory pool. + By default the memory pool will be a static global pool and the + MemoryManager class can be used to change it. Most users should have + little or no reason to touch the memory allocation system. + */ + cout << "\nTotal memory allocated from the current memory pool: " + << (MemoryManager::GetPool().alloc_byte_count() >> 20) << " MB" << endl; + + int selection = 0; + cout << endl << "Run example: "; + if (!(cin >> selection)) + { + cout << "Invalid option." << endl; + cin.clear(); + cin.ignore(numeric_limits::max(), '\n'); + continue; + } + + switch (selection) + { + case 1: + example_bfv_basics_i(); + break; + + case 2: + example_bfv_basics_ii(); + break; + + case 3: + example_bfv_basics_iii(); + break; + + case 4: + example_bfv_basics_iv(); + break; + + case 5: + example_bfv_performance(); + break; + + case 6: + example_ckks_basics_i(); + break; + + case 7: + example_ckks_basics_ii(); + break; + + case 8: + example_ckks_basics_iii(); + break; + + case 9: { + example_ckks_performance(); + break; + } + + case 0: + return 0; + + default: + cout << "Invalid option." << endl; + } + } + + return 0; +} + +void example_bfv_basics_i() +{ + print_example_banner("Example: BFV Basics I"); + + /* + In this example we demonstrate setting up encryption parameters and other + relevant objects for performing simple computations on encrypted integers. + + SEAL implements two encryption schemes: the Brakerski/Fan-Vercauteren (BFV) + scheme and the Cheon-Kim-Kim-Song (CKKS) scheme. In the first examples we + use the BFV scheme as it is far easier to understand and use than CKKS. For + more details on the basics of the BFV scheme, we refer the reader to the + original paper https://eprint.iacr.org/2012/144. In truth, to achieve good + performance SEAL implements the "FullRNS" optimization as described in + https://eprint.iacr.org/2016/510, but this optiomization is invisible to + the user and has no security implications. We will discuss the CKKS scheme + in later examples. + + The first task is to set up an instance of the EncryptionParameters class. + It is critical to understand how these different parameters behave, how they + affect the encryption scheme, performance, and the security level. There are + three encryption parameters that are necessary to set: + + - poly_modulus_degree (degree of polynomial modulus); + - coeff_modulus ([ciphertext] coefficient modulus); + - plain_modulus (plaintext modulus). + + A fourth parameter -- noise_standard_deviation -- has a default value 3.20 + and should not be necessary to modify unless the user has a specific reason + to do so and has an in-depth understanding of the security implications. + + A fifth parameter -- random_generator -- can be set to use customized random + number generators. By default, SEAL uses hardware-based AES in counter mode + for pseudo-randomness with key generated using std::random_device. If the + AES-NI instruction set is not available, all randomness is generated from + std::random_device. Most academic users in particular should have little + reason to change this. + + The BFV scheme cannot perform arbitrary computations on encrypted data. + Instead, each ciphertext has a specific quantity called the `invariant noise + budget' -- or `noise budget' for short -- measured in bits. The noise budget + in a freshly encrypted ciphertext (initial noise budget) is determined by + the encryption parameters. Homomorphic operations consume the noise budget + at a rate also determined by the encryption parameters. In BFV the two basic + operations allowed on encrypted data are additions and multiplications, of + which additions can generally be thought of as being nearly free in terms of + noise budget consumption compared to multiplications. Since noise budget + consumption compounds in sequential multiplications, the most significant + factor in choosing appropriate encryption parameters is the multiplicative + depth of the arithmetic circuit that the user wants to evaluate on encrypted + data. Once the noise budget of a ciphertext reaches zero it becomes too + corrupted to be decrypted. Thus, it is essential to choose the parameters to + be large enough to support the desired computation; otherwise the result is + impossible to make sense of even with the secret key. + */ + EncryptionParameters parms(scheme_type::BFV); + + /* + The first parameter we set is the degree of the polynomial modulus. This must + be a positive power of 2, representing the degree of a power-of-2 cyclotomic + polynomial; it is not necessary to understand what this means. The polynomial + modulus degree should be thought of mainly affecting the security level of the + scheme: larger degree makes the scheme more secure. Larger degree also makes + ciphertext sizes larger, and consequently all operations slower. Recommended + degrees are 1024, 2048, 4096, 8192, 16384, 32768, but it is also possible to + go beyond this. In this example we use a relatively small polynomial modulus. + */ + parms.set_poly_modulus_degree(2048); + + /* + Next we set the [ciphertext] coefficient modulus (coeff_modulus). The size + of the coefficient modulus should be thought of as the most significant + factor in determining the noise budget in a freshly encrypted ciphertext: + bigger means more noise budget, which is desirable. On the other hand, + a larger coefficient modulus lowers the security level of the scheme. Thus, + if a large noise budget is required for complicated computations, a large + coefficient modulus needs to be used, and the reduction in the security + level must be countered by simultaneously increasing the polynomial modulus. + Overall, this will result in worse performance. + + To make parameter selection easier for the user, we have constructed sets + of largest safe coefficient moduli for 128-bit and 192-bit security levels + for different choices of the polynomial modulus. These default parameters + follow the recommendations in the Security Standard Draft available at + http://HomomorphicEncryption.org. The security estimates are a complicated + topic and we highly recommend consulting with experts in the field when + selecting parameters. + + Our recommended values for the coefficient modulus can be easily accessed + through the functions + + coeff_modulus_128bit(int) + coeff_modulus_192bit(int) + coeff_modulus_256bit(int) + + for 128-bit, 192-bit, and 256-bit security levels. The integer parameter is + the degree of the polynomial modulus used. + + In SEAL the coefficient modulus is a positive composite number -- a product + of distinct primes of size up to 60 bits. When we talk about the size of the + coefficient modulus we mean the bit length of the product of the primes. The + small primes are represented by instances of the SmallModulus class so for + example coeff_modulus_128bit(int) returns a vector of SmallModulus instances. + + It is possible for the user to select their own small primes. Since SEAL uses + the Number Theoretic Transform (NTT) for polynomial multiplications modulo the + factors of the coefficient modulus, the factors need to be prime numbers + congruent to 1 modulo 2*poly_modulus_degree. We have generated a list of such + prime numbers of various sizes that the user can easily access through the + functions + + small_mods_60bit(int) + small_mods_50bit(int) + small_mods_40bit(int) + small_mods_30bit(int) + + each of which gives access to an array of primes of the denoted size. These + primes are located in the source file util/globals.cpp. Again, please keep + in mind that the choice of coeff_modulus has a dramatic effect on security + and should almost always be obtained through coeff_modulus_xxx(int). + + Performance is mainly affected by the size of the polynomial modulus, and + the number of prime factors in the coefficient modulus; hence in some cases + it can be important to use as few prime factors in the coefficient modulus + as possible. + + In this example we use the default coefficient modulus for a 128-bit security + level. Concretely, this coefficient modulus consists of only one 54-bit prime + factor: 0x3fffffff000001. + */ + parms.set_coeff_modulus(coeff_modulus_128(2048)); + + /* + The plaintext modulus can be any positive integer, even though here we take + it to be a power of two. In fact, in many cases one might instead want it + to be a prime number; we will see this in later examples. The plaintext + modulus determines the size of the plaintext data type but it also affects + the noise budget in a freshly encrypted ciphertext and the consumption of + noise budget in homomorphic (encrypted) multiplications. Thus, it is + essential to try to keep the plaintext data type as small as possible for + best performance. The noise budget in a freshly encrypted ciphertext is + + ~ log2(coeff_modulus/plain_modulus) (bits) + + and the noise budget consumption in a homomorphic multiplication is of the + form log2(plain_modulus) + (other terms). + */ + parms.set_plain_modulus(1 << 8); + + /* + Now that all parameters are set, we are ready to construct a SEALContext + object. This is a heavy class that checks the validity and properties of the + parameters we just set and performs several important pre-computations. + */ + auto context = SEALContext::Create(parms); + + /* + Print the parameters that we have chosen. + */ + print_parameters(context); + + /* + Plaintexts in the BFV scheme are polynomials with coefficients integers + modulo plain_modulus. This is not a very practical object to encrypt: much + more useful would be encrypting integers or floating point numbers. For this + we need an `encoding scheme' to convert data from integer representation to + an appropriate plaintext polynomial representation than can subsequently be + encrypted. SEAL comes with a few basic encoders for the BFV scheme: + + [IntegerEncoder] + Given an integer base b, encodes integers as plaintext polynomials as follows. + First, a base-b expansion of the integer is computed. This expansion uses + a `balanced' set of representatives of integers modulo b as the coefficients. + Namely, when b is odd the coefficients are integers between -(b-1)/2 and + (b-1)/2. When b is even, the integers are between -b/2 and (b-1)/2, except + when b is two and the usual binary expansion is used (coefficients 0 and 1). + Decoding amounts to evaluating the polynomial at x=b. For example, if b=2, + the integer + + 26 = 2^4 + 2^3 + 2^1 + + is encoded as the polynomial 1x^4 + 1x^3 + 1x^1. When b=3, + + 26 = 3^3 - 3^0 + + is encoded as the polynomial 1x^3 - 1. In memory polynomial coefficients are + always stored as unsigned integers by storing their smallest non-negative + representatives modulo plain_modulus. To create a base-b integer encoder, + use the constructor IntegerEncoder(plain_modulus, b). If no b is given, b=2 + is used. + + [FractionalEncoder] + The FractionalEncoder encodes fixed-precision rational numbers as follows. + It expands the number in a given base b, possibly truncating an infinite + fractional part to finite precision, e.g. + + 26.75 = 2^4 + 2^3 + 2^1 + 2^(-1) + 2^(-2) + + when b=2. For the sake of the example, suppose poly_modulus is 1x^1024 + 1. + It then represents the integer part of the number in the same way as in + IntegerEncoder (with b=2 here), and moves the fractional part instead to the + highest degree part of the polynomial, but with signs of the coefficients + changed. In this example we would represent 26.75 as the polynomial + + -1x^1023 - 1x^1022 + 1x^4 + 1x^3 + 1x^1. + + In memory the negative coefficients of the polynomial will be represented as + their negatives modulo plain_modulus. While easy to use, the fractional + encoder suffers from drawbacks that can be avoided using the CKKS scheme + instead of BFV; hence, we do not demonstrate the FractionalEncoder in these + examples. + + [BatchEncoder] + If plain_modulus is a prime congruent to 1 modulo 2*poly_modulus_degree, the + plaintext elements can be viewed as 2-by-(poly_modulus_degree / 2) matrices + with elements integers modulo plain_modulus. When a desired computation can + be vectorized, using BatchEncoder can result in a massive performance boost + over naively encrypting and operating on each input number separately. Thus, + in more complicated computations this is likely to be by far the most + important and useful encoder. In example_bfv_basics_iii() we show how to + operate on encrypted matrix plaintexts. + + Here we choose to create an IntegerEncoder with base b=2. For most use-cases + of the IntegerEncoder this is a good choice. + */ + IntegerEncoder encoder(parms.plain_modulus()); + + /* + We are now ready to generate the secret and public keys. For this purpose + we need an instance of the KeyGenerator class. Constructing a KeyGenerator + automatically generates the public and secret key, which can then be read to + local variables. + */ + KeyGenerator keygen(context); + PublicKey public_key = keygen.public_key(); + SecretKey secret_key = keygen.secret_key(); + + /* + To be able to encrypt we need to construct an instance of Encryptor. Note + that the Encryptor only requires the public key, as expected. + */ + Encryptor encryptor(context, public_key); + + /* + Computations on the ciphertexts are performed with the Evaluator class. In + a real use-case the Evaluator would not be constructed by the same party + that holds the secret key. + */ + Evaluator evaluator(context); + + /* + We will of course want to decrypt our results to verify that everything worked, + so we need to also construct an instance of Decryptor. Note that the Decryptor + requires the secret key. + */ + Decryptor decryptor(context, secret_key); + + /* + We start by encoding two integers as plaintext polynomials. + */ + int value1 = 5; + Plaintext plain1 = encoder.encode(value1); + cout << "Encoded " << value1 << " as polynomial " << plain1.to_string() + << " (plain1)" << endl; + + int value2 = -7; + Plaintext plain2 = encoder.encode(value2); + cout << "Encoded " << value2 << " as polynomial " << plain2.to_string() + << " (plain2)" << endl; + + /* + Encrypting the encoded values is easy. + */ + Ciphertext encrypted1, encrypted2; + cout << "Encrypting plain1: "; + encryptor.encrypt(plain1, encrypted1); + cout << "Done (encrypted1)" << endl; + + cout << "Encrypting plain2: "; + encryptor.encrypt(plain2, encrypted2); + cout << "Done (encrypted2)" << endl; + + /* + To illustrate the concept of noise budget, we print the budgets in the fresh + encryptions. + */ + cout << "Noise budget in encrypted1: " + << decryptor.invariant_noise_budget(encrypted1) << " bits" << endl; + cout << "Noise budget in encrypted2: " + << decryptor.invariant_noise_budget(encrypted2) << " bits" << endl; + + /* + As a simple example, we compute (-encrypted1 + encrypted2) * encrypted2. Most + basic arithmetic operations come as in-place two-argument versions that + overwrite the first argument with the result, and as three-argument versions + taking as separate destination parameter. In most cases the in-place variants + are slightly faster. + */ + + /* + Negation is a unary operation and does not consume any noise budget. + */ + evaluator.negate_inplace(encrypted1); + cout << "Noise budget in -encrypted1: " + << decryptor.invariant_noise_budget(encrypted1) << " bits" << endl; + + /* + Compute the sum of encrypted1 and encrypted2; the sum overwrites encrypted1. + */ + evaluator.add_inplace(encrypted1, encrypted2); + + /* + Addition sets the noise budget to the minimum of the input noise budgets. + In this case both inputs had roughly the same budget going in, so the output + (in encrypted1) has just a slightly lower budget. Depending on probabilistic + effects the noise growth consumption may or may not be visible when measured + in whole bits. + */ + cout << "Noise budget in -encrypted1 + encrypted2: " + << decryptor.invariant_noise_budget(encrypted1) << " bits" << endl; + + /* + Finally multiply with encrypted2. Again, we use the in-place version of the + function, overwriting encrypted1 with the product. + */ + evaluator.multiply_inplace(encrypted1, encrypted2); + + /* + Multiplication consumes a lot of noise budget. This is clearly seen in the + print-out. The user can change the plain_modulus to see its effect on the + rate of noise budget consumption. + */ + cout << "Noise budget in (-encrypted1 + encrypted2) * encrypted2: " + << decryptor.invariant_noise_budget(encrypted1) << " bits" << endl; + + /* + Now we decrypt and decode our result. + */ + Plaintext plain_result; + cout << "Decrypting result: "; + decryptor.decrypt(encrypted1, plain_result); + cout << "Done" << endl; + + /* + Print the result plaintext polynomial. + */ + cout << "Plaintext polynomial: " << plain_result.to_string() << endl; + + /* + Decode to obtain an integer result. + */ + cout << "Decoded integer: " << encoder.decode_int32(plain_result) << endl; +} + +void example_bfv_basics_ii() +{ + print_example_banner("Example: BFV Basics II"); + + /* + In this example we explain what relinearization is, how to use it, and how + it affects noise budget consumption. Relinearization is used both in the BFV + and the CKKS schemes but in this example (for the sake of simplicity) we + again focus on BFV. + + First we set the parameters, create a SEALContext, and generate the public + and secret keys. We use slightly larger parameters than before to be able to + do more homomorphic multiplications. + */ + EncryptionParameters parms(scheme_type::BFV); + parms.set_poly_modulus_degree(8192); + + /* + The default coefficient modulus consists of the following primes: + + 0x7fffffff380001, 0x7ffffffef00001, + 0x3fffffff000001, 0x3ffffffef40001 + + The total size is 218 bits. + */ + parms.set_coeff_modulus(coeff_modulus_128(8192)); + parms.set_plain_modulus(1 << 10); + + auto context = SEALContext::Create(parms); + print_parameters(context); + + /* + We generate the public and secret keys as before. + + There are actually two more types of keys in SEAL: `relinearization keys' + and `Galois keys'. In this example we will discuss relinearization keys, and + Galois keys will be discussed later in example_bfv_basics_iii(). + */ + KeyGenerator keygen(context); + auto public_key = keygen.public_key(); + auto secret_key = keygen.secret_key(); + + /* + We also set up an Encryptor, Evaluator, and Decryptor here. We will + encrypt polynomials directly in this example, so there is no need for + an encoder. + */ + Encryptor encryptor(context, public_key); + Evaluator evaluator(context); + Decryptor decryptor(context, secret_key); + + /* + We can easily construct a plaintext polynomial from a string. Again, note + how there is no need for encoding since the BFV scheme natively encrypts + polynomials. + */ + Plaintext plain1("1x^2 + 2x^1 + 3"); + Ciphertext encrypted; + cout << "Encrypting " << plain1.to_string() << ": "; + encryptor.encrypt(plain1, encrypted); + cout << "Done" << endl; + + /* + In SEAL, a valid ciphertext consists of two or more polynomials whose + coefficients are integers modulo the product of the primes in coeff_modulus. + The current size of a ciphertext can be found using Ciphertext::size(). + A freshly encrypted ciphertext always has size 2. + */ + cout << "Size of a fresh encryption: " << encrypted.size() << endl; + cout << "Noise budget in fresh encryption: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + + /* + Homomorphic multiplication results in the output ciphertext growing in size. + More precisely, if the input ciphertexts have size M and N, then the output + ciphertext after homomorphic multiplication will have size M+N-1. In this + case we square encrypted twice to observe this growth (also observe noise + budget consumption). + */ + evaluator.square_inplace(encrypted); + cout << "Size after squaring: " << encrypted.size() << endl; + cout << "Noise budget after squaring: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + + evaluator.square_inplace(encrypted); + cout << "Size after second squaring: " << encrypted.size() << endl; + cout << "Noise budget after second squaring: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + + /* + It does not matter that the size has grown -- decryption works as usual. + Observe from the print-out that the coefficients in the plaintext have grown + quite large. One more squaring would cause some of them to wrap around the + plain_modulus (0x400) and as a result we would no longer obtain the expected + result as an integer-coefficient polynomial. We can fix this problem to some + extent by increasing plain_modulus. This makes sense since we still have + plenty of noise budget left. + */ + Plaintext plain2; + decryptor.decrypt(encrypted, plain2); + cout << "Fourth power: " << plain2.to_string() << endl; + cout << endl; + + /* + The problem here is that homomorphic operations on large ciphertexts are + computationally much more costly than on small ciphertexts. Specifically, + homomorphic multiplication on input ciphertexts of size M and N will require + O(M*N) polynomial multiplications to be performed, and an addition will + require O(M+N) additions. Relinearization reduces the size of ciphertexts + after multiplication back to the initial size (2). Thus, relinearizing one + or both inputs before the next multiplication or e.g. before serializing the + ciphertexts, can have a huge positive impact on performance. + + Another problem is that the noise budget consumption in multiplication is + bigger when the input ciphertexts sizes are bigger. In a complicated + computation the contribution of the sizes to the noise budget consumption + can actually become the dominant term. We will point this out again below + once we get to our example. + + Relinearization itself has both a computational cost and a noise budget cost. + These both depend on a parameter called `decomposition bit count', which can + be any integer at least 1 [dbc_min()] and at most 60 [dbc_max()]. A large + decomposition bit count makes relinearization fast, but consumes more noise + budget. A small decomposition bit count can make relinearization slower, but + might not change the noise budget by any observable amount. + + Relinearization requires a special type of key called `relinearization keys'. + These can be created by the KeyGenerator for any decomposition bit count. + To relinearize a ciphertext of size M >= 2 back to size 2, we actually need + M-2 relinearization keys. Attempting to relinearize a too large ciphertext + with too few relinearization keys will result in an exception being thrown. + + We repeat our computation, but this time relinearize after both squarings. + Since our ciphertext never grows past size 3 (we relinearize after every + multiplication), it suffices to generate only one relinearization key. This + (relinearizing after every multiplication) should be the preferred approach + in almost all cases. + + First, we need to create relinearization keys. We use a decomposition bit + count of 16 here, which should be thought of as very small. + + This function generates one single relinearization key. Another overload + of KeyGenerator::relin_keys takes the number of keys to be generated as an + argument, but one is all we need in this example (see above). + */ + auto relin_keys16 = keygen.relin_keys(16); + + cout << "Encrypting " << plain1.to_string() << ": "; + encryptor.encrypt(plain1, encrypted); + cout << "Done" << endl; + cout << "Size of a fresh encryption: " << encrypted.size() << endl; + cout << "Noise budget in fresh encryption: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + + evaluator.square_inplace(encrypted); + cout << "Size after squaring: " << encrypted.size() << endl; + cout << "Noise budget after squaring: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + + evaluator.relinearize_inplace(encrypted, relin_keys16); + cout << "Size after relinearization: " << encrypted.size() << endl; + cout << "Noise budget after relinearizing (dbc = " + << relin_keys16.decomposition_bit_count() << "): " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + + evaluator.square_inplace(encrypted); + cout << "Size after second squaring: " << encrypted.size() << endl; + cout << "Noise budget after second squaring: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + + evaluator.relinearize_inplace(encrypted, relin_keys16); + cout << "Size after relinearization: " << encrypted.size() << endl; + cout << "Noise budget after relinearizing (dbc = " + << relin_keys16.decomposition_bit_count() << "): " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + + decryptor.decrypt(encrypted, plain2); + cout << "Fourth power: " << plain2.to_string() << endl; + cout << endl; + + /* + Of course the result is still the same, but this time we actually used less + of our noise budget. This is not surprising for two reasons: + + - We used a very small decomposition bit count, which is why + relinearization itself did not consume the noise budget by any + observable amount; + - Since our ciphertext sizes remain small throughout the two + squarings, the noise budget consumption rate in multiplication + remains as small as possible. Recall from above that operations + on larger ciphertexts actually cause more noise growth. + + To make things more clear, we repeat the computation a third time, now using + the largest possible decomposition bit count (60). We are not measuring + running time here, but relinearization with relin_keys60 (below) is much + faster than with relin_keys16. + */ + auto relin_keys60 = keygen.relin_keys(dbc_max()); + + cout << "Encrypting " << plain1.to_string() << ": "; + encryptor.encrypt(plain1, encrypted); + cout << "Done" << endl; + cout << "Size of a fresh encryption: " << encrypted.size() << endl; + cout << "Noise budget in fresh encryption: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + + evaluator.square_inplace(encrypted); + cout << "Size after squaring: " << encrypted.size() << endl; + cout << "Noise budget after squaring: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + evaluator.relinearize_inplace(encrypted, relin_keys60); + cout << "Size after relinearization: " << encrypted.size() << endl; + cout << "Noise budget after relinearizing (dbc = " + << relin_keys60.decomposition_bit_count() << "): " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + + evaluator.square_inplace(encrypted); + cout << "Size after second squaring: " << encrypted.size() << endl; + cout << "Noise budget after second squaring: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + evaluator.relinearize_inplace(encrypted, relin_keys60); + cout << "Size after relinearization: " << encrypted.size() << endl; + cout << "Noise budget after relinearizing (dbc = " + << relin_keys60.decomposition_bit_count() << "): " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + + decryptor.decrypt(encrypted, plain2); + cout << "Fourth power: " << plain2.to_string() << endl; + cout << endl; + + /* + Observe from the print-out that we have now used significantly more of our + noise budget than in the two previous runs. This is again not surprising, + since the first relinearization chops off a huge part of the noise budget. + + However, note that the second relinearization does not change the noise + budget by any observable amount. This is very important to understand when + optimal performance is desired: relinearization always drops the noise + budget from the maximum (freshly encrypted ciphertext) down to a fixed + amount depending on the encryption parameters and the decomposition bit + count. On the other hand, homomorphic multiplication always consumes the + noise budget from its current level. This is why the second relinearization + does not change the noise budget anymore: it is already consumed past the + fixed amount determinted by the decomposition bit count and the encryption + parameters. + + We now perform a third squaring and observe an even further compounded + decrease in the noise budget. Again, relinearization does not consume the + noise budget at this point by any observable amount, even with the largest + possible decomposition bit count. + */ + evaluator.square_inplace(encrypted); + cout << "Size after third squaring: " << encrypted.size() << endl; + cout << "Noise budget after third squaring: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + + evaluator.relinearize_inplace(encrypted, relin_keys60); + cout << "Size after relinearization: " << encrypted.size() << endl; + cout << "Noise budget after relinearizing (dbc = " + << relin_keys60.decomposition_bit_count() << "): " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + + decryptor.decrypt(encrypted, plain2); + cout << "Eighth power: " << plain2.to_string() << endl; + + /* + Observe from the print-out that the polynomial coefficients are no longer + correct as integers: they have been reduced modulo plain_modulus, and there + was no warning sign about this. It might be necessary to carefully analyze + the computation to make sure such overflow does not occur unexpectedly. + + These experiments suggest that an optimal strategy might be to relinearize + first with relinearization keys with a small decomposition bit count, and + later with relinearization keys with a larger decomposition bit count (for + performance) when noise budget has already been consumed past the bound + determined by the larger decomposition bit count. For example, the best + strategy might have been to use relin_keys16 in the first relinearization + and relin_keys60 in the next two relinearizations for optimal noise budget + consumption/performance trade-off. Luckily, in most use-cases it is not so + critical to squeeze out every last bit of performance, especially when + larger parameters are used. + */ +} + +void example_bfv_basics_iii() +{ + print_example_banner("Example: BFV Basics III"); + + /* + In this fundamental example we discuss and demonstrate a powerful technique + called `batching'. If N denotes the degree of the polynomial modulus, and T + the plaintext modulus, then batching is automatically enabled for the BFV + scheme when T is a prime number congruent to 1 modulo 2*N. In batching the + plaintexts are viewed as matrices of size 2-by-(N/2) with each element an + integer modulo T. Homomorphic operations act element-wise between encrypted + matrices, allowing the user to obtain speeds-ups of several orders of + magnitude in naively vectorizable computations. We demonstrate two more + homomorphic operations which act on encrypted matrices by rotating the rows + cyclically, or rotate the columns (i.e. swap the rows). These operations + require the construction of so-called `Galois keys', which are very similar + to relinearization keys. + + The batching functionality is totally optional in the BFV scheme and is + exposed through the BatchEncoder class. + */ + EncryptionParameters parms(scheme_type::BFV); + + parms.set_poly_modulus_degree(4096); + parms.set_coeff_modulus(coeff_modulus_128(4096)); + + /* + Note that 40961 is a prime number and 2*4096 divides 40960, so batching will + automatically be enabled for these parameters. + */ + parms.set_plain_modulus(40961); + + auto context = SEALContext::Create(parms); + print_parameters(context); + + /* + We can verify that batching is indeed enabled by looking at the encryption + parameter qualifiers created by SEALContext. + */ + auto qualifiers = context->context_data()->qualifiers(); + cout << "Batching enabled: " << boolalpha << qualifiers.using_batching << endl; + + KeyGenerator keygen(context); + auto public_key = keygen.public_key(); + auto secret_key = keygen.secret_key(); + + /* + We need to create so-called `Galois keys' for performing matrix row and + column rotations on encrypted matrices. Like relinearization keys, the + behavior of Galois keys depends on a decomposition bit count. The noise + budget consumption behavior of matrix row and column rotations is exactly + like that of relinearization (recall example_bfv_basics_ii()). + + Here we use a moderate size decomposition bit count. + */ + auto gal_keys = keygen.galois_keys(30); + + /* + Since we are going to do some multiplications we will also relinearize. + */ + auto relin_keys = keygen.relin_keys(30); + + /* + We also set up an Encryptor, Evaluator, and Decryptor here. + */ + Encryptor encryptor(context, public_key); + Evaluator evaluator(context); + Decryptor decryptor(context, secret_key); + + /* + Batching is done through an instance of the BatchEncoder class so need to + construct one. + */ + BatchEncoder batch_encoder(context); + + /* + The total number of batching `slots' is poly_modulus_degree. The matrices + we encrypt are of size 2-by-(slot_count / 2). + */ + size_t slot_count = batch_encoder.slot_count(); + size_t row_size = slot_count / 2; + cout << "Plaintext matrix row size: " << row_size << endl; + + /* + Printing the matrix is a bit of a pain. + */ + auto print_matrix = [row_size](auto &matrix) + { + cout << endl; + + /* + We're not going to print every column of the matrix (there are 2048). Instead + print this many slots from beginning and end of the matrix. + */ + size_t print_size = 5; + + cout << " ["; + for (size_t i = 0; i < print_size; i++) + { + cout << setw(3) << matrix[i] << ","; + } + cout << setw(3) << " ...,"; + for (size_t i = row_size - print_size; i < row_size; i++) + { + cout << setw(3) << matrix[i] << ((i != row_size - 1) ? "," : " ]\n"); + } + cout << " ["; + for (size_t i = row_size; i < row_size + print_size; i++) + { + cout << setw(3) << matrix[i] << ","; + } + cout << setw(3) << " ...,"; + for (size_t i = 2 * row_size - print_size; i < 2 * row_size; i++) + { + cout << setw(3) << matrix[i] << ((i != 2 * row_size - 1) ? "," : " ]\n"); + } + cout << endl; + }; + + /* + The matrix plaintext is simply given to BatchEncoder as a flattened vector + of numbers of size slot_count. The first row_size numbers form the first row, + and the rest form the second row. Here we create the following matrix: + + [ 0, 1, 2, 3, 0, 0, ..., 0 ] + [ 4, 5, 6, 7, 0, 0, ..., 0 ] + */ + vector pod_matrix(slot_count, 0ULL); + pod_matrix[0] = 0ULL; + pod_matrix[1] = 1ULL; + pod_matrix[2] = 2ULL; + pod_matrix[3] = 3ULL; + pod_matrix[row_size] = 4ULL; + pod_matrix[row_size + 1] = 5ULL; + pod_matrix[row_size + 2] = 6ULL; + pod_matrix[row_size + 3] = 7ULL; + + cout << "Input plaintext matrix:" << endl; + print_matrix(pod_matrix); + + /* + First we use BatchEncoder to compose the matrix into a plaintext. + */ + Plaintext plain_matrix; + batch_encoder.encode(pod_matrix, plain_matrix); + + /* + Next we encrypt the plaintext as usual. + */ + Ciphertext encrypted_matrix; + cout << "Encrypting: "; + encryptor.encrypt(plain_matrix, encrypted_matrix); + cout << "Done" << endl; + cout << "Noise budget in fresh encryption: " + << decryptor.invariant_noise_budget(encrypted_matrix) << " bits" << endl; + + /* + Operating on the ciphertext results in homomorphic operations being performed + simultaneously in all 4096 slots (matrix elements). To illustrate this, we + form another plaintext matrix + + [ 1, 2, 1, 2, 1, 2, ..., 2 ] + [ 1, 2, 1, 2, 1, 2, ..., 2 ] + + and compose it into a plaintext. + */ + vector pod_matrix2; + for (size_t i = 0; i < slot_count; i++) + { + pod_matrix2.push_back((i % 2) + 1); + } + Plaintext plain_matrix2; + batch_encoder.encode(pod_matrix2, plain_matrix2); + cout << "Second input plaintext matrix:" << endl; + print_matrix(pod_matrix2); + + /* + We now add the second (plaintext) matrix to the encrypted one using another + new operation -- plain addition -- and square the sum. + */ + cout << "Adding and squaring: "; + evaluator.add_plain_inplace(encrypted_matrix, plain_matrix2); + evaluator.square_inplace(encrypted_matrix); + evaluator.relinearize_inplace(encrypted_matrix, relin_keys); + cout << "Done" << endl; + + /* + How much noise budget do we have left? + */ + cout << "Noise budget in result: " + << decryptor.invariant_noise_budget(encrypted_matrix) << " bits" << endl; + + /* + We decrypt and decompose the plaintext to recover the result as a matrix. + */ + Plaintext plain_result; + cout << "Decrypting result: "; + decryptor.decrypt(encrypted_matrix, plain_result); + cout << "Done" << endl; + + vector pod_result; + batch_encoder.decode(plain_result, pod_result); + + cout << "Result plaintext matrix:" << endl; + print_matrix(pod_result); + + /* + Note how the operation was performed in one go for each of the elements of + the matrix. It is possible to achieve incredible performance improvements by + using this method when the computation is easily vectorizable. + + Our discussion so far could have applied just as well for a simple vector + data type (not matrix). Now we show how the matrix view of the plaintext can + be used for more functionality. Namely, it is possible to rotate the matrix + rows cyclically, and same for the columns (i.e. swap the two rows). For this + we need the Galois keys that we generated earlier. + + We return to the original matrix that we started with. + */ + encryptor.encrypt(plain_matrix, encrypted_matrix); + cout << "Unrotated matrix: " << endl; + print_matrix(pod_matrix); + cout << "Noise budget in fresh encryption: " + << decryptor.invariant_noise_budget(encrypted_matrix) << " bits" << endl; + + /* + Now rotate the rows to the left 3 steps, decrypt, decompose, and print. + */ + evaluator.rotate_rows_inplace(encrypted_matrix, 3, gal_keys); + cout << "Rotated rows 3 steps left: " << endl; + decryptor.decrypt(encrypted_matrix, plain_result); + batch_encoder.decode(plain_result, pod_result); + print_matrix(pod_result); + cout << "Noise budget after rotation: " + << decryptor.invariant_noise_budget(encrypted_matrix) << " bits" << endl; + + /* + Rotate columns (swap rows), decrypt, decompose, and print. + */ + evaluator.rotate_columns_inplace(encrypted_matrix, gal_keys); + cout << "Rotated columns: " << endl; + decryptor.decrypt(encrypted_matrix, plain_result); + batch_encoder.decode(plain_result, pod_result); + print_matrix(pod_result); + cout << "Noise budget after rotation: " + << decryptor.invariant_noise_budget(encrypted_matrix) << " bits" << endl; + + /* + Rotate rows to the right 4 steps, decrypt, decompose, and print. + */ + evaluator.rotate_rows_inplace(encrypted_matrix, -4, gal_keys); + cout << "Rotated rows 4 steps right: " << endl; + decryptor.decrypt(encrypted_matrix, plain_result); + batch_encoder.decode(plain_result, pod_result); + print_matrix(pod_result); + cout << "Noise budget after rotation: " + << decryptor.invariant_noise_budget(encrypted_matrix) << " bits" << endl; + + /* + The output is as expected. Note how the noise budget gets a big hit in the + first rotation, but remains almost unchanged in the next rotations. This is + again the same phenomenon that occurs with relinearization, where the noise + budget is consumed down to some bound determined by the decomposition bit + count and the encryption parameters. For example, after some multiplications + have been performed rotations come basically for free (noise budget-wise), + whereas they can be relatively expensive when the noise budget is nearly + full unless a small decomposition bit count is used, which on the other hand + is computationally costly. + */ +} + +void example_bfv_basics_iv() +{ + print_example_banner("Example: BFV Basics IV"); + + /* + In this example we describe the concept of `parms_id' in the context of the + BFV scheme and show how modulus switching can be used for improving both + computation and communication cost. + + We start by setting up medium size parameters for BFV as usual. + */ + EncryptionParameters parms(scheme_type::BFV); + + parms.set_poly_modulus_degree(8192); + parms.set_coeff_modulus(coeff_modulus_128(8192)); + parms.set_plain_modulus(1 << 20); + + /* + In SEAL a particular set of encryption parameters (excluding the random + number generator) is identified uniquely by a SHA-3 hash of the parameters. + This hash is called the `parms_id' and can be easily accessed and printed + at any time. The hash will change as soon as any of the relevant parameters + is changed. + */ + cout << "Current parms_id: " << parms.parms_id() << endl; + cout << "Changing plain_modulus ..." << endl; + parms.set_plain_modulus((1 << 20) + 1); + cout << "Current parms_id: " << parms.parms_id() << endl << endl; + + /* + Create the context. + */ + auto context = SEALContext::Create(parms); + print_parameters(context); + + /* + All keys and ciphertext, and in the CKKS also plaintexts, carry the parms_id + for the encryption parameters they are created with, allowing SEAL to very + quickly determine whether the objects are valid for use and compatible for + homomorphic computations. SEAL takes care of managing, and verifying the + parms_id for all objects so the user should have no reason to change it by + hand. + */ + KeyGenerator keygen(context); + auto public_key = keygen.public_key(); + auto secret_key = keygen.secret_key(); + cout << "parms_id of public_key: " << public_key.parms_id() << endl; + cout << "parms_id of secret_key: " << secret_key.parms_id() << endl; + + Encryptor encryptor(context, public_key); + Evaluator evaluator(context); + Decryptor decryptor(context, secret_key); + + /* + Note how in the BFV scheme plaintexts do not carry the parms_id, but + ciphertexts do. + */ + Plaintext plain("1x^3 + 2x^2 + 3x^1 + 4"); + Ciphertext encrypted; + encryptor.encrypt(plain, encrypted); + cout << "parms_id of plain: " << plain.parms_id() << " (not set)" << endl; + cout << "parms_id of encrypted: " << encrypted.parms_id() << endl << endl; + + /* + When SEALContext is created from a given EncryptionParameters instance, + SEAL automatically creates a so-called "modulus switching chain", which is + a chain of other encryption parameters derived from the original set. + The parameters in the modulus switching chain are the same as the original + parameters with the exception that size of the coefficient modulus is + decreasing going down the chain. More precisely, each parameter set in the + chain attempts to remove one of the coefficient modulus primes from the + previous set; this continues until the parameter set is no longer valid + (e.g. plain_modulus is larger than the remaining coeff_modulus). It is easy + to walk through the chain and access all the parameter sets. Additionally, + each parameter set in the chain has a `chain_index' that indicates its + position in the chain so that the last set has index 0. We say that a set + of encryption parameters, or an object carrying those encryption parameters, + is at a higher level in the chain than another set of parameters if its the + chain index is bigger, i.e. it is earlier in the chain. + */ + for(auto context_data = context->context_data(); context_data; + context_data = context_data->next_context_data()) + { + cout << "Chain index: " << context_data->chain_index() << endl; + cout << "parms_id: " << context_data->parms().parms_id() << endl; + cout << "coeff_modulus primes: "; + cout << hex; + for(const auto &prime : context_data->parms().coeff_modulus()) + { + cout << prime.value() << " "; + } + cout << dec << endl; + cout << "\\" << endl; + cout << " \\-->" << endl; + } + cout << "End of chain reached" << endl << endl; + + /* + Modulus switching changes the ciphertext parameters to any set down the + chain from the current one. The function mod_switch_to_next(...) always + switches to the next set down the chain, whereas mod_switch_to(...) switches + to a parameter set down the chain corresponding to a given parms_id. + */ + auto context_data = context->context_data(); + while(context_data->next_context_data()) + { + cout << "Chain index: " << context_data->chain_index() << endl; + cout << "parms_id of encrypted: " << encrypted.parms_id() << endl; + cout << "Noise budget at this level: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + cout << "\\" << endl; + cout << " \\-->" << endl; + evaluator.mod_switch_to_next_inplace(encrypted); + context_data = context_data->next_context_data(); + } + cout << "Chain index: " << context_data->chain_index() << endl; + cout << "parms_id of encrypted: " << encrypted.parms_id() << endl; + cout << "Noise budget at this level: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + cout << "\\" << endl; + cout << " \\-->" << endl; + cout << "End of chain reached" << endl << endl; + + /* + At this point it is hard to see any benefit in doing this: we lost a huge + amount of noise budget (i.e. computational power) at each switch and seemed + to get nothing in return. The ciphertext still decrypts to the exact same + value. + */ + decryptor.decrypt(encrypted, plain); + cout << "Decryption: " << plain.to_string() << endl << endl; + + /* + However, there is a hidden benefit: the size of the ciphertext depends + linearly on the number of primes in the coefficient modulus. Thus, if there + is no need or intention to perform any more computations on a given + ciphertext, we might as well switch it down to the smallest (last) set of + parameters in the chain before sending it back to the secret key holder for + decryption. + + Also the lost noise budget is actually not as issue at all, if we do things + right, as we will see below. First we recreate the original ciphertext (with + largest parameters) and perform some simple computations on it. + */ + encryptor.encrypt(plain, encrypted); + auto relin_keys = keygen.relin_keys(60); + cout << "Noise budget before squaring: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, relin_keys); + cout << "Noise budget after squaring: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + + /* + From the print-out we see that the noise budget after these computations is + just slightly below the level we would have in a fresh ciphertext after one + modulus switch (135 bits). Surprisingly, in this case modulus switching has + no effect at all on the modulus. + */ + evaluator.mod_switch_to_next_inplace(encrypted); + cout << "Noise budget after modulus switching: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + + /* + This means that there is no harm at all in dropping some of the coefficient + modulus after doing enough computations. In some cases one might want to + switch to a lower level slightly earlier, actually sacrificing some of the + noise budget in the process, to gain computational performance from having + a smaller coefficient modulus. We see from the print-out that that the next + modulus switch should be done ideally when the noise budget reaches 81 bits. + */ + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, relin_keys); + cout << "Noise budget after squaring: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + evaluator.mod_switch_to_next_inplace(encrypted); + cout << "Noise budget after modulus switching: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, relin_keys); + cout << "Noise budget after squaring: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; + evaluator.mod_switch_to_next_inplace(encrypted); + cout << "Noise budget after modulus switching: " + << decryptor.invariant_noise_budget(encrypted) << " bits" << endl << endl; + + /* + At this point the ciphertext still decrypts correctly, has very small size, + and the computation was as efficient as possible. Note that the decryptor + can be used to decrypt a ciphertext at any level in the modulus switching + chain as long as the secret key is at a higher level in the same chain. + */ + decryptor.decrypt(encrypted, plain); + cout << "Decryption of eighth power: " << plain.to_string() << endl << endl; + + /* + In BFV modulus switching is not necessary and in some cases the user might + not want to create the modulus switching chain. This can be done by passing + a bool `false' to the SEALContext::Create(...) function as follows. + */ + context = SEALContext::Create(parms, false); + + /* + We can check that indeed the modulus switching chain has not been created. + The following loop should execute only once. + */ + for (context_data = context->context_data(); context_data; + context_data = context_data->next_context_data()) + { + cout << "Chain index: " << context_data->chain_index() << endl; + cout << "parms_id: " << context_data->parms().parms_id() << endl; + cout << "coeff_modulus primes: "; + cout << hex; + for (const auto &prime : context_data->parms().coeff_modulus()) + { + cout << prime.value() << " "; + } + cout << dec << endl; + cout << "\\" << endl; + cout << " \\-->" << endl; + } + cout << "End of chain reached" << endl << endl; + + /* + It is very important to understand how this example works since in the CKKS + scheme modulus switching has a much more fundamental purpose and the next + examples will be difficult to understand unless these basic properties are + totally clear. + */ +} + +void example_ckks_basics_i() +{ + print_example_banner("Example: CKKS Basics I"); + + /* + In this example we demonstrate using the Cheon-Kim-Kim-Song (CKKS) scheme + for encrypting and computing on floating point numbers. For full details on + the CKKS scheme, we refer the reader to https://eprint.iacr.org/2016/421. + For better performance, SEAL implements the "FullRNS" optimization for CKKS + described in https://eprint.iacr.org/2018/931. + */ + + /* + We start by creating encryption parameters for the CKKS scheme. One major + difference to the BFV scheme is that the CKKS scheme does not use the + plain_modulus parameter. + */ + EncryptionParameters parms(scheme_type::CKKS); + parms.set_poly_modulus_degree(8192); + parms.set_coeff_modulus(coeff_modulus_128(8192)); + + /* + We create the SEALContext as usual and print the parameters. + */ + auto context = SEALContext::Create(parms); + print_parameters(context); + + /* + Keys are created the same way as for the BFV scheme. + */ + KeyGenerator keygen(context); + auto public_key = keygen.public_key(); + auto secret_key = keygen.secret_key(); + auto relin_keys = keygen.relin_keys(60); + + /* + We also set up an Encryptor, Evaluator, and Decryptor as usual. + */ + Encryptor encryptor(context, public_key); + Evaluator evaluator(context); + Decryptor decryptor(context, secret_key); + + /* + To create CKKS plaintexts we need a special encoder: we cannot create them + directly from polynomials. Note that the IntegerEncoder, FractionalEncoder, + and BatchEncoder cannot be used with the CKKS scheme. The CKKS scheme allows + encryption and approximate computation on vectors of real or complex numbers + which the CKKSEncoder converts into Plaintext objects. At a high level this + looks a lot like BatchEncoder for the BFV scheme, but the theory behind it + is different. + */ + CKKSEncoder encoder(context); + + /* + In CKKS the number of slots is poly_modulus_degree / 2 and each slot encodes + one complex (or real) number. This should be contrasted with BatchEncoder in + the BFV scheme, where the number of slots is equal to poly_modulus_degree + and they are arranged into a 2-by-(poly_modulus_degree / 2) matrix. + */ + size_t slot_count = encoder.slot_count(); + cout << "Number of slots: " << slot_count << endl; + + /* + We create a small vector to encode; the CKKSEncoder will implicitly pad it + with zeros to full size (poly_modulus_degree / 2) when encoding. + */ + vector input{ 0.0, 1.1, 2.2, 3.3 }; + cout << "Input vector: " << endl; + print_vector(input); + + /* + Now we encode it with CKKSEncoder. The floating-point coefficients of input + will be scaled up by the parameter `scale'; this is necessary since even in + the CKKS scheme the plaintexts are polynomials with integer coefficients. + It is instructive to think of the scale as determining the bit-precision of + the encoding; naturally it will also affect the precision of the result. + + In CKKS the message is stored modulo coeff_modulus (in BFV it is stored + modulo plain_modulus), so the scale must not get too close to the total size + of coeff_modulus. In this case our coeff_modulus is quite large (218 bits) + so we have little to worry about in this regard. For this example a 60-bit + scale is more than enough. + */ + Plaintext plain; + double scale = pow(2.0, 60); + encoder.encode(input, scale, plain); + + /* + The vector is encrypted the same was as in BFV. + */ + Ciphertext encrypted; + encryptor.encrypt(plain, encrypted); + + /* + Another difference to the BFV scheme is that in CKKS also plaintexts are + linked to specific parameter sets: they carry the corresponding parms_id. + An overload of CKKSEncoder::encode(...) allows the caller to specify which + parameter set in the modulus switching chain (identified by parms_id) should + be used to encode the plaintext. This is important as we will see later. + */ + cout << "parms_id of plain: " << plain.parms_id() << endl; + cout << "parms_id of encrypted: " << encrypted.parms_id() << endl << endl; + + /* + The ciphertexts will keep track of the scales in the underlying plaintexts. + The current scale in every plaintext and ciphertext is easy to access. + */ + cout << "Scale in plain: " << plain.scale() << endl; + cout << "Scale in encrypted: " << encrypted.scale() << endl << endl; + + /* + Basic operations on the ciphertexts are still easy to do. Here we square + the ciphertext, decrypt, decode, and print the result. We note also that + decoding returns a vector of full size (poly_modulus_degree / 2); this is + because of the implicit zero-padding mentioned above. + */ + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, relin_keys); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, input); + cout << "Squared input: " << endl; + print_vector(input); + + /* + We notice that the results are correct. We can also print the scale in the + result and observe that it has increased. In fact, it is now the square of + the original scale (2^60). + */ + cout << "Scale in the square: " << encrypted.scale() + << " (" << log2(encrypted.scale()) << " bits)" << endl; + + /* + CKKS supports modulus switching just like the BFV scheme. We can switch + away parts of the coefficient modulus. + */ + cout << "Current coeff_modulus size: " + << context->context_data(encrypted.parms_id())-> + total_coeff_modulus_bit_count() << " bits" << endl; + + cout << "Modulus switching ..." << endl; + evaluator.mod_switch_to_next_inplace(encrypted); + + cout << "Current coeff_modulus size: " + << context->context_data(encrypted.parms_id())-> + total_coeff_modulus_bit_count() << " bits" << endl; + cout << endl; + + /* + At this point if we tried switching further SEAL would throw an exception. + This is because the scale is 120 bits and after modulus switching we would + be down to a total coeff_modulus smaller than that, which is not enough to + contain the plaintext. We decrypt and decode, and observe that the result + is the same as before. + */ + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, input); + cout << "Squared input: " << endl; + print_vector(input); + + /* + In some cases it can be convenient to change the scale of a ciphertext by + hand. For example, multiplying the scale by a number effectively divides the + underlying plaintext by that number, and vice versa. The caveat is that the + resulting scale can be incompatible with the scales of other ciphertexts. + Here we divide the ciphertext by 3. + */ + encrypted.scale() *= 3; + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, input); + cout << "Divided by 3: " << endl; + print_vector(input); + + /* + Homomorphic addition and subtraction naturally require that the scales of + the inputs are the same, but also that the encryption parameters (parms_id) + are the same. Here we add a plaintext to encrypted. Note that a scale or + parms_id mismatch would make Evaluator::add_plain(..) throw an exception; + there is no problem here since we encode the plaintext just-in-time with + exactly the right scale. + */ + vector vec_summand{ 20.2, 30.3, 40.4, 50.5 }; + cout << "Plaintext summand: " << endl; + print_vector(vec_summand); + + /* + Get the parms_id and scale from encrypted and do the addition. + */ + Plaintext plain_summand; + encoder.encode(vec_summand, encrypted.parms_id(), encrypted.scale(), + plain_summand); + evaluator.add_plain_inplace(encrypted, plain_summand); + + /* + Decryption and decoding should give the correct result. + */ + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, input); + cout << "Sum: " << endl; + print_vector(input); + + /* + Note that we have not mentioned noise budget at all. In fact, CKKS does not + have a similar concept of a noise budget as BFV; instead, the homomorphic + encryption noise will overlap the low-order bits of the message. This is why + scaling is needed: the message must be moved to higher-order bits to protect + it from the noise. Still, it is difficult to completely decouple the noise + from the message itself; hence the noise/error budget cannot be exactly + measured from a ciphertext alone. + */ +} + +void example_ckks_basics_ii() +{ + print_example_banner("Example: CKKS Basics II"); + + /* + The previous example did not really make it clear why CKKS is useful at all. + Certainly one can scale floating-point numbers to integers, encrypt them, + keep track of the scale, and operate on them by just using BFV. The problem + with this approach is that the scale quickly grows larger than the size of + the coefficient modulus, preventing further computations. The true power of + CKKS is that it allows the scale to be switched down (`rescaling') without + changing the encrypted values. + + To demonstrate this, we start by setting up the same environment we had in + the previous example. + */ + EncryptionParameters parms(scheme_type::CKKS); + parms.set_poly_modulus_degree(8192); + parms.set_coeff_modulus(coeff_modulus_128(8192)); + + auto context = SEALContext::Create(parms); + print_parameters(context); + + KeyGenerator keygen(context); + auto public_key = keygen.public_key(); + auto secret_key = keygen.secret_key(); + auto relin_keys = keygen.relin_keys(60); + + Encryptor encryptor(context, public_key); + Evaluator evaluator(context); + Decryptor decryptor(context, secret_key); + + CKKSEncoder encoder(context); + + size_t slot_count = encoder.slot_count(); + cout << "Number of slots: " << slot_count << endl; + + vector input{ 0.0, 1.1, 2.2, 3.3 }; + cout << "Input vector: " << endl; + print_vector(input); + + /* + We use a slightly smaller scale in this example. + */ + Plaintext plain; + double scale = pow(2.0, 60); + encoder.encode(input, scale, plain); + + Ciphertext encrypted; + encryptor.encrypt(plain, encrypted); + + /* + Print the scale and the parms_id for encrypted. + */ + cout << "Chain index of (encryption parameters of) encrypted: " + << context->context_data(encrypted.parms_id())->chain_index() << endl; + cout << "Scale in encrypted before squaring: " << encrypted.scale() << endl; + + /* + We did this already in the previous example: square encrypted and observe + the scale growth. + */ + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, relin_keys); + cout << "Scale in encrypted after squaring: " << encrypted.scale() + << " (" << log2(encrypted.scale()) << " bits)" << endl; + cout << "Current coeff_modulus size: " + << context->context_data(encrypted.parms_id())-> + total_coeff_modulus_bit_count() << " bits" << endl; + cout << endl; + + /* + Now, to prevent the scale from growing too large in subsequent operations, + we apply rescaling. + */ + cout << "Rescaling ..." << endl << endl; + evaluator.rescale_to_next_inplace(encrypted); + + /* + Rescaling changes the coefficient modulus as modulus switching does. These + operations are in fact very closely related. Moreover, the scale indeed has + been significantly reduced: rescaling divides the scale by the coefficient + modulus prime that was switched away. Since our coefficient modulus in this + case consisted of the primes (see seal/utils/global.cpp) + + 0x7fffffff380001, 0x7ffffffef00001, + 0x3fffffff000001, 0x3ffffffef40001, + + the last of which is 54 bits, the bit-size of the scale was reduced by + precisely 54 bits. Finer granularity rescaling would require smaller primes + to be used, but this might lead to performance problems as the computational + cost of homomorphic operations and the size of ciphertexts depends linearly + on the number of primes in coeff_modulus. + */ + cout << "Chain index of (encryption parameters of) encrypted: " + << context->context_data(encrypted.parms_id())->chain_index() << endl; + cout << "Scale in encrypted: " << encrypted.scale() + << " (" << log2(encrypted.scale()) << " bits)" << endl; + cout << "Current coeff_modulus size: " + << context->context_data(encrypted.parms_id())-> + total_coeff_modulus_bit_count() << " bits" << endl; + cout << endl; + + /* + We can even compute the fourth power of the input. Note that it is very + important to first relinearize and then rescale. Trying to do these two + operations in the opposite order will make SEAL throw and exception. + */ + cout << "Squaring and rescaling ..." << endl << endl; + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, relin_keys); + evaluator.rescale_to_next_inplace(encrypted); + + cout << "Chain index of (encryption parameters of) encrypted: " + << context->context_data(encrypted.parms_id())->chain_index() << endl; + cout << "Scale in encrypted: " << encrypted.scale() + << " (" << log2(encrypted.scale()) << " bits)" << endl; + cout << "Current coeff_modulus size: " + << context->context_data(encrypted.parms_id())-> + total_coeff_modulus_bit_count() << " bits" << endl; + cout << endl; + + /* + At this point our scale is 78 bits and the coefficient modulus is 110 bits. + This means that we cannot square the result anymore, but if we rescale once + more and then square, things should work out better. We cannot relinearize + with relin_keys at this point due to the large decomposition bit count we + used: the noise from relinearization would completely destroy our result + due to the small scale we are at. + */ + cout << "Rescaling and squaring (no relinearization) ..." << endl << endl; + evaluator.rescale_to_next_inplace(encrypted); + evaluator.square_inplace(encrypted); + + cout << "Chain index of (encryption parameters of) encrypted: " + << context->context_data(encrypted.parms_id())->chain_index() << endl; + cout << "Scale in encrypted: " << encrypted.scale() + << " (" << log2(encrypted.scale()) << " bits)" << endl; + cout << "Current coeff_modulus size: " + << context->context_data(encrypted.parms_id())-> + total_coeff_modulus_bit_count() << " bits" << endl; + cout << endl; + + /* + We decrypt, decode, and print the results. + */ + decryptor.decrypt(encrypted, plain); + vector result; + encoder.decode(plain, result); + cout << "Eighth powers: " << endl; + print_vector(result); + + /* + We have gone pretty low in the scale at this point and can no longer expect + to get entirely accurate results. Still, our results are quite accurate. + */ + vector precise_result{}; + transform(input.begin(), input.end(), back_inserter(precise_result), + [](auto in) { return pow(in, 8); }); + cout << "Precise result: " << endl; + print_vector(precise_result); +} + +void example_ckks_basics_iii() +{ + print_example_banner("Example: CKKS Basics III"); + + /* + In this example we demonstrate evaluating a polynomial function on + floating-point input data. The challenges we encounter will be related to + matching scales and encryption parameters when adding together terms of + different degrees in the polynomial evaluation. We start by setting up an + environment similar to what we had in the above examples. + */ + EncryptionParameters parms(scheme_type::CKKS); + parms.set_poly_modulus_degree(8192); + + /* + In this example we decide to use four 40-bit moduli for more flexible + rescaling. Note that 4*40 bits = 160 bits, which is well below the size of + the default coefficient modulus (see seal/util/globals.cpp). It is always + more secure to use a smaller coefficient modulus while keeping the degree of + the polynomial modulus fixed. Since the coeff_mod_128(8192) default 218-bit + coefficient modulus achieves already a 128-bit security level, this 160-bit + modulus must be much more secure. + + We use the small_mods_40bit(int) function to get primes from a hard-coded + list of 40-bit prime numbers; it is important that all primes used for the + coefficient modulus are distinct. + */ + parms.set_coeff_modulus({ + small_mods_40bit(0), small_mods_40bit(1), + small_mods_40bit(2), small_mods_40bit(3) }); + + auto context = SEALContext::Create(parms); + print_parameters(context); + + KeyGenerator keygen(context); + auto public_key = keygen.public_key(); + auto secret_key = keygen.secret_key(); + auto relin_keys = keygen.relin_keys(60); + + Encryptor encryptor(context, public_key); + Evaluator evaluator(context); + Decryptor decryptor(context, secret_key); + + CKKSEncoder encoder(context); + size_t slot_count = encoder.slot_count(); + cout << "Number of slots: " << slot_count << endl; + + /* + In this example our goal is to evaluate the polynomial PI*x^3 + 0.4x + 1 on + an encrypted input x for 4096 equidistant points x in the interval [0, 1]. + */ + vector input; + input.reserve(slot_count); + double curr_point = 0, step_size = 1.0 / (static_cast(slot_count) - 1); + for (size_t i = 0; i < slot_count; i++, curr_point += step_size) + { + input.push_back(curr_point); + } + cout << "Input vector: " << endl; + print_vector(input, 3, 7); + cout << "Evaluating polynomial PI*x^3 + 0.4x + 1 ..." << endl << endl; + + /* + Now encode and encrypt the input using the last of the coeff_modulus primes + as the scale for a reason that will become clear soon. + */ + auto scale = static_cast(parms.coeff_modulus().back().value()); + Plaintext plain_x; + encoder.encode(input, scale, plain_x); + Ciphertext encrypted_x1; + encryptor.encrypt(plain_x, encrypted_x1); + + /* + We create plaintext elements for PI, 0.4, and 1, using an overload of + CKKSEncoder::encode(...) that encodes the given floating-point value to + every slot in the vector. + */ + Plaintext plain_coeff3, plain_coeff1, plain_coeff0; + encoder.encode(3.14159265, scale, plain_coeff3); + encoder.encode(0.4, scale, plain_coeff1); + encoder.encode(1.0, scale, plain_coeff0); + + /* + To compute x^3 we first compute x^2, relinearize, and rescale. + */ + Ciphertext encrypted_x3; + evaluator.square(encrypted_x1, encrypted_x3); + evaluator.relinearize_inplace(encrypted_x3, relin_keys); + evaluator.rescale_to_next_inplace(encrypted_x3); + + /* + Now encrypted_x3 is at different encryption parameters than encrypted_x1, + preventing us from multiplying them together to compute x^3. We could simply + switch encrypted_x1 down to the next parameters in the modulus switching + chain. Since we still need to multiply the x^3 term with PI (plain_coeff3), + we instead compute PI*x first and multiply that with x^2 to obtain PI*x^3. + This product poses no problems since both inputs are at the same scale and + use the same encryption parameters. We rescale afterwards to change the + scale back to 40 bits, which will also drop the coefficient modulus down to + 120 bits. + */ + Ciphertext encrypted_x1_coeff3; + evaluator.multiply_plain(encrypted_x1, plain_coeff3, encrypted_x1_coeff3); + evaluator.rescale_to_next_inplace(encrypted_x1_coeff3); + + /* + Since both encrypted_x3 and encrypted_x1_coeff3 now have the same scale and + use same encryption parameters, we can multiply them together. We write the + result to encrypted_x3. + */ + evaluator.multiply_inplace(encrypted_x3, encrypted_x1_coeff3); + evaluator.relinearize_inplace(encrypted_x3, relin_keys); + evaluator.rescale_to_next_inplace(encrypted_x3); + + /* + Next we compute the degree one term. All this requires is one multiply_plain + with plain_coeff1. We overwrite encrypted_x1 with the result. + */ + evaluator.multiply_plain_inplace(encrypted_x1, plain_coeff1); + evaluator.rescale_to_next_inplace(encrypted_x1); + + /* + Now we would hope to compute the sum of all three terms. However, there is + a serious problem: the encryption parameters used by all three terms are + different due to modulus switching from rescaling. + */ + cout << "Parameters used by all three terms are different:" << endl; + cout << "Modulus chain index for encrypted_x3: " + << context->context_data(encrypted_x3.parms_id())->chain_index() << endl; + cout << "Modulus chain index for encrypted_x1: " + << context->context_data(encrypted_x1.parms_id())->chain_index() << endl; + cout << "Modulus chain index for plain_coeff0: " + << context->context_data(plain_coeff0.parms_id())->chain_index() << endl; + cout << endl; + + /* + Let us carefully consider what the scales are at this point. If we denote + the primes in coeff_modulus as q1, q2, q3, q4 (order matters here), then all + fresh encodings start with a scale equal to q4 (this was a choice we made + above). After the computations above the scale in encrypted_x3 is q4^2/q3: + + * The product x^2 has scale q4^2; + * The produt PI*x has scale q4^2; + * Rescaling both of these by q4 (last prime) results in scale q4; + * Multiplication to obtain PI*x^3 raises the scale to q4^2; + * Rescaling by q3 (last prime) yields a scale of q4^2/q3. + + The scale in both encrypted_x1 and plain_coeff0 is just q4. + */ + ios old_fmt(nullptr); + old_fmt.copyfmt(cout); + cout << fixed << setprecision(10); + cout << "Scale in encrypted_x3: " << encrypted_x3.scale() << endl; + cout << "Scale in encrypted_x1: " << encrypted_x1.scale() << endl; + cout << "Scale in plain_coeff0: " << plain_coeff0.scale() << endl; + cout << endl; + cout.copyfmt(old_fmt); + + /* + There are a couple of ways to fix this this problem. Since q4 and q3 are + really close to each other, we could simply "lie" to SEAL and set the scales + to be the same. For example, changing the scale of encrypted_x3 to be q4 + simply means that we scale the value of encrypted_x3 by q4/q3 which is very + close to 1; this should not result in any noticeable error. + + Another option would be to encode 1 with scale q4, perform a multiply_plain + with encrypted_x1, and finally rescale. In this case we would additionally + make sure to encode 1 with the appropriate encryption parameters (parms_id). + + A third option would be to initially encode plain_coeff1 with scale q4^2/q3. + Then, after multiplication with encrypted_x1 and rescaling, the result would + have scale q4^2/q3. Since encoding can be computationally costly, this may + not be a realistic option in some cases. + + In this example we will use the first (simplest) approach and simply change + the scale of encrypted_x3. + */ + encrypted_x3.scale() = encrypted_x1.scale(); + + /* + We still have a problem with mismatching encryption parameters. This is easy + to fix by using traditional modulus switching (no rescaling). Note that we + use here the Evaluator::mod_switch_to_inplace(...) function to switch to + encryption parameters down the chain with a specific parms_id. + */ + evaluator.mod_switch_to_inplace(encrypted_x1, encrypted_x3.parms_id()); + evaluator.mod_switch_to_inplace(plain_coeff0, encrypted_x3.parms_id()); + + /* + All three ciphertexts are now compatible and can be added. + */ + Ciphertext encrypted_result; + evaluator.add(encrypted_x3, encrypted_x1, encrypted_result); + evaluator.add_plain_inplace(encrypted_result, plain_coeff0); + + /* + Print the chain index and scale for encrypted_result. + */ + cout << "Modulus chain index for encrypted_result: " + << context->context_data(encrypted_result.parms_id()) + ->chain_index() << endl; + old_fmt.copyfmt(cout); + cout << fixed << setprecision(10); + cout << "Scale in encrypted_result: " << encrypted_result.scale(); + cout.copyfmt(old_fmt); + cout << " (" << log2(encrypted_result.scale()) << " bits)" << endl; + + /* + We decrypt, decode, and print the result. + */ + Plaintext plain_result; + decryptor.decrypt(encrypted_result, plain_result); + vector result; + encoder.decode(plain_result, result); + cout << "Result of PI*x^3 + 0.4x + 1:" << endl; + print_vector(result, 3, 7); + + /* + At this point if we wanted to multiply encrypted_result one more time, the + other multiplicand would have to have scale less than 40 bits, otherwise + the scale would become larger than the coeff_modulus itself. + */ + cout << "Current coeff_modulus size for encrypted_result: " + << context->context_data(encrypted_result.parms_id())-> + total_coeff_modulus_bit_count() << " bits" << endl << endl; + + /* + A very extreme case for multiplication is where we multiply a ciphertext + with a vector of values that are all the same integer. For example, let us + multiply encrypted_result by 7. In this case we do not need any scaling in + the multiplicand due to a different (much simpler) encoding process. + */ + Plaintext plain_integer_scalar; + encoder.encode(7, encrypted_result.parms_id(), plain_integer_scalar); + evaluator.multiply_plain_inplace(encrypted_result, plain_integer_scalar); + + old_fmt.copyfmt(cout); + cout << fixed << setprecision(10); + cout << "Scale in plain_integer_scalar scale: " + << plain_integer_scalar.scale() << endl; + cout << "Scale in encrypted_result: " << encrypted_result.scale() << endl; + cout.copyfmt(old_fmt); + + /* + We decrypt, decode, and print the result. + */ + decryptor.decrypt(encrypted_result, plain_result); + encoder.decode(plain_result, result); + cout << "Result of 7 * (PI*x^3 + 0.4x + 1):" << endl; + print_vector(result, 3, 7); + + /* + Finally, we show how to apply vector rotations on the encrypted data. This + is very similar to how matrix rotations work in the BFV scheme. We try this + with three sizes of Galois keys. In some cases it is desirable for memory + reasons to create Galois keys that support only specific rotations. This can + be done by passing to KeyGenerator::galois_keys(...) a vector of signed + integers specifying the desired rotation step counts. Here we create Galois + keys that only allow cyclic rotation by a single step (at a time) to the left. + */ + auto gal_keys30 = keygen.galois_keys(30, vector{ 1 }); + auto gal_keys15 = keygen.galois_keys(15, vector{ 1 }); + + Ciphertext rotated_result; + evaluator.rotate_vector(encrypted_result, 1, gal_keys15, rotated_result); + decryptor.decrypt(rotated_result, plain_result); + encoder.decode(plain_result, result); + cout << "Result rotated with dbc 15:" << endl; + print_vector(result, 3, 7); + + evaluator.rotate_vector(encrypted_result, 1, gal_keys30, rotated_result); + decryptor.decrypt(rotated_result, plain_result); + encoder.decode(plain_result, result); + cout << "Result rotated with dbc 30:" << endl; + print_vector(result, 3, 5); + + /* + We notice that the using the smallest decomposition bit count introduces + the least amount of error in the result. The problem is that our scale at + this point is very small -- only 40 bits -- so a rotation with decomposition + bit count 30 or bigger already destroys most or all of the message bits. + Ideally rotations would be performed right after multiplications before any + rescaling takes place. This way the scale is as large as possible and the + additive noise coming from the rotation (or relinearization) will be totally + shadowed by the large scale, and subsequently scaled down by the following + rescaling. Of course this may not always be possible to arrange. + + We did not show any computations on complex numbers in these examples, but + the CKKSEncoder would allow us to have done that just as easily. Additions + and multiplications behave just as one would expect. It is also possible + to complex conjugate the values in a ciphertext by using the functions + Evaluator::complex_conjugate[_inplace](...). + */ +} + +void example_bfv_performance() +{ + print_example_banner("Example: BFV Performance Test"); + + /* + In this example we time all the basic operations. We use the following + lambda function to run the test. + */ + auto performance_test = [](auto context) + { + chrono::high_resolution_clock::time_point time_start, time_end; + + print_parameters(context); + auto &parms = context->context_data()->parms(); + auto &plain_modulus = parms.plain_modulus(); + size_t poly_modulus_degree = parms.poly_modulus_degree(); + + /* + Set up keys. For both relinearization and rotations we use a large + decomposition bit count for best possible computational performance. + */ + cout << "Generating secret/public keys: "; + KeyGenerator keygen(context); + cout << "Done" << endl; + + auto secret_key = keygen.secret_key(); + auto public_key = keygen.public_key(); + + /* + Generate relinearization keys. + */ + int dbc = dbc_max(); + cout << "Generating relinearization keys (dbc = " << dbc << "): "; + time_start = chrono::high_resolution_clock::now(); + auto relin_keys = keygen.relin_keys(dbc); + time_end = chrono::high_resolution_clock::now(); + auto time_diff = chrono::duration_cast(time_end - time_start); + cout << "Done [" << time_diff.count() << " microseconds]" << endl; + + /* + Generate Galois keys. In larger examples the Galois keys can use + a significant amount of memory, which can be a problem in constrained + systems. The user should try enabling some of the larger runs of the + test (see below) and to observe their effect on the memory pool + allocation size. The key generation can also take a significant amount + of time, as can be observed from the print-out. + */ + if (!context->context_data()->qualifiers().using_batching) + { + cout << "Given encryption parameters do not support batching." << endl; + return; + } + cout << "Generating Galois keys (dbc = " << dbc << "): "; + time_start = chrono::high_resolution_clock::now(); + auto gal_keys = keygen.galois_keys(dbc); + time_end = chrono::high_resolution_clock::now(); + time_diff = chrono::duration_cast(time_end - time_start); + cout << "Done [" << time_diff.count() << " microseconds]" << endl; + + Encryptor encryptor(context, public_key); + Decryptor decryptor(context, secret_key); + Evaluator evaluator(context); + BatchEncoder batch_encoder(context); + IntegerEncoder encoder(plain_modulus); + + /* + These will hold the total times used by each operation. + */ + chrono::microseconds time_batch_sum(0); + chrono::microseconds time_unbatch_sum(0); + chrono::microseconds time_encrypt_sum(0); + chrono::microseconds time_decrypt_sum(0); + chrono::microseconds time_add_sum(0); + chrono::microseconds time_multiply_sum(0); + chrono::microseconds time_multiply_plain_sum(0); + chrono::microseconds time_square_sum(0); + chrono::microseconds time_relinearize_sum(0); + chrono::microseconds time_rotate_rows_one_step_sum(0); + chrono::microseconds time_rotate_rows_random_sum(0); + chrono::microseconds time_rotate_columns_sum(0); + + /* + How many times to run the test? + */ + int count = 10; + + /* + Populate a vector of values to batch. + */ + vector pod_vector; + random_device rd; + for (size_t i = 0; i < batch_encoder.slot_count(); i++) + { + pod_vector.push_back(rd() % plain_modulus.value()); + } + + cout << "Running tests "; + for (int i = 0; i < count; i++) + { + /* + [Batching] + There is nothing unusual here. We batch our random plaintext matrix + into the polynomial. The user can try changing the decomposition bit + count to something smaller to see the effect. Note how the plaintext + we create is of the exactly right size so unnecessary reallocations + are avoided. + */ + Plaintext plain(parms.poly_modulus_degree(), 0); + time_start = chrono::high_resolution_clock::now(); + batch_encoder.encode(pod_vector, plain); + time_end = chrono::high_resolution_clock::now(); + time_batch_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Unbatching] + We unbatch what we just batched. + */ + vector pod_vector2(batch_encoder.slot_count()); + time_start = chrono::high_resolution_clock::now(); + batch_encoder.decode(plain, pod_vector2); + time_end = chrono::high_resolution_clock::now(); + time_unbatch_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + if (pod_vector2 != pod_vector) + { + throw runtime_error("Batch/unbatch failed. Something is wrong."); + } + + /* + [Encryption] + We make sure our ciphertext is already allocated and large enough to + hold the encryption with these encryption parameters. We encrypt our + random batched matrix here. + */ + Ciphertext encrypted(context); + time_start = chrono::high_resolution_clock::now(); + encryptor.encrypt(plain, encrypted); + time_end = chrono::high_resolution_clock::now(); + time_encrypt_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Decryption] + We decrypt what we just encrypted. + */ + Plaintext plain2(poly_modulus_degree, 0); + time_start = chrono::high_resolution_clock::now(); + decryptor.decrypt(encrypted, plain2); + time_end = chrono::high_resolution_clock::now(); + time_decrypt_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + if (plain2 != plain) + { + throw runtime_error("Encrypt/decrypt failed. Something is wrong."); + } + + /* + [Add] + We create two ciphertexts that are both of size 2, and perform a few + additions with them. + */ + Ciphertext encrypted1(context); + encryptor.encrypt(encoder.encode(i), encrypted1); + Ciphertext encrypted2(context); + encryptor.encrypt(encoder.encode(i + 1), encrypted2); + time_start = chrono::high_resolution_clock::now(); + evaluator.add_inplace(encrypted1, encrypted1); + evaluator.add_inplace(encrypted2, encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + time_end = chrono::high_resolution_clock::now(); + time_add_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start) / 3; + + /* + [Multiply] + We multiply two ciphertexts of size 2. Since the size of the result + will be 3, and will overwrite the first argument, we reserve first + enough memory to avoid reallocating during multiplication. + */ + encrypted1.reserve(3); + time_start = chrono::high_resolution_clock::now(); + evaluator.multiply_inplace(encrypted1, encrypted2); + time_end = chrono::high_resolution_clock::now(); + time_multiply_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Multiply Plain] + We multiply a ciphertext of size 2 with a random plaintext. Recall + that multiply_plain does not change the size of the ciphertext so we + use encrypted2 here, which still has size 2. + */ + time_start = chrono::high_resolution_clock::now(); + evaluator.multiply_plain_inplace(encrypted2, plain); + time_end = chrono::high_resolution_clock::now(); + time_multiply_plain_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Square] + We continue to use the size 2 ciphertext encrypted2. Now we square + it; this should be faster than generic homomorphic multiplication. + */ + time_start = chrono::high_resolution_clock::now(); + evaluator.square_inplace(encrypted2); + time_end = chrono::high_resolution_clock::now(); + time_square_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Relinearize] + Time to get back to encrypted1; at this point it still has size 3. + We now relinearize it back to size 2. Since the allocation is + currently big enough to contain a ciphertext of size 3, no costly + reallocations are needed in the process. + */ + time_start = chrono::high_resolution_clock::now(); + evaluator.relinearize_inplace(encrypted1, relin_keys); + time_end = chrono::high_resolution_clock::now(); + time_relinearize_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Rotate Rows One Step] + We rotate matrix rows by one step left and measure the time. + */ + time_start = chrono::high_resolution_clock::now(); + evaluator.rotate_rows_inplace(encrypted, 1, gal_keys); + evaluator.rotate_rows_inplace(encrypted, -1, gal_keys); + time_end = chrono::high_resolution_clock::now(); + time_rotate_rows_one_step_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start) / 2; + + /* + [Rotate Rows Random] + We rotate matrix rows by a random number of steps. This is more + expensive than rotating by just one step. + */ + size_t row_size = batch_encoder.slot_count() / 2; + int random_rotation = static_cast(rd() % row_size); + time_start = chrono::high_resolution_clock::now(); + evaluator.rotate_rows_inplace(encrypted, random_rotation, gal_keys); + time_end = chrono::high_resolution_clock::now(); + time_rotate_rows_random_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Rotate Columns] + Nothing surprising here. + */ + time_start = chrono::high_resolution_clock::now(); + evaluator.rotate_columns_inplace(encrypted, gal_keys); + time_end = chrono::high_resolution_clock::now(); + time_rotate_columns_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + Print a dot to indicate progress. + */ + cout << "."; + cout.flush(); + } + + cout << " Done" << endl << endl; + cout.flush(); + + auto avg_batch = time_batch_sum.count() / count; + auto avg_unbatch = time_unbatch_sum.count() / count; + auto avg_encrypt = time_encrypt_sum.count() / count; + auto avg_decrypt = time_decrypt_sum.count() / count; + auto avg_add = time_add_sum.count() / count; + auto avg_multiply = time_multiply_sum.count() / count; + auto avg_multiply_plain = time_multiply_plain_sum.count() / count; + auto avg_square = time_square_sum.count() / count; + auto avg_relinearize = time_relinearize_sum.count() / count; + auto avg_rotate_rows_one_step = time_rotate_rows_one_step_sum.count() / count; + auto avg_rotate_rows_random = time_rotate_rows_random_sum.count() / count; + auto avg_rotate_columns = time_rotate_columns_sum.count() / count; + + cout << "Average batch: " << avg_batch << " microseconds" << endl; + cout << "Average unbatch: " << avg_unbatch << " microseconds" << endl; + cout << "Average encrypt: " << avg_encrypt << " microseconds" << endl; + cout << "Average decrypt: " << avg_decrypt << " microseconds" << endl; + cout << "Average add: " << avg_add << " microseconds" << endl; + cout << "Average multiply: " << avg_multiply << " microseconds" << endl; + cout << "Average multiply plain: " << avg_multiply_plain << " microseconds" << endl; + cout << "Average square: " << avg_square << " microseconds" << endl; + cout << "Average relinearize: " << avg_relinearize << " microseconds" << endl; + cout << "Average rotate rows one step: " << avg_rotate_rows_one_step << " microseconds" << endl; + cout << "Average rotate rows random: " << avg_rotate_rows_random << " microseconds" << endl; + cout << "Average rotate columns: " << avg_rotate_columns << " microseconds" << endl; + cout.flush(); + }; + + EncryptionParameters parms(scheme_type::BFV); + parms.set_poly_modulus_degree(4096); + parms.set_coeff_modulus(coeff_modulus_128(4096)); + parms.set_plain_modulus(786433); + performance_test(SEALContext::Create(parms)); + + cout << endl; + parms.set_poly_modulus_degree(8192); + parms.set_coeff_modulus(coeff_modulus_128(8192)); + parms.set_plain_modulus(786433); + performance_test(SEALContext::Create(parms)); + + cout << endl; + parms.set_poly_modulus_degree(16384); + parms.set_coeff_modulus(coeff_modulus_128(16384)); + parms.set_plain_modulus(786433); + performance_test(SEALContext::Create(parms)); + + /* + Comment out the following to run the biggest example. + */ + // cout << endl; + // parms.set_poly_modulus_degree(32768); + // parms.set_coeff_modulus(coeff_modulus_128(32768)); + // parms.set_plain_modulus(786433); + // performance_test(SEALContext::Create(parms)); +} + +void example_ckks_performance() +{ + print_example_banner("Example: CKKS Performance Test"); + + /* + In this example we time all the basic operations. We use the following + lambda function to run the test. This is largely similar to the function + in the previous example. + */ + auto performance_test = [](auto context) + { + chrono::high_resolution_clock::time_point time_start, time_end; + + print_parameters(context); + auto &parms = context->context_data()->parms(); + size_t poly_modulus_degree = parms.poly_modulus_degree(); + + cout << "Generating secret/public keys: "; + KeyGenerator keygen(context); + cout << "Done" << endl; + + auto secret_key = keygen.secret_key(); + auto public_key = keygen.public_key(); + + int dbc = dbc_max(); + cout << "Generating relinearization keys (dbc = " << dbc << "): "; + time_start = chrono::high_resolution_clock::now(); + auto relin_keys = keygen.relin_keys(dbc); + time_end = chrono::high_resolution_clock::now(); + auto time_diff = chrono::duration_cast(time_end - time_start); + cout << "Done [" << time_diff.count() << " microseconds]" << endl; + + if (!context->context_data()->qualifiers().using_batching) + { + cout << "Given encryption parameters do not support batching." << endl; + return; + } + cout << "Generating Galois keys (dbc = " << dbc << "): "; + time_start = chrono::high_resolution_clock::now(); + auto gal_keys = keygen.galois_keys(dbc); + time_end = chrono::high_resolution_clock::now(); + time_diff = chrono::duration_cast(time_end - time_start); + cout << "Done [" << time_diff.count() << " microseconds]" << endl; + + Encryptor encryptor(context, public_key); + Decryptor decryptor(context, secret_key); + Evaluator evaluator(context); + CKKSEncoder ckks_encoder(context); + + chrono::microseconds time_encode_sum(0); + chrono::microseconds time_decode_sum(0); + chrono::microseconds time_encrypt_sum(0); + chrono::microseconds time_decrypt_sum(0); + chrono::microseconds time_add_sum(0); + chrono::microseconds time_multiply_sum(0); + chrono::microseconds time_multiply_plain_sum(0); + chrono::microseconds time_square_sum(0); + chrono::microseconds time_relinearize_sum(0); + chrono::microseconds time_rescale_sum(0); + chrono::microseconds time_rotate_one_step_sum(0); + chrono::microseconds time_rotate_random_sum(0); + chrono::microseconds time_conjugate_sum(0); + + /* + How many times to run the test? + */ + int count = 10; + + /* + Populate a vector of floating-point values to batch. + */ + vector pod_vector; + random_device rd; + for (size_t i = 0; i < ckks_encoder.slot_count(); i++) + { + pod_vector.push_back(1.001 * static_cast(i)); + } + + cout << "Running tests "; + for (int i = 0; i < count; i++) + { + /* + [Encoding] + */ + Plaintext plain(parms.poly_modulus_degree() * + parms.coeff_modulus().size(), 0); + time_start = chrono::high_resolution_clock::now(); + ckks_encoder.encode(pod_vector, + static_cast(parms.coeff_modulus().back().value()), plain); + time_end = chrono::high_resolution_clock::now(); + time_encode_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Decoding] + */ + vector pod_vector2(ckks_encoder.slot_count()); + time_start = chrono::high_resolution_clock::now(); + ckks_encoder.decode(plain, pod_vector2); + time_end = chrono::high_resolution_clock::now(); + time_decode_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Encryption] + */ + Ciphertext encrypted(context); + time_start = chrono::high_resolution_clock::now(); + encryptor.encrypt(plain, encrypted); + time_end = chrono::high_resolution_clock::now(); + time_encrypt_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Decryption] + */ + Plaintext plain2(poly_modulus_degree, 0); + time_start = chrono::high_resolution_clock::now(); + decryptor.decrypt(encrypted, plain2); + time_end = chrono::high_resolution_clock::now(); + time_decrypt_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Add] + */ + Ciphertext encrypted1(context); + ckks_encoder.encode(i + 1, plain); + encryptor.encrypt(plain, encrypted1); + Ciphertext encrypted2(context); + ckks_encoder.encode(i + 1, plain2); + encryptor.encrypt(plain2, encrypted2); + time_start = chrono::high_resolution_clock::now(); + evaluator.add_inplace(encrypted1, encrypted1); + evaluator.add_inplace(encrypted2, encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + time_end = chrono::high_resolution_clock::now(); + time_add_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start) / 3; + + /* + [Multiply] + */ + encrypted1.reserve(3); + time_start = chrono::high_resolution_clock::now(); + evaluator.multiply_inplace(encrypted1, encrypted2); + time_end = chrono::high_resolution_clock::now(); + time_multiply_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Multiply Plain] + */ + time_start = chrono::high_resolution_clock::now(); + evaluator.multiply_plain_inplace(encrypted2, plain); + time_end = chrono::high_resolution_clock::now(); + time_multiply_plain_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Square] + */ + time_start = chrono::high_resolution_clock::now(); + evaluator.square_inplace(encrypted2); + time_end = chrono::high_resolution_clock::now(); + time_square_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Relinearize] + */ + time_start = chrono::high_resolution_clock::now(); + evaluator.relinearize_inplace(encrypted1, relin_keys); + time_end = chrono::high_resolution_clock::now(); + time_relinearize_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Rescale] + */ + time_start = chrono::high_resolution_clock::now(); + evaluator.rescale_to_next_inplace(encrypted1); + time_end = chrono::high_resolution_clock::now(); + time_rescale_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Rotate Vector] + */ + time_start = chrono::high_resolution_clock::now(); + evaluator.rotate_vector_inplace(encrypted, 1, gal_keys); + evaluator.rotate_vector_inplace(encrypted, -1, gal_keys); + time_end = chrono::high_resolution_clock::now(); + time_rotate_one_step_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start) / 2; + + /* + [Rotate Vector Random] + */ + int random_rotation = static_cast(rd() % ckks_encoder.slot_count()); + time_start = chrono::high_resolution_clock::now(); + evaluator.rotate_vector_inplace(encrypted, random_rotation, gal_keys); + time_end = chrono::high_resolution_clock::now(); + time_rotate_random_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + [Complex Conjugate] + */ + time_start = chrono::high_resolution_clock::now(); + evaluator.complex_conjugate_inplace(encrypted, gal_keys); + time_end = chrono::high_resolution_clock::now(); + time_conjugate_sum += chrono::duration_cast< + chrono::microseconds>(time_end - time_start); + + /* + Print a dot to indicate progress. + */ + cout << "."; + cout.flush(); + } + + cout << " Done" << endl << endl; + cout.flush(); + + auto avg_encode = time_encode_sum.count() / count; + auto avg_decode = time_decode_sum.count() / count; + auto avg_encrypt = time_encrypt_sum.count() / count; + auto avg_decrypt = time_decrypt_sum.count() / count; + auto avg_add = time_add_sum.count() / count; + auto avg_multiply = time_multiply_sum.count() / count; + auto avg_multiply_plain = time_multiply_plain_sum.count() / count; + auto avg_square = time_square_sum.count() / count; + auto avg_relinearize = time_relinearize_sum.count() / count; + auto avg_rescale = time_rescale_sum.count() / count; + auto avg_rotate_one_step = time_rotate_one_step_sum.count() / count; + auto avg_rotate_random = time_rotate_random_sum.count() / count; + auto avg_conjugate = time_conjugate_sum.count() / count; + + cout << "Average encode: " << avg_encode << " microseconds" << endl; + cout << "Average decode: " << avg_decode << " microseconds" << endl; + cout << "Average encrypt: " << avg_encrypt << " microseconds" << endl; + cout << "Average decrypt: " << avg_decrypt << " microseconds" << endl; + cout << "Average add: " << avg_add << " microseconds" << endl; + cout << "Average multiply: " << avg_multiply << " microseconds" << endl; + cout << "Average multiply plain: " << avg_multiply_plain << " microseconds" << endl; + cout << "Average square: " << avg_square << " microseconds" << endl; + cout << "Average relinearize: " << avg_relinearize << " microseconds" << endl; + cout << "Average rescale: " << avg_rescale << " microseconds" << endl; + cout << "Average rotate vector one step: " << avg_rotate_one_step << " microseconds" << endl; + cout << "Average rotate vector random: " << avg_rotate_random << " microseconds" << endl; + cout << "Average complex conjugate: " << avg_conjugate << " microseconds" << endl; + cout.flush(); + }; + + EncryptionParameters parms(scheme_type::CKKS); + parms.set_poly_modulus_degree(4096); + parms.set_coeff_modulus(coeff_modulus_128(4096)); + performance_test(SEALContext::Create(parms)); + + cout << endl; + parms.set_poly_modulus_degree(8192); + parms.set_coeff_modulus(coeff_modulus_128(8192)); + performance_test(SEALContext::Create(parms)); + + cout << endl; + parms.set_poly_modulus_degree(16384); + parms.set_coeff_modulus(coeff_modulus_128(16384)); + performance_test(SEALContext::Create(parms)); + + /* + Comment out the following to run the biggest example. + */ + // cout << endl; + // parms.set_poly_modulus_degree(32768); + // parms.set_coeff_modulus(coeff_modulus_128(32768)); + // performance_test(SEALContext::Create(parms)); +} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 000000000..274a92824 --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,369 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +cmake_minimum_required(VERSION 3.10) + +project(SEAL VERSION 3.1.0 LANGUAGES CXX C) + +if(DEFINED MSVC AND NOT DEFINED ALLOW_COMMAND_LINE_BUILD) + message(FATAL_ERROR "Please build SEAL using the attached Visual Studio solution/project files") +endif() + +if(${ALLOW_COMMAND_LINE_BUILD}) + message(STATUS "Configuring for Visual Studio") +endif() + +# Build in Release mode by default; otherwise use selected option +set(SEAL_DEFAULT_BUILD_TYPE "Release") +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE ${SEAL_DEFAULT_BUILD_TYPE} CACHE + STRING "Build type" FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS + "Release" "Debug" "MinSizeRel" "RelWithDebInfo") +endif() +message(STATUS "Build type (CMAKE_BUILD_TYPE): ${CMAKE_BUILD_TYPE}") + +# In Debug mode enable also SEAL_DEBUG by default +if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") + set(SEAL_DEBUG_DEFAULT ON) +else() + set(SEAL_DEBUG_DEFAULT OFF) +endif() + +# Required files and directories +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${SEAL_SOURCE_DIR}/../lib) +set(SEAL_INCLUDES_INSTALL_DIR include) +set(SEAL_CONFIG_IN_FILENAME ${SEAL_SOURCE_DIR}/cmake/SEALConfig.cmake.in) +set(SEAL_CONFIG_FILENAME ${SEAL_SOURCE_DIR}/cmake/SEALConfig.cmake) +set(SEAL_TARGETS_FILENAME ${SEAL_SOURCE_DIR}/cmake/SEALTargets.cmake) +set(SEAL_CONFIG_VERSION_FILENAME ${SEAL_SOURCE_DIR}/cmake/SEALConfigVersion.cmake) +set(SEAL_CONFIG_INSTALL_DIR lib/cmake/SEAL) + +# For extra modules we might have +list(APPEND CMAKE_MODULE_PATH ${SEAL_SOURCE_DIR}/cmake) + +include(CMakePushCheckState) +include(CMakeDependentOption) +include(CheckIncludeFiles) +include(CheckCXXSourceRuns) +include(CheckTypeSize) + +# Are we using SEAL_DEBUG? +set(SEAL_DEBUG ${SEAL_DEBUG_DEFAULT}) +message(STATUS "SEAL debug mode: ${SEAL_DEBUG}") + +# Should we use C++14 or C++17? +set(SEAL_USE_CXX17_OPTION_STR "Use C++17") +option(SEAL_USE_CXX17 ${SEAL_USE_CXX17_OPTION_STR} ON) +message(STATUS "${SEAL_USE_CXX17_OPTION_STR} (SEAL_USE_CXX17): ${SEAL_USE_CXX17}") + +# Conditionally enable features from C++17 +set(SEAL_USE_STD_BYTE OFF) +set(SEAL_USE_SHARED_MUTEX OFF) +set(SEAL_USE_IF_CONSTEXPR OFF) +set(SEAL_USE_MAYBE_UNUSED OFF) +set(SEAL_LANG_FLAG "-std=c++14") +if(SEAL_USE_CXX17) + set(SEAL_USE_STD_BYTE ON) + set(SEAL_USE_SHARED_MUTEX ON) + set(SEAL_USE_IF_CONSTEXPR ON) + set(SEAL_USE_MAYBE_UNUSED ON) + set(SEAL_LANG_FLAG "-std=c++17") +endif() + +# Enforce at least 128-bit security level based on HomomorphicEncryption.org estimates +set(SEAL_ENFORCE_HE_STD_SECURITY_STR "Enforce at least 128-bit security level from HomomorphicEncryption.org security standard") +option(SEAL_ENFORCE_HE_STD_SECURITY ${SEAL_ENFORCE_HE_STD_SECURITY_STR} OFF) + +# Use intrinsics if available +set(SEAL_USE_INTRIN_OPTION_STR "Use intrinsics") +option(SEAL_USE_INTRIN ${SEAL_USE_INTRIN_OPTION_STR} ON) + +# Use Microsoft GSL if available +set(SEAL_USE_MSGSL_OPTION_STR "Use Microsoft GSL") +option(SEAL_USE_MSGSL ${SEAL_USE_MSGSL_OPTION_STR} ON) + +# Check for intrin.h or x64intrin.h +if(SEAL_USE_INTRIN) + if(DEFINED MSVC) + set(SEAL_INTRIN_HEADER "intrin.h") + else() + set(SEAL_INTRIN_HEADER "x86intrin.h") + endif() + + check_include_file_cxx(${SEAL_INTRIN_HEADER} HAVE_INTRIN_HEADER) + + if(NOT HAVE_INTRIN_HEADER) + set(SEAL_USE_INTRIN OFF CACHE BOOL ${SEAL_USE_INTRIN_OPTION_STR} FORCE) + endif() +endif() +message(STATUS "${SEAL_USE_INTRIN_OPTION_STR} (SEAL_USE_INTRIN): ${SEAL_USE_INTRIN}") + +# Specific intrinsics depending on SEAL_USE_INTRIN +if(DEFINED MSVC) + set(SEAL_USE__UMUL128_OPTION_STR "Use _umul128") + cmake_dependent_option(SEAL_USE__UMUL128 SEAL_USE__UMUL128_OPTION_STR ON "SEAL_USE_INTRIN" OFF) + + set(SEAL_USE__BITSCANREVERSE64_OPTION_STR "Use _BitScanReverse64") + cmake_dependent_option(SEAL_USE__BITSCANREVERSE64 SEAL_USE__BITSCANREVERSE64_OPTION_STR ON "SEAL_USE_INTRIN" OFF) +else() + set(SEAL_USE___INT128_OPTION_STR "Use __int128") + cmake_dependent_option(SEAL_USE___INT128 SEAL_USE___INT128_OPTION_STR ON "SEAL_USE_INTRIN" OFF) + + set(SEAL_USE___BUILTIN_CLZLL_OPTION_STR "Use __builtin_clzll") + cmake_dependent_option(SEAL_USE___BUILTIN_CLZLL SEAL_USE___BUILTIN_CLZLL_OPTION_STR ON "SEAL_USE_INTRIN" OFF) +endif() + +set(SEAL_USE__ADDCARRY_U64_OPTION_STR "Use _addcarry_u64") +cmake_dependent_option(SEAL_USE__ADDCARRY_U64 SEAL_USE__ADDCARRY_U64_OPTION_STR ON "SEAL_USE_INTRIN" OFF) + +set(SEAL_USE__SUBBORROW_U64_OPTION_STR "Use _subborrow_u64") +cmake_dependent_option(SEAL_USE__SUBBORROW_U64 SEAL_USE__SUBBORROW_U64_OPTION_STR ON "SEAL_USE_INTRIN" OFF) + +set(SEAL_USE_AES_NI_PRNG_OPTION_STR "Use fast AES-NI PRNG") +cmake_dependent_option(SEAL_USE_AES_NI_PRNG SEAL_USE_AES_NI_PRNG_OPTION_STR ON "SEAL_USE_INTRIN" OFF) + +if(SEAL_USE_INTRIN) + cmake_push_check_state(RESET) + set(CMAKE_REQUIRED_QUIET TRUE) + if(NOT DEFINED MSVC) + set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} -O0 ${SEAL_LANG_FLAG}") + endif() + + if(DEFINED MSVC) + # Check for presence of _umul128 + if(SEAL_USE__UMUL128) + check_cxx_source_runs(" + #include <${SEAL_INTRIN_HEADER}> + int main() { + unsigned long long a = 0, b = 0; + unsigned long long c; + volatile unsigned long long d; + d = _umul128(a, b, &c); + return 0; + }" + USE_UMUL128 + ) + if(NOT USE_UMUL128 EQUAL 1) + set(SEAL_USE__UMUL128 OFF CACHE BOOL ${SEAL_USE__UMUL128_OPTION_STR} FORCE) + endif() + endif() + + # Check for _BitScanReverse64 + if(SEAL_USE__BITSCANREVERSE64) + check_cxx_source_runs(" + #include <${SEAL_INTRIN_HEADER}> + int main() { + unsigned long a = 0, b = 0; + volatile unsigned char res = _BitScanReverse64(&a, b); + return 0; + }" + USE_BITSCANREVERSE64 + ) + if(NOT USE_BITSCANREVERSE64 EQUAL 1) + set(SEAL_USE__BITSCANREVERSE64 OFF CACHE BOOL ${SEAL_USE__BITSCANREVERSE64_OPTION_STR} FORCE) + endif() + endif() + else() + # Check for presence of ___int128 + if(SEAL_USE___INT128) + check_type_size("__int128" INT128 LANGUAGE CXX) + if(NOT INT128 EQUAL 16) + set(SEAL_USE___INT128 OFF CACHE BOOL ${SEAL_USE___INT128_OPTION_STR} FORCE) + endif() + endif() + + # Check for __builtin_clzll + if(SEAL_USE___BUILTIN_CLZLL) + check_cxx_source_runs(" + int main() { + volatile auto = __builtin_clzll(0); + return 0; + }" + USE_BUILTIN_CLZLL + ) + if(NOT USE_BUILTIN_CLZLL EQUAL 1) + set(SEAL_USE___BUILTIN_CLZLL OFF CACHE BOOL ${SEAL_USE___BUILTIN_CLZLL_OPTION_STR} FORCE) + endif() + endif() + endif() + + # Check for _addcarry_u64 + if(SEAL_USE__ADDCARRY_U64) + check_cxx_source_runs(" + #include <${SEAL_INTRIN_HEADER}> + int main() { + unsigned long long a; + volatile auto res = _addcarry_u64(0,0,0,&a); + return 0; + }" + USE_ADDCARRY_U64 + ) + if(NOT USE_ADDCARRY_U64 EQUAL 1) + set(SEAL_USE__ADDCARRY_U64 OFF CACHE BOOL ${SEAL_USE__ADDCARRY_U64_OPTION_STR} FORCE) + endif() + endif() + + # Check for _subborrow_u64 + if(SEAL_USE__SUBBORROW_U64) + check_cxx_source_runs(" + #include <${SEAL_INTRIN_HEADER}> + int main() { + unsigned long long a; + volatile auto res = _subborrow_u64(0,0,0,&a); + return 0; + }" + USE_SUBBORROW_U64 + ) + if(NOT USE_SUBBORROW_U64 EQUAL 1) + set(SEAL_USE__SUBBORROW_U64 OFF CACHE BOOL ${SEAL_USE__SUBBORROW_U64_OPTION_STR} FORCE) + endif() + endif() + + check_include_file_cxx("wmmintrin.h" HAVE_WMMINTRIN_HEADER) + if(NOT HAVE_WMMINTRIN_HEADER) + set(SEAL_USE_AES_NI_PRNG OFF CACHE BOOL ${SEAL_USE_AES_NI_PRNG_OPTION_STR} FORCE) + endif() + + # Check that AES-NI runs + if(SEAL_USE_AES_NI_PRNG) + if(NOT DEFINED MSVC) + set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} -maes") + endif() + check_cxx_source_runs(" + #include + int main() { + __m128i a{ 0 }; + volatile auto b = _mm_aeskeygenassist_si128(a, 0); + return 0; + }" + USE_AES_KEYGEN_ASSIST + ) + if(NOT USE_AES_KEYGEN_ASSIST EQUAL 1) + set(SEAL_USE_AES_NI_PRNG OFF CACHE BOOL ${SEAL_USE_AES_NI_PRNG_OPTION_STR} FORCE) + endif() + endif() + message(STATUS "${SEAL_USE_AES_NI_PRNG_OPTION_STR}: ${SEAL_USE_AES_NI_PRNG}") + + cmake_pop_check_state() +endif() + +# Try to find MSGSL if requested +if(SEAL_USE_MSGSL) + find_package(msgsl MODULE) + if(NOT msgsl_FOUND) + set(SEAL_USE_MSGSL OFF CACHE BOOL ${SEAL_USE_MSGSL_OPTION_STR} FORCE) + endif() +endif() +message(STATUS "${SEAL_USE_MSGSL_OPTION_STR} (SEAL_USE_MSGSL): ${SEAL_USE_MSGSL}") + +# Specific options depending on SEAL_USE_MSGSL +set(SEAL_USE_MSGSL_SPAN_OPTION_STR "Use gsl::span") +cmake_dependent_option(SEAL_USE_MSGSL_SPAN ${SEAL_USE_MSGSL_SPAN_OPTION_STR} ON "SEAL_USE_MSGSL" OFF) + +set(SEAL_USE_MSGSL_MULTISPAN_OPTION_STR "Use gsl::multi_span") +cmake_dependent_option(SEAL_USE_MSGSL_MULTISPAN ${SEAL_USE_MSGSL_MULTISPAN_OPTION_STR} ON "SEAL_USE_MSGSL" OFF) + +if(SEAL_USE_MSGSL) + # Now check for individual classes + cmake_push_check_state(RESET) + set(CMAKE_REQUIRED_INCLUDES ${MSGSL_INCLUDE_DIR}) + set(CMAKE_EXTRA_INCLUDE_FILES gsl/gsl) + set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} -O0 ${SEAL_LANG_FLAG}") + set(CMAKE_REQUIRED_QUIET TRUE) + + # Detect gsl::span + if(SEAL_USE_MSGSL_SPAN) + check_type_size("gsl::span" MSGSL_SPAN LANGUAGE CXX) + if(NOT MSGSL_SPAN GREATER 0) + set(SEAL_USE_MSGSL_SPAN OFF CACHE BOOL ${SEAL_USE_MSGSL_SPAN_OPTION_STR} FORCE) + endif() + endif() + + # Detect gsl::multi_span + if(SEAL_USE_MSGSL_MULTISPAN) + check_type_size("gsl::multi_span" MSGSL_MULTISPAN LANGUAGE CXX) + if(NOT MSGSL_MULTISPAN GREATER 0) + set(SEAL_USE_MSGSL_MULTISPAN OFF CACHE BOOL ${SEAL_USE_MSGSL_MULTISPAN_OPTION_STR} FORCE) + endif() + endif() + + cmake_pop_check_state() +endif() + +# Create library but add no source files yet +add_library(seal STATIC "") + +# Add source files to library and header files to install +add_subdirectory(seal) + +# Add local include directories for build +target_include_directories(seal PUBLIC + $) + +# We require at least C++14 +if(SEAL_USE_CXX17) + target_compile_features(seal PUBLIC cxx_std_17) +else() + target_compile_features(seal PUBLIC cxx_std_14) +endif() + +# Add -maes flag if needed +if(SEAL_USE_AES_NI_PRNG) + target_compile_options(seal PUBLIC "-maes") +endif() + +# Require thread library +set(CMAKE_THREAD_PREFER_PTHREAD TRUE) +set(THREADS_PREFER_PTHREAD_FLAG TRUE) +find_package(Threads REQUIRED) + +# Link Threads with seal +target_link_libraries(seal PUBLIC Threads::Threads) + +# Create msgsl interface target +if(SEAL_USE_MSGSL) + # Create interface target + add_library(msgsl INTERFACE) + set_target_properties(msgsl PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES ${MSGSL_INCLUDE_DIR}) + + # Associate msgsl with export seal_export + install(TARGETS msgsl EXPORT seal_export) + + # Link with seal + target_link_libraries(seal PUBLIC msgsl) +endif() + +# Associate seal to export seal_export +install(TARGETS seal EXPORT seal_export + ARCHIVE DESTINATION lib + INCLUDES DESTINATION ${SEAL_INCLUDES_INSTALL_DIR}) + +# Export the targets +export(EXPORT seal_export + FILE ${SEAL_TARGETS_FILENAME} + NAMESPACE SEAL::) + +# Create the CMake config file +configure_file(${SEAL_CONFIG_IN_FILENAME} ${SEAL_CONFIG_FILENAME} @ONLY) + +# Install the export +install( + EXPORT seal_export + FILE SEALTargets.cmake + NAMESPACE SEAL:: + DESTINATION ${SEAL_CONFIG_INSTALL_DIR}) + +# Version file; we require exact version match for downstream +include(CMakePackageConfigHelpers) +write_basic_package_version_file(${SEAL_CONFIG_VERSION_FILENAME} + VERSION ${SEAL_VERSION} + COMPATIBILITY ExactVersion) + +# Install other files +install( + FILES + ${SEAL_CONFIG_FILENAME} + ${SEAL_CONFIG_VERSION_FILENAME} + DESTINATION ${SEAL_CONFIG_INSTALL_DIR}) diff --git a/src/SEAL.vcxproj b/src/SEAL.vcxproj new file mode 100644 index 000000000..ce9896235 --- /dev/null +++ b/src/SEAL.vcxproj @@ -0,0 +1,201 @@ + + + + + Debug + x64 + + + Release + x64 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + {7EA96C25-FC0D-485A-BB71-32B6DA55652A} + SEAL + 10.0.16299.0 + + + + StaticLibrary + true + v141 + Unicode + false + + + StaticLibrary + false + v141 + true + Unicode + false + + + + + + + + + + + + + + + $(SolutionDir)lib\$(Platform)\$(Configuration)\ + $(ProjectDir)obj\$(Platform)\$(Configuration)\ + .lib + seal + + + $(SolutionDir)lib\$(Platform)\$(Configuration)\ + $(ProjectDir)obj\$(Platform)\$(Configuration)\ + .lib + seal + + + + Level3 + Disabled + true + $(ProjectDir) + true + Neither + stdcpp17 + %(PreprocessorDefinitions); _ENABLE_EXTENDED_ALIGNED_STORAGE + /Zc:__cplusplus %(AdditionalOptions) + + + true + + + "$(SolutionDir)tools\scripts\cmake_config.cmd" $(Configuration) "$(DevEnvDir)" "$(IncludePath)" + + + Configure SEAL through CMake + + + + + Level3 + MaxSpeed + true + true + true + $(ProjectDir) + Speed + Default + stdcpp17 + %(PreprocessorDefinitions); _ENABLE_EXTENDED_ALIGNED_STORAGE + /Zc:__cplusplus %(AdditionalOptions) + + + true + true + true + + + "$(SolutionDir)tools\scripts\cmake_config.cmd" $(Configuration) "$(DevEnvDir)" "$(IncludePath)" + + + Configure SEAL through CMake + + + + + + \ No newline at end of file diff --git a/src/SEAL.vcxproj.filters b/src/SEAL.vcxproj.filters new file mode 100644 index 000000000..11de12252 --- /dev/null +++ b/src/SEAL.vcxproj.filters @@ -0,0 +1,295 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms + + + {a119ce23-aae9-4b06-be2c-1c8aada4ab20} + + + {8740bd83-253c-49f3-8f9a-3b9c526f67c2} + + + {8585bc5e-eaa9-481a-a6ee-c38be1319a32} + + + {aaf838b1-cab2-4ccc-a016-a81c7edf506e} + + + {31fb1149-1a6f-438b-a86a-744384986d21} + + + {497d5f96-98a3-44e9-8b38-a2ea4bbea366} + + + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files\util + + + Header Files\util + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files + + + Header Files\util + + + Header Files\util + + + Header Files\util + + + Header Files + + + Header Files + + + Header Files + + + Header Files + + + Header Files\util + + + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files\util + + + Source Files\util + + + Source Files + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files\util + + + + + Other + + + Other\seal + + + Other\seal\util + + + + + Other\cmake + + + Other\seal\util + + + Other\cmake + + + \ No newline at end of file diff --git a/src/cmake/Findmsgsl.cmake b/src/cmake/Findmsgsl.cmake new file mode 100644 index 000000000..65f41536d --- /dev/null +++ b/src/cmake/Findmsgsl.cmake @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +# Simple attempt to locate Microsoft GSL +set(CURRENT_MSGSL_INCLUDE_DIR ${MSGSL_INCLUDE_DIR}) +unset(MSGSL_INCLUDE_DIR CACHE) +find_path(MSGSL_INCLUDE_DIR + NAMES gsl/gsl gsl/span gsl/multi_span + HINTS ${CMAKE_INCLUDE_PATH} ${CURRENT_MSGSL_INCLUDE_DIR}) + +# Determine whether found based on MSGSL_INCLUDE_DIR +find_package(PackageHandleStandardArgs) +find_package_handle_standard_args(msgsl + REQUIRED_VARS MSGSL_INCLUDE_DIR) diff --git a/src/cmake/SEALConfig.cmake.in b/src/cmake/SEALConfig.cmake.in new file mode 100644 index 000000000..666759149 --- /dev/null +++ b/src/cmake/SEALConfig.cmake.in @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +# Exports target SEAL::seal +# +# Creates variables: +# SEAL_BUILD_TYPE : The build configuration used +# SEAL_DEBUG : Set to non-zero value if SEAL is compiled with extra debugging code (very slow!) +# SEAL_ENFORCE_HE_STD_SECURITY : Set to non-zero value if SEAL is compiled to enforce at least +# a 128-bit security level based on HomomorphicEncryption.org security estimates +# SEAL_USE_MSGSL : Set to non-zero value if SEAL is compiled with Microsoft GSL support +# MSGSL_INCLUDE_DIR : Holds the path to Microsoft GSL if SEAL is compiled with Microsoft GSL support + +include(CMakeFindDependencyMacro) + +set(SEAL_BUILD_TYPE @CMAKE_BUILD_TYPE@) +set(SEAL_DEBUG @SEAL_DEBUG@) +set(SEAL_USE_CXX17 @SEAL_USE_CXX17@) +set(SEAL_ENFORCE_HE_STD_SECURITY @SEAL_ENFORCE_HE_STD_SECURITY@) +set(SEAL_USE_MSGSL @SEAL_USE_MSGSL@) +if(SEAL_USE_MSGSL) + set(MSGSL_INCLUDE_DIR @MSGSL_INCLUDE_DIR@) +endif() + +set(CMAKE_THREAD_PREFER_PTHREAD TRUE) +set(THREADS_PREFER_PTHREAD_FLAG TRUE) +find_dependency(Threads REQUIRED) + +include(${CMAKE_CURRENT_LIST_DIR}/SEALTargets.cmake) + +message(STATUS "SEAL detected (version ${SEAL_VERSION})") +if(SEAL_DEBUG) + message(STATUS "Performance warning: SEAL compiled in debug mode") +endif() diff --git a/src/seal/CMakeLists.txt b/src/seal/CMakeLists.txt new file mode 100644 index 000000000..eb6b0f8bf --- /dev/null +++ b/src/seal/CMakeLists.txt @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +target_sources(seal + PRIVATE + ${CMAKE_CURRENT_LIST_DIR}/batchencoder.cpp + ${CMAKE_CURRENT_LIST_DIR}/biguint.cpp + ${CMAKE_CURRENT_LIST_DIR}/ciphertext.cpp + ${CMAKE_CURRENT_LIST_DIR}/ckks.cpp + ${CMAKE_CURRENT_LIST_DIR}/context.cpp + ${CMAKE_CURRENT_LIST_DIR}/decryptor.cpp + ${CMAKE_CURRENT_LIST_DIR}/encoder.cpp + ${CMAKE_CURRENT_LIST_DIR}/encryptionparams.cpp + ${CMAKE_CURRENT_LIST_DIR}/encryptor.cpp + ${CMAKE_CURRENT_LIST_DIR}/evaluator.cpp + ${CMAKE_CURRENT_LIST_DIR}/galoiskeys.cpp + ${CMAKE_CURRENT_LIST_DIR}/keygenerator.cpp + ${CMAKE_CURRENT_LIST_DIR}/memorymanager.cpp + ${CMAKE_CURRENT_LIST_DIR}/plaintext.cpp + ${CMAKE_CURRENT_LIST_DIR}/randomgen.cpp + ${CMAKE_CURRENT_LIST_DIR}/relinkeys.cpp + ${CMAKE_CURRENT_LIST_DIR}/smallmodulus.cpp +) + +install( + FILES + ${CMAKE_CURRENT_LIST_DIR}/batchencoder.h + ${CMAKE_CURRENT_LIST_DIR}/biguint.h + ${CMAKE_CURRENT_LIST_DIR}/ciphertext.h + ${CMAKE_CURRENT_LIST_DIR}/ckks.h + ${CMAKE_CURRENT_LIST_DIR}/context.h + ${CMAKE_CURRENT_LIST_DIR}/decryptor.h + ${CMAKE_CURRENT_LIST_DIR}/defaultparams.h + ${CMAKE_CURRENT_LIST_DIR}/encoder.h + ${CMAKE_CURRENT_LIST_DIR}/encryptionparams.h + ${CMAKE_CURRENT_LIST_DIR}/encryptor.h + ${CMAKE_CURRENT_LIST_DIR}/evaluator.h + ${CMAKE_CURRENT_LIST_DIR}/galoiskeys.h + ${CMAKE_CURRENT_LIST_DIR}/intarray.h + ${CMAKE_CURRENT_LIST_DIR}/keygenerator.h + ${CMAKE_CURRENT_LIST_DIR}/memorymanager.h + ${CMAKE_CURRENT_LIST_DIR}/plaintext.h + ${CMAKE_CURRENT_LIST_DIR}/publickey.h + ${CMAKE_CURRENT_LIST_DIR}/randomgen.h + ${CMAKE_CURRENT_LIST_DIR}/relinkeys.h + ${CMAKE_CURRENT_LIST_DIR}/seal.h + ${CMAKE_CURRENT_LIST_DIR}/secretkey.h + ${CMAKE_CURRENT_LIST_DIR}/smallmodulus.h + DESTINATION + ${SEAL_INCLUDES_INSTALL_DIR}/seal +) + +add_subdirectory(util) diff --git a/src/seal/batchencoder.cpp b/src/seal/batchencoder.cpp new file mode 100644 index 000000000..e2e8ef46f --- /dev/null +++ b/src/seal/batchencoder.cpp @@ -0,0 +1,581 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include "seal/batchencoder.h" +#include "seal/util/polycore.h" + +using namespace std; +using namespace seal::util; + +namespace seal +{ + BatchEncoder::BatchEncoder(std::shared_ptr context) : + context_(std::move(context)) + { + // Verify parameters + if (!context_) + { + throw invalid_argument("invalid context"); + } + if (!context_->parameters_set()) + { + throw invalid_argument("encryption parameters are not set correctly"); + } + + auto &context_data = *context_->context_data(); + if (context_data.parms().scheme() != scheme_type::BFV) + { + throw invalid_argument("unsupported scheme"); + } + if (!context_data.qualifiers().using_batching) + { + throw invalid_argument("encryption parameters are not valid for batching"); + } + + // Set the slot count + slots_ = context_data.parms().poly_modulus_degree(); + + // Reserve space for all of the primitive roots + roots_of_unity_ = allocate_uint(slots_, pool_); + + // Fill the vector of roots of unity with all distinct odd powers of generator. + // These are all the primitive (2*slots_)-th roots of unity in integers modulo + // parms.plain_modulus(). + populate_roots_of_unity_vector(context_data); + + // Populate matrix representation index map + populate_matrix_reps_index_map(); + } + + void BatchEncoder::populate_roots_of_unity_vector( + const SEALContext::ContextData &context_data) + { + uint64_t root = context_data.plain_ntt_tables()->get_root(); + auto &modulus = context_data.parms().plain_modulus(); + + uint64_t generator_sq = multiply_uint_uint_mod(root, root, modulus); + roots_of_unity_[0] = root; + + for (size_t i = 1; i < slots_; i++) + { + roots_of_unity_[i] = multiply_uint_uint_mod(roots_of_unity_[i - 1], + generator_sq, modulus); + } + } + + void BatchEncoder::populate_matrix_reps_index_map() + { + int logn = get_power_of_two(slots_); + matrix_reps_index_map_ = allocate_uint(slots_, pool_); + + // Copy from the matrix to the value vectors + size_t row_size = slots_ >> 1; + size_t m = slots_ << 1; + uint64_t gen = 3; + uint64_t pos = 1; + for (size_t i = 0; i < row_size; i++) + { + // Position in normal bit order + uint64_t index1 = (pos - 1) >> 1; + uint64_t index2 = (m - pos - 1) >> 1; + + // Set the bit-reversed locations + matrix_reps_index_map_[i] = util::reverse_bits(index1, logn); + matrix_reps_index_map_[row_size | i] = util::reverse_bits(index2, logn); + + // Next primitive root + pos *= gen; + pos &= (m - 1); + } + } + + void BatchEncoder::encode(const vector &values_matrix, + Plaintext &destination) + { + auto &context_data = *context_->context_data(); + + // Validate input parameters + size_t values_matrix_size = values_matrix.size(); + if (values_matrix_size > slots_) + { + throw logic_error("values_matrix size is too large"); + } +#ifdef SEAL_DEBUG + uint64_t modulus = context_data.parms().plain_modulus().value(); + for (auto v : values_matrix) + { + // Validate the i-th input + if (v >= modulus) + { + throw invalid_argument("input value is larger than plain_modulus"); + } + } +#endif + // Set destination to full size + destination.resize(slots_); + destination.parms_id() = parms_id_zero; + + // First write the values to destination coefficients. + // Read in top row, then bottom row. + for (size_t i = 0; i < values_matrix_size; i++) + { + *(destination.data() + matrix_reps_index_map_[i]) = values_matrix[i]; + } + for (size_t i = values_matrix_size; i < slots_; i++) + { + *(destination.data() + matrix_reps_index_map_[i]) = 0; + } + + // Transform destination using inverse of negacyclic NTT + // Note: We already performed bit-reversal when reading in the matrix + inverse_ntt_negacyclic_harvey(destination.data(), *context_data.plain_ntt_tables()); + } + + void BatchEncoder::encode(const vector &values_matrix, + Plaintext &destination) + { + auto &context_data = *context_->context_data(); + uint64_t modulus = context_data.parms().plain_modulus().value(); + + // Validate input parameters + size_t values_matrix_size = values_matrix.size(); + if (values_matrix_size > slots_) + { + throw logic_error("values_matrix size is too large"); + } +#ifdef SEAL_DEBUG + uint64_t plain_modulus_div_two = modulus >> 1; + for (auto v : values_matrix) + { + // Validate the i-th input + if (unsigned_gt(llabs(v), plain_modulus_div_two)) + { + throw invalid_argument("input value is larger than plain_modulus"); + } + } +#endif + // Set destination to full size + destination.resize(slots_); + destination.parms_id() = parms_id_zero; + + // First write the values to destination coefficients. + // Read in top row, then bottom row. + for (size_t i = 0; i < values_matrix_size; i++) + { + *(destination.data() + matrix_reps_index_map_[i]) = + (values_matrix[i] < 0) ? (modulus + static_cast(values_matrix[i])) : + static_cast(values_matrix[i]); + } + for (size_t i = values_matrix_size; i < slots_; i++) + { + *(destination.data() + matrix_reps_index_map_[i]) = 0; + } + + // Transform destination using inverse of negacyclic NTT + // Note: We already performed bit-reversal when reading in the matrix + inverse_ntt_negacyclic_harvey(destination.data(), *context_data.plain_ntt_tables()); + } +#ifdef SEAL_USE_MSGSL_SPAN + void BatchEncoder::encode(gsl::span values_matrix, + Plaintext &destination) + { + auto &context_data = *context_->context_data(); + + // Validate input parameters + size_t values_matrix_size = static_cast(values_matrix.size()); + if (values_matrix_size > slots_) + { + throw logic_error("values_matrix size is too large"); + } +#ifdef SEAL_DEBUG + uint64_t modulus = context_data.parms().plain_modulus().value(); + for (auto v : values_matrix) + { + // Validate the i-th input + if (v >= modulus) + { + throw invalid_argument("input value is larger than plain_modulus"); + } + } +#endif + // Set destination to full size + destination.resize(slots_); + destination.parms_id() = parms_id_zero; + + // First write the values to destination coefficients. Read + // in top row, then bottom row. + using index_type = decltype(values_matrix)::index_type; + for (size_t i = 0; i < values_matrix_size; i++) + { + *(destination.data() + matrix_reps_index_map_[i]) = + values_matrix[static_cast(i)]; + } + for (size_t i = values_matrix_size; i < slots_; i++) + { + *(destination.data() + matrix_reps_index_map_[i]) = 0; + } + + // Transform destination using inverse of negacyclic NTT + // Note: We already performed bit-reversal when reading in the matrix + inverse_ntt_negacyclic_harvey(destination.data(), *context_data.plain_ntt_tables()); + } + + void BatchEncoder::encode(gsl::span values_matrix, + Plaintext &destination) + { + auto &context_data = *context_->context_data(); + uint64_t modulus = context_data.parms().plain_modulus().value(); + + // Validate input parameters + size_t values_matrix_size = static_cast(values_matrix.size()); + if (values_matrix_size > slots_) + { + throw logic_error("values_matrix size is too large"); + } +#ifdef SEAL_DEBUG + uint64_t plain_modulus_div_two = modulus >> 1; + for (auto v : values_matrix) + { + // Validate the i-th input + if (unsigned_gt(llabs(v), plain_modulus_div_two)) + { + throw invalid_argument("input value is larger than plain_modulus"); + } + } +#endif + // Set destination to full size + destination.resize(slots_); + destination.parms_id() = parms_id_zero; + + // First write the values to destination coefficients. Read + // in top row, then bottom row. + using index_type = decltype(values_matrix)::index_type; + for (size_t i = 0; i < values_matrix_size; i++) + { + *(destination.data() + matrix_reps_index_map_[i]) = + (values_matrix[static_cast(i)] < 0) ? + (modulus + static_cast(values_matrix[static_cast(i)])) : + static_cast(values_matrix[static_cast(i)]); + } + for (size_t i = values_matrix_size; i < slots_; i++) + { + *(destination.data() + matrix_reps_index_map_[i]) = 0; + } + + // Transform destination using inverse of negacyclic NTT + // Note: We already performed bit-reversal when reading in the matrix + inverse_ntt_negacyclic_harvey(destination.data(), *context_data.plain_ntt_tables()); + } +#endif + void BatchEncoder::encode(Plaintext &plain, MemoryPoolHandle pool) + { + if (plain.is_ntt_form()) + { + throw invalid_argument("plain cannot be in NTT form"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + + auto &context_data = *context_->context_data(); + + // Validate input parameters + if (plain.coeff_count() > context_data.parms().poly_modulus_degree()) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#ifdef SEAL_DEBUG + if (!are_poly_coefficients_less_than(plain.data(), + plain.coeff_count(), context_data.parms().plain_modulus().value())) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#endif + // We need to permute the coefficients of plain. To do this, we allocate + // temporary space. + size_t input_plain_coeff_count = min(plain.coeff_count(), slots_); + auto temp(allocate_uint(input_plain_coeff_count, pool)); + set_uint_uint(plain.data(), input_plain_coeff_count, temp.get()); + + // Set plain to full slot count size. + plain.resize(slots_); + plain.parms_id() = parms_id_zero; + + // First write the values to destination coefficients. Read + // in top row, then bottom row. + for (size_t i = 0; i < input_plain_coeff_count; i++) + { + *(plain.data() + matrix_reps_index_map_[i]) = temp[i]; + } + for (size_t i = input_plain_coeff_count; i < slots_; i++) + { + *(plain.data() + matrix_reps_index_map_[i]) = 0; + } + + // Transform destination using inverse of negacyclic NTT + // Note: We already performed bit-reversal when reading in the matrix + inverse_ntt_negacyclic_harvey(plain.data(), *context_data.plain_ntt_tables()); + } + + void BatchEncoder::decode(const Plaintext &plain, vector &destination, + MemoryPoolHandle pool) + { + if (plain.is_ntt_form()) + { + throw invalid_argument("plain cannot be in NTT form"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + + auto &context_data = *context_->context_data(); + + // Validate input parameters + if (plain.coeff_count() > context_data.parms().poly_modulus_degree()) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#ifdef SEAL_DEBUG + if (!are_poly_coefficients_less_than(plain.data(), + plain.coeff_count(), context_data.parms().plain_modulus().value())) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#endif + // Set destination size + destination.resize(slots_); + + // Never include the leading zero coefficient (if present) + size_t plain_coeff_count = min(plain.coeff_count(), slots_); + + auto temp_dest(allocate_uint(slots_, pool)); + + // Make a copy of poly + set_uint_uint(plain.data(), plain_coeff_count, temp_dest.get()); + set_zero_uint(slots_ - plain_coeff_count, temp_dest.get() + plain_coeff_count); + + // Transform destination using negacyclic NTT. + ntt_negacyclic_harvey(temp_dest.get(), *context_data.plain_ntt_tables()); + + // Read top row + for (size_t i = 0; i < slots_; i++) + { + destination[i] = temp_dest[matrix_reps_index_map_[i]]; + } + } + + void BatchEncoder::decode(const Plaintext &plain, vector &destination, + MemoryPoolHandle pool) + { + if (plain.is_ntt_form()) + { + throw invalid_argument("plain cannot be in NTT form"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + + auto &context_data = *context_->context_data(); + uint64_t modulus = context_data.parms().plain_modulus().value(); + + // Validate input parameters + if (plain.coeff_count() > context_data.parms().poly_modulus_degree()) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#ifdef SEAL_DEBUG + if (!are_poly_coefficients_less_than(plain.data(), plain.coeff_count(), modulus)) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#endif + // Set destination size + destination.resize(slots_); + + // Never include the leading zero coefficient (if present) + size_t plain_coeff_count = min(plain.coeff_count(), slots_); + + auto temp_dest(allocate_uint(slots_, pool)); + + // Make a copy of poly + set_uint_uint(plain.data(), plain_coeff_count, temp_dest.get()); + set_zero_uint(slots_ - plain_coeff_count, temp_dest.get() + plain_coeff_count); + + // Transform destination using negacyclic NTT. + ntt_negacyclic_harvey(temp_dest.get(), *context_data.plain_ntt_tables()); + + // Read top row, then bottom row + uint64_t plain_modulus_div_two = modulus >> 1; + for (size_t i = 0; i < slots_; i++) + { + uint64_t curr_value = temp_dest[matrix_reps_index_map_[i]]; + destination[i] = (curr_value > plain_modulus_div_two) ? + (static_cast(curr_value) - static_cast(modulus)) : + static_cast(curr_value); + } + } +#ifdef SEAL_USE_MSGSL_SPAN + void BatchEncoder::decode(const Plaintext &plain, gsl::span destination, + MemoryPoolHandle pool) + { + if (plain.is_ntt_form()) + { + throw invalid_argument("plain cannot be in NTT form"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + + auto &context_data = *context_->context_data(); + + // Validate input parameters + if (plain.coeff_count() > context_data.parms().poly_modulus_degree()) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#ifdef SEAL_DEBUG + if (!are_poly_coefficients_less_than(plain.data(), + plain.coeff_count(), context_data.parms().plain_modulus().value())) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#endif + using index_type = decltype(destination)::index_type; + if(unsigned_gt(destination.size(), numeric_limits::max()) || + unsigned_neq(destination.size(), slots_)) + { + throw invalid_argument("destination has incorrect size"); + } + + // Never include the leading zero coefficient (if present) + size_t plain_coeff_count = min(plain.coeff_count(), slots_); + + auto temp_dest(allocate_uint(slots_, pool)); + + // Make a copy of poly + set_uint_uint(plain.data(), plain_coeff_count, temp_dest.get()); + set_zero_uint(slots_ - plain_coeff_count, temp_dest.get() + plain_coeff_count); + + // Transform destination using negacyclic NTT. + ntt_negacyclic_harvey(temp_dest.get(), *context_data.plain_ntt_tables()); + + // Read top row + for (size_t i = 0; i < slots_; i++) + { + destination[static_cast(i)] = temp_dest[matrix_reps_index_map_[i]]; + } + } + + void BatchEncoder::decode(const Plaintext &plain, gsl::span destination, + MemoryPoolHandle pool) + { + if (plain.is_ntt_form()) + { + throw invalid_argument("plain cannot be in NTT form"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + + auto &context_data = *context_->context_data(); + uint64_t modulus = context_data.parms().plain_modulus().value(); + + // Validate input parameters + if (plain.coeff_count() > context_data.parms().poly_modulus_degree()) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#ifdef SEAL_DEBUG + if (!are_poly_coefficients_less_than(plain.data(), plain.coeff_count(), modulus)) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#endif + using index_type = decltype(destination)::index_type; + if(unsigned_gt(destination.size(), numeric_limits::max()) || + unsigned_neq(destination.size(), slots_)) + { + throw invalid_argument("destination has incorrect size"); + } + + // Never include the leading zero coefficient (if present) + size_t plain_coeff_count = min(plain.coeff_count(), slots_); + + auto temp_dest(allocate_uint(slots_, pool)); + + // Make a copy of poly + set_uint_uint(plain.data(), plain_coeff_count, temp_dest.get()); + set_zero_uint(slots_ - plain_coeff_count, temp_dest.get() + plain_coeff_count); + + // Transform destination using negacyclic NTT. + ntt_negacyclic_harvey(temp_dest.get(), *context_data.plain_ntt_tables()); + + // Read top row, then bottom row + uint64_t plain_modulus_div_two = modulus >> 1; + for (size_t i = 0; i < slots_; i++) + { + uint64_t curr_value = temp_dest[matrix_reps_index_map_[i]]; + destination[static_cast(i)] = (curr_value > plain_modulus_div_two) ? + (static_cast(curr_value) - static_cast(modulus)) : + static_cast(curr_value); + } + } +#endif + void BatchEncoder::decode(Plaintext &plain, MemoryPoolHandle pool) + { + if (plain.is_ntt_form()) + { + throw invalid_argument("plain cannot be in NTT form"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + + auto &context_data = *context_->context_data(); + + // Validate input parameters + if (plain.coeff_count() > context_data.parms().poly_modulus_degree()) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#ifdef SEAL_DEBUG + if (!are_poly_coefficients_less_than(plain.data(), + plain.coeff_count(), context_data.parms().plain_modulus().value())) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#endif + // Never include the leading zero coefficient (if present) + size_t plain_coeff_count = min(plain.coeff_count(), slots_); + + // Allocate temporary space to store a wide copy of plain + auto temp(allocate_uint(slots_, pool)); + + // Make a copy of poly + set_uint_uint(plain.data(), plain_coeff_count, temp.get()); + set_zero_uint(slots_ - plain_coeff_count, temp.get() + plain_coeff_count); + + // Transform destination using negacyclic NTT. + ntt_negacyclic_harvey(temp.get(), *context_data.plain_ntt_tables()); + + // Set plain to full slot count size (note that all new coefficients are + // set to zero). + plain.resize(slots_); + + // Read top row, then bottom row + for (size_t i = 0; i < slots_; i++) + { + *(plain.data() + i) = temp[matrix_reps_index_map_[i]]; + } + } +} diff --git a/src/seal/batchencoder.h b/src/seal/batchencoder.h new file mode 100644 index 000000000..e76c037e2 --- /dev/null +++ b/src/seal/batchencoder.h @@ -0,0 +1,405 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include "seal/util/defines.h" +#include "seal/util/common.h" +#include "seal/util/uintcore.h" +#include "seal/util/uintarithsmallmod.h" +#include "seal/plaintext.h" +#include "seal/context.h" + +namespace seal +{ + /** + Provides functionality for CRT batching. If the polynomial modulus degree is N, and + the plaintext modulus is a prime number T such that T is congruent to 1 modulo 2N, + then BatchEncoder allows the SEAL plaintext elements to be viewed as 2-by-(N/2) + matrices of integers modulo T. Homomorphic operations performed on such encrypted + matrices are applied coefficient (slot) wise, enabling powerful SIMD functionality + for computations that are vectorizable. This functionality is often called "batching" + in the homomorphic encryption literature. + + @par Mathematical Background + Mathematically speaking, if the polynomial modulus is X^N+1, N is a power of two, and + plain_modulus is a prime number T such that 2N divides T-1, then integers modulo T + contain a primitive 2N-th root of unity and the polynomial X^N+1 splits into n distinct + linear factors as X^N+1 = (X-a_1)*...*(X-a_N) mod T, where the constants a_1, ..., a_n + are all the distinct primitive 2N-th roots of unity in integers modulo T. The Chinese + Remainder Theorem (CRT) states that the plaintext space Z_T[X]/(X^N+1) in this case is + isomorphic (as an algebra) to the N-fold direct product of fields Z_T. The isomorphism + is easy to compute explicitly in both directions, which is what this class does. + Furthermore, the Galois group of the extension is (Z/2NZ)* ~= Z/2Z x Z/(N/2) whose + action on the primitive roots of unity is easy to describe. Since the batching slots + correspond 1-to-1 to the primitive roots of unity, applying Galois automorphisms on the + plaintext act by permuting the slots. By applying generators of the two cyclic + subgroups of the Galois group, we can effectively view the plaintext as a 2-by-(N/2) + matrix, and enable cyclic row rotations, and column rotations (row swaps). + + @par Valid Parameters + Whether batching can be used depends on whether the plaintext modulus has been chosen + appropriately. Thus, to construct a BatchEncoder the user must provide an instance + of SEALContext such that its associated EncryptionParameterQualifiers object has the + flags parameters_set and enable_batching set to true. + + @see EncryptionParameters for more information about encryption parameters. + @see EncryptionParameterQualifiers for more information about parameter qualifiers. + @see Evaluator for rotating rows and columns of encrypted matrices. + */ + class BatchEncoder + { + public: + /** + Creates a BatchEncoder. It is necessary that the encryption parameters + given through the SEALContext object support batching. + + @param[in] context The SEALContext + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid for batching + @throws std::invalid_argument if scheme is not scheme_type::BFV + */ + BatchEncoder(std::shared_ptr context); + + /** + Creates a SEAL plaintext from a given matrix. This function "batches" a given matrix + of integers modulo the plaintext modulus into a SEAL plaintext element, and stores + the result in the destination parameter. The input vector must have size at most equal + to the degree of the polynomial modulus. The first half of the elements represent the + first row of the matrix, and the second half represent the second row. The numbers + in the matrix can be at most equal to the plaintext modulus for it to represent + a valid SEAL plaintext. + + If the destination plaintext overlaps the input values in memory, the behavior of + this function is undefined. + + @param[in] values The matrix of integers modulo plaintext modulus to batch + @param[out] destination The plaintext polynomial to overwrite with the result + @throws std::invalid_argument if values is too large + */ + void encode(const std::vector &values, Plaintext &destination); + + /** + Creates a SEAL plaintext from a given matrix. This function "batches" a given matrix + of integers modulo the plaintext modulus into a SEAL plaintext element, and stores + the result in the destination parameter. The input vector must have size at most equal + to the degree of the polynomial modulus. The first half of the elements represent the + first row of the matrix, and the second half represent the second row. The numbers + in the matrix can be at most equal to the plaintext modulus for it to represent + a valid SEAL plaintext. + + If the destination plaintext overlaps the input values in memory, the behavior of + this function is undefined. + + @param[in] values The matrix of integers modulo plaintext modulus to batch + @param[out] destination The plaintext polynomial to overwrite with the result + @throws std::invalid_argument if values is too large + */ + void encode(const std::vector &values, Plaintext &destination); +#ifdef SEAL_USE_MSGSL_SPAN + /** + Creates a SEAL plaintext from a given matrix. This function "batches" a given matrix + of integers modulo the plaintext modulus into a SEAL plaintext element, and stores + the result in the destination parameter. The input vector must have size at most equal + to the degree of the polynomial modulus. The first half of the elements represent the + first row of the matrix, and the second half represent the second row. The numbers + in the matrix can be at most equal to the plaintext modulus for it to represent + a valid SEAL plaintext. + + If the destination plaintext overlaps the input values in memory, the behavior of + this function is undefined. + + @param[in] values The matrix of integers modulo plaintext modulus to batch + @param[out] destination The plaintext polynomial to overwrite with the result + @throws std::invalid_argument if values is too large + */ + void encode(gsl::span values, Plaintext &destination); + + /** + Creates a SEAL plaintext from a given matrix. This function "batches" a given matrix + of integers modulo the plaintext modulus into a SEAL plaintext element, and stores + the result in the destination parameter. The input vector must have size at most equal + to the degree of the polynomial modulus. The first half of the elements represent the + first row of the matrix, and the second half represent the second row. The numbers + in the matrix can be at most equal to the plaintext modulus for it to represent + a valid SEAL plaintext. + + If the destination plaintext overlaps the input values in memory, the behavior of + this function is undefined. + + @param[in] values The matrix of integers modulo plaintext modulus to batch + @param[out] destination The plaintext polynomial to overwrite with the result + @throws std::invalid_argument if values is too large + */ + void encode(gsl::span values, Plaintext &destination); +#ifdef SEAL_USE_MSGSL_MULTISPAN + /** + Creates a SEAL plaintext from a given matrix. This function "batches" a given matrix + of integers modulo the plaintext modulus into a SEAL plaintext element, and stores + the result in the destination parameter. The input must have dimensions [2, N/2], + where N denotes the degree of the polynomial modulus, representing a 2 x (N/2) + matrix. The numbers in the matrix can be at most equal to the plaintext modulus for + it to represent a valid SEAL plaintext. + + If the destination plaintext overlaps the input values in memory, the behavior of + this function is undefined. + + @param[in] values The matrix of integers modulo plaintext modulus to batch + @param[out] destination The plaintext polynomial to overwrite with the result + @throws std::invalid_argument if values is too large or has incorrect size + */ + inline void encode(gsl::multi_span< + const std::uint64_t, + static_cast(2), + gsl::dynamic_range> values, Plaintext &destination) + { + encode(gsl::span(values.data(), values.size()), + destination); + } + + /** + Creates a SEAL plaintext from a given matrix. This function "batches" a given matrix + of integers modulo the plaintext modulus into a SEAL plaintext element, and stores + the result in the destination parameter. The input must have dimensions [2, N/2], + where N denotes the degree of the polynomial modulus, representing a 2 x (N/2) + matrix. The numbers in the matrix can be at most equal to the plaintext modulus for + it to represent a valid SEAL plaintext. + + If the destination plaintext overlaps the input values in memory, the behavior of + this function is undefined. + + @param[in] values The matrix of integers modulo plaintext modulus to batch + @param[out] destination The plaintext polynomial to overwrite with the result + @throws std::invalid_argument if values is too large or has incorrect size + */ + inline void encode(gsl::multi_span< + const std::int64_t, + static_cast(2), + gsl::dynamic_range> values, Plaintext &destination) + { + encode(gsl::span(values.data(), values.size()), + destination); + } +#endif +#endif + /** + Creates a SEAL plaintext from a given matrix. This function "batches" a given matrix + of integers modulo the plaintext modulus in-place into a SEAL plaintext ready to be + encrypted. The matrix is given as a plaintext element whose first N/2 coefficients + represent the first row of the matrix, and the second N/2 coefficients represent the + second row, where N denotes the degree of the polynomial modulus. The input plaintext + must have degress less than the polynomial modulus, and coefficients less than the + plaintext modulus, i.e. it must be a valid plaintext for the encryption parameters. + Dynamic memory allocations in the process are allocated from the memory pool pointed + to by the given MemoryPoolHandle. + + @param[in] plain The matrix of integers modulo plaintext modulus to batch + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if plain is not valid for the encryption parameters + @throws std::invalid_argument if plain is in NTT form + @throws std::invalid_argument if pool is uninitialized + */ + void encode(Plaintext &plain, MemoryPoolHandle pool = MemoryManager::GetPool()); + + /** + Inverse of encode. This function "unbatches" a given SEAL plaintext into a matrix + of integers modulo the plaintext modulus, and stores the result in the destination + parameter. The input plaintext must have degress less than the polynomial modulus, + and coefficients less than the plaintext modulus, i.e. it must be a valid plaintext + for the encryption parameters. Dynamic memory allocations in the process are + allocated from the memory pool pointed to by the given MemoryPoolHandle. + + @param[in] plain The plaintext polynomial to unbatch + @param[out] destination The matrix to be overwritten with the values in the slots + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if plain is not valid for the encryption parameters + @throws std::invalid_argument if plain is in NTT form + @throws std::invalid_argument if pool is uninitialized + */ + void decode(const Plaintext &plain, std::vector &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()); + + /** + Inverse of encode. This function "unbatches" a given SEAL plaintext into a matrix + of integers modulo the plaintext modulus, and stores the result in the destination + parameter. The input plaintext must have degress less than the polynomial modulus, + and coefficients less than the plaintext modulus, i.e. it must be a valid plaintext + for the encryption parameters. Dynamic memory allocations in the process are + allocated from the memory pool pointed to by the given MemoryPoolHandle. + + @param[in] plain The plaintext polynomial to unbatch + @param[out] destination The matrix to be overwritten with the values in the slots + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if plain is not valid for the encryption parameters + @throws std::invalid_argument if plain is in NTT form + @throws std::invalid_argument if pool is uninitialized + */ + void decode(const Plaintext &plain, std::vector &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()); +#ifdef SEAL_USE_MSGSL_SPAN + /** + Inverse of encode. This function "unbatches" a given SEAL plaintext into a matrix + of integers modulo the plaintext modulus, and stores the result in the destination + parameter. The input plaintext must have degress less than the polynomial modulus, + and coefficients less than the plaintext modulus, i.e. it must be a valid plaintext + for the encryption parameters. Dynamic memory allocations in the process are + allocated from the memory pool pointed to by the given MemoryPoolHandle. + + @param[in] plain The plaintext polynomial to unbatch + @param[out] destination The matrix to be overwritten with the values in the slots + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if plain is not valid for the encryption parameters + @throws std::invalid_argument if plain is in NTT form + @throws std::invalid_argument if destination has incorrect size + @throws std::invalid_argument if pool is uninitialized + */ + void decode(const Plaintext &plain, gsl::span destination, + MemoryPoolHandle pool = MemoryManager::GetPool()); + + /** + Inverse of encode. This function "unbatches" a given SEAL plaintext into a matrix + of integers modulo the plaintext modulus, and stores the result in the destination + parameter. The input plaintext must have degress less than the polynomial modulus, + and coefficients less than the plaintext modulus, i.e. it must be a valid plaintext + for the encryption parameters. Dynamic memory allocations in the process are + allocated from the memory pool pointed to by the given MemoryPoolHandle. + + @param[in] plain The plaintext polynomial to unbatch + @param[out] destination The matrix to be overwritten with the values in the slots + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if plain is not valid for the encryption parameters + @throws std::invalid_argument if plain is in NTT form + @throws std::invalid_argument if destination has incorrect size + @throws std::invalid_argument if pool is uninitialized + */ + void decode(const Plaintext &plain, gsl::span destination, + MemoryPoolHandle pool = MemoryManager::GetPool()); +#ifdef SEAL_USE_MSGSL_MULTISPAN + /** + Inverse of encode. This function "unbatches" a given SEAL plaintext into a matrix + of integers modulo the plaintext modulus, and stores the result in the destination + parameter. The destination must have dimensions [2, N/2], where N denotes the degree + of the polynomial modulus, representing a 2 x (N/2) matrix. The input plaintext must + have degress less than the polynomial modulus, and coefficients less than the + plaintext modulus, i.e. it must be a valid plaintext for the encryption parameters. + Dynamic memory allocations in the process are allocated from the memory pool pointed + to by the given MemoryPoolHandle. + + @param[in] plain The plaintext polynomial to unbatch + @param[out] destination The matrix to be overwritten with the values in the slots + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if plain is not valid for the encryption parameters + @throws std::invalid_argument if plain is in NTT form + @throws std::invalid_argument if destination has incorrect size + @throws std::invalid_argument if pool is uninitialized + */ + inline void decode(const Plaintext &plain, + gsl::multi_span(2), + gsl::dynamic_range> destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + decode(plain, gsl::span(destination.data(), + destination.size()), std::move(pool)); + } + + /** + Inverse of encode. This function "unbatches" a given SEAL plaintext into a matrix + of integers modulo the plaintext modulus, and stores the result in the destination + parameter. The destination must have dimensions [2, N/2], where N denotes the degree + of the polynomial modulus, representing a 2 x (N/2) matrix. The input plaintext must + have degress less than the polynomial modulus, and coefficients less than the + plaintext modulus, i.e. it must be a valid plaintext for the encryption parameters. + Dynamic memory allocations in the process are allocated from the memory pool pointed + to by the given MemoryPoolHandle. + + @param[in] plain The plaintext polynomial to unbatch + @param[out] destination The matrix to be overwritten with the values in the slots + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if plain is not valid for the encryption parameters + @throws std::invalid_argument if plain is in NTT form + @throws std::invalid_argument if destination has incorrect size + @throws std::invalid_argument if pool is uninitialized + */ + inline void decode(const Plaintext &plain, + gsl::multi_span(2), + gsl::dynamic_range> destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + decode(plain, gsl::span(destination.data(), + destination.size()), std::move(pool)); + } +#endif +#endif + /** + Inverse of encode. This function "unbatches" a given SEAL plaintext in-place into + a matrix of integers modulo the plaintext modulus. The input plaintext must have + degress less than the polynomial modulus, and coefficients less than the plaintext + modulus, i.e. it must be a valid plaintext for the encryption parameters. Dynamic + memory allocations in the process are allocated from the memory pool pointed to by + the given MemoryPoolHandle. + + @param[in] plain The plaintext polynomial to unbatch + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if plain is not valid for the encryption parameters + @throws std::invalid_argument if plain is in NTT form + @throws std::invalid_argument if pool is uninitialized + */ + void decode(Plaintext &plain, MemoryPoolHandle pool = MemoryManager::GetPool()); + + /** + Returns the number of slots. + */ + inline auto slot_count() const noexcept + { + return slots_; + } + + private: + BatchEncoder(const BatchEncoder ©) = delete; + + BatchEncoder(BatchEncoder &&source) = delete; + + BatchEncoder &operator =(const BatchEncoder &assign) = delete; + + BatchEncoder &operator =(BatchEncoder &&assign) = delete; + + void populate_roots_of_unity_vector( + const SEALContext::ContextData &context_data); + + void populate_matrix_reps_index_map(); + + inline void reverse_bits(std::uint64_t *input) + { +#ifdef SEAL_DEBUG + if (input == nullptr) + { + throw std::invalid_argument("input cannot be null"); + } +#endif + std::size_t coeff_count = context_->context_data()->parms().poly_modulus_degree(); + int logn = util::get_power_of_two(coeff_count); + for (std::size_t i = 0; i < coeff_count; i++) + { + std::uint64_t reversed_i = util::reverse_bits(i, logn); + if (i < reversed_i) + { + std::swap(input[i], input[reversed_i]); + } + } + } + + MemoryPoolHandle pool_ = MemoryManager::GetPool(); + + std::shared_ptr context_{ nullptr }; + + std::size_t slots_; + + util::Pointer roots_of_unity_; + + util::Pointer matrix_reps_index_map_; + }; +} diff --git a/src/seal/biguint.cpp b/src/seal/biguint.cpp new file mode 100644 index 000000000..d6cd6d903 --- /dev/null +++ b/src/seal/biguint.cpp @@ -0,0 +1,312 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/biguint.h" +#include "seal/util/common.h" +#include "seal/util/uintcore.h" +#include "seal/util/uintarith.h" +#include "seal/util/uintarithmod.h" +#include + +using namespace std; +using namespace seal::util; + +namespace seal +{ + BigUInt::BigUInt(int bit_count) + { + resize(bit_count); + } + + BigUInt::BigUInt(const string &hex_value) + { + operator =(hex_value); + } + + BigUInt::BigUInt(int bit_count, const string &hex_value) + { + resize(bit_count); + operator =(hex_value); + if (bit_count != bit_count_) + { + resize(bit_count); + } + } + + BigUInt::BigUInt(int bit_count, uint64_t *value) : + value_(decltype(value_)::Aliasing(value)), bit_count_(bit_count) + { + if (bit_count < 0) + { + throw invalid_argument("bit_count must be non-negative"); + } + if (value == nullptr && bit_count > 0) + { + throw invalid_argument("value must be non-null for non-zero bit count"); + } + } +#ifdef SEAL_USE_MSGSL_SPAN + BigUInt::BigUInt(gsl::span value) + { + if(unsigned_gt(value.size(), numeric_limits::max() / bits_per_uint64)) + { + throw std::invalid_argument("value has too large size"); + } + value_ = decltype(value_)::Aliasing(value.data()); + bit_count_ = static_cast(value.size()) * bits_per_uint64; + } +#endif + BigUInt::BigUInt(int bit_count, uint64_t value) + { + resize(bit_count); + operator =(value); + if (bit_count != bit_count_) + { + resize(bit_count); + } + } + + BigUInt::BigUInt(const BigUInt ©) + { + resize(copy.bit_count()); + operator =(copy); + } + + BigUInt::BigUInt(BigUInt &&source) noexcept : + pool_(move(source.pool_)), + value_(move(source.value_)), + bit_count_(source.bit_count_) + { + // Pointer in source has been taken over so set it to nullptr + source.bit_count_ = 0; + } + + BigUInt::~BigUInt() noexcept + { + reset(); + } + + string BigUInt::to_string() const + { + return uint_to_hex_string(value_.get(), uint64_count()); + } + + string BigUInt::to_dec_string() const + { + return uint_to_dec_string(value_.get(), uint64_count(), pool_); + } + + void BigUInt::resize(int bit_count) + { + if (bit_count < 0) + { + throw invalid_argument("bit_count must be non-negative"); + } + if (value_.is_alias()) + { + throw logic_error("Cannot resize an aliased BigUInt"); + } + if (bit_count == bit_count_) + { + return; + } + + // Lazy initialization of MemoryPoolHandle + if (!pool_) + { + pool_ = MemoryManager::GetPool(); + } + + // Fast path if allocation size doesn't change. + size_t old_uint64_count = uint64_count(); + size_t new_uint64_count = safe_cast( + divide_round_up(bit_count, bits_per_uint64)); + if (old_uint64_count == new_uint64_count) + { + bit_count_ = bit_count; + return; + } + + // Allocate new space. + decltype(value_) new_value; + if (new_uint64_count > 0) + { + new_value.swap_with(allocate_uint(new_uint64_count, pool_)); + } + + // Copy over old value. + if (new_uint64_count > 0) + { + set_uint_uint(value_.get(), old_uint64_count, new_uint64_count, new_value.get()); + filter_highbits_uint(new_value.get(), new_uint64_count, bit_count); + } + + // Deallocate any owned pointers. + reset(); + + // Update class. + value_.swap_with(new_value); + bit_count_ = bit_count; + } + + BigUInt &BigUInt::operator =(const BigUInt& assign) + { + // Do nothing if same thing. + if (&assign == this) + { + return *this; + } + + // Verify assigned value will fit within bit count. + int assign_sig_bit_count = assign.significant_bit_count(); + if (assign_sig_bit_count > bit_count_) + { + // Size is too large to currently fit, so resize. + resize(assign_sig_bit_count); + } + + // Copy over value. + size_t assign_uint64_count = safe_cast( + divide_round_up(assign_sig_bit_count, bits_per_uint64)); + if (uint64_count() > 0) + { + set_uint_uint(assign.value_.get(), assign_uint64_count, + uint64_count(), value_.get()); + } + return *this; + } + + BigUInt &BigUInt::operator =(const string &hex_value) + { + int hex_value_length = safe_cast(hex_value.size()); + + int assign_bit_count = get_hex_string_bit_count(hex_value.data(), hex_value_length); + if (assign_bit_count > bit_count_) + { + // Size is too large to currently fit, so resize. + resize(assign_bit_count); + } + if (bit_count_ > 0) + { + // Copy over value. + hex_string_to_uint(hex_value.data(), hex_value_length, uint64_count(), value_.get()); + } + return *this; + } + + BigUInt BigUInt::operator /(const BigUInt& operand2) const + { + int result_bits = significant_bit_count(); + int operand2_bits = operand2.significant_bit_count(); + if (operand2_bits == 0) + { + throw invalid_argument("operand2 must be positive"); + } + if (operand2_bits > result_bits) + { + BigUInt zero(result_bits); + return zero; + } + BigUInt result(result_bits); + BigUInt remainder(result_bits); + size_t result_uint64_count = result.uint64_count(); + if (result_uint64_count > operand2.uint64_count()) + { + BigUInt operand2resized(result_bits); + operand2resized = operand2; + divide_uint_uint(value_.get(), operand2resized.data(), result_uint64_count, + result.data(), remainder.data(), pool_); + } + else + { + divide_uint_uint(value_.get(), operand2.data(), result_uint64_count, + result.data(), remainder.data(), pool_); + } + return result; + } + + BigUInt BigUInt::divrem(const BigUInt& operand2, BigUInt &remainder) const + { + int result_bits = significant_bit_count(); + remainder = *this; + int operand2_bits = operand2.significant_bit_count(); + if (operand2_bits > result_bits) + { + BigUInt zero; + return zero; + } + BigUInt quotient(result_bits); + size_t uint64_count = remainder.uint64_count(); + if (uint64_count > operand2.uint64_count()) + { + BigUInt operand2resized(result_bits); + operand2resized = operand2; + divide_uint_uint_inplace(remainder.data(), operand2resized.data(), + uint64_count, quotient.data(), pool_); + } + else + { + divide_uint_uint_inplace(remainder.data(), operand2.data(), + uint64_count, quotient.data(), pool_); + } + return quotient; + } + + void BigUInt::save(ostream &stream) const + { + auto old_except_mask = stream.exceptions(); + try + { + // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit + stream.exceptions(ios_base::badbit | ios_base::failbit); + + int32_t bit_count32 = safe_cast(bit_count_); + streamsize data_size = safe_cast(mul_safe(uint64_count(), sizeof(uint64_t))); + stream.write(reinterpret_cast(&bit_count32), sizeof(int32_t)); + stream.write(reinterpret_cast(value_.get()), data_size); + } + catch (const exception &) + { + stream.exceptions(old_except_mask); + throw; + } + + stream.exceptions(old_except_mask); + } + + void BigUInt::load(istream &stream) + { + auto old_except_mask = stream.exceptions(); + try + { + // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit + stream.exceptions(ios_base::badbit | ios_base::failbit); + + int32_t read_bit_count = 0; + stream.read(reinterpret_cast(&read_bit_count), sizeof(int32_t)); + if (read_bit_count > bit_count_) + { + // Size is too large to currently fit, so resize. + resize(read_bit_count); + } + size_t read_uint64_count = safe_cast( + divide_round_up(read_bit_count, bits_per_uint64)); + streamsize data_size = safe_cast(mul_safe(read_uint64_count, sizeof(uint64_t))); + stream.read(reinterpret_cast(value_.get()), data_size); + + // Zero any extra space. + if (uint64_count() > read_uint64_count) + { + set_zero_uint(uint64_count() - read_uint64_count, + value_.get() + read_uint64_count); + } + } + catch (const exception &) + { + stream.exceptions(old_except_mask); + throw; + } + + stream.exceptions(old_except_mask); + } +} diff --git a/src/seal/biguint.h b/src/seal/biguint.h new file mode 100644 index 000000000..6d0b7189d --- /dev/null +++ b/src/seal/biguint.h @@ -0,0 +1,1648 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include "seal/memorymanager.h" +#include "seal/util/pointer.h" +#include "seal/util/common.h" +#include "seal/util/uintcore.h" +#include "seal/util/uintarith.h" +#include "seal/util/uintarithmod.h" + +namespace seal +{ + /** + Represents an unsigned integer with a specified bit width. Non-const + BigUInts are mutable and able to be resized. The bit count for a BigUInt + (which can be read with bit_count()) is set initially by the constructor + and can be resized either explicitly with the resize() function or + implicitly with an assignment operation (e.g., operator=(), operator+=(), + etc.). A rich set of unsigned integer operations are provided by the + BigUInt class, including comparison, traditional arithmetic (addition, + subtraction, multiplication, division), and modular arithmetic functions. + + @par Backing Array + The backing array for a BigUInt stores its unsigned integer value as + a contiguous std::uint64_t array. Each std::uint64_t in the array + sequentially represents 64-bits of the integer value, with the least + significant quad-word storing the lower 64-bits and the order of the bits + for each quad word dependent on the architecture's std::uint64_t + representation. The size of the array equals the bit count of the BigUInt + (which can be read with bit_count()) rounded up to the next std::uint64_t + boundary (i.e., rounded up to the next 64-bit boundary). The uint64_count() + function returns the number of std::uint64_t in the backing array. The + data() function returns a pointer to the first std::uint64_t in the array. + Additionally, the operator [] function allows accessing the individual + bytes of the integer value in a platform-independent way - for example, + reading the third byte will always return bits 16-24 of the BigUInt value + regardless of the platform being little-endian or big-endian. + + @par Implicit Resizing + Both the copy constructor and operator=() allocate more memory for the + backing array when needed, i.e. when the source BigUInt has a larger + backing array than the destination. Conversely, when the destination + backing array is already large enough, the data is only copied and the + unnecessary higher order bits are set to zero. When new memory has to be + allocated, only the significant bits of the source BigUInt are taken + into account. This is is important, because it avoids unnecessary zero + bits to be included in the destination, which in some cases could + accumulate and result in very large unnecessary allocations. However, + sometimes it is necessary to preserve the original size, even if some + of the leading bits are zero. For this purpose BigUInt contains functions + duplicate_from and duplicate_to, which create an exact copy of the source + BigUInt. + + @par Alias BigUInts + An aliased BigUInt (which can be determined with is_alias()) is a special + type of BigUInt that does not manage its underlying std::uint64_t pointer + used to store the value. An aliased BigUInt supports most of the same + operations as a non-aliased BigUInt, including reading and writing the + value, however an aliased BigUInt does not internally allocate or + deallocate its backing array and, therefore, does not support resizing. + Any attempt, either explicitly or implicitly, to resize the BigUInt will + result in an exception being thrown. An aliased BigUInt can be created + with the BigUInt(int, std::uint64_t*) constructor or the alias() function. + Note that the pointer specified to be aliased must be deallocated + externally after the BigUInt is no longer in use. Aliasing is useful in + cases where it is desirable to not have each BigUInt manage its own memory + allocation and/or to prevent unnecessary copying. + + @par Thread Safety + In general, reading a BigUInt is thread-safe while mutating is not. + Specifically, the backing array may be freed whenever a resize occurs, + the BigUInt is destroyed, or alias() is called, which would invalidate + the address returned by data() and the byte references returned by + operator []. When it is known that a resize will not occur, concurrent + reading and mutating will not inherently fail but it is possible for + a read to see a partially updated value from a concurrent write. + A non-aliased BigUInt allocates its backing array from the global + (thread-safe) memory pool. Consequently, creating or resizing a large + number of BigUInt can result in a performance loss due to thread + contention. + */ + class BigUInt + { + public: + /** + Creates an empty BigUInt with zero bit width. No memory is allocated + by this constructor. + */ + BigUInt() = default; + + /** + Creates a zero-initialized BigUInt of the specified bit width. + + @param[in] bit_count The bit width + @throws std::invalid_argument if bit_count is negative + */ + BigUInt(int bit_count); + + /** + Creates a BigUInt initialized and minimally sized to fit the unsigned + hexadecimal integer specified by the string. The string matches the format + returned by to_string() and must consist of only the characters 0-9, A-F, + or a-f, most-significant nibble first. + + @param[in] hex_value The hexadecimal integer string specifying the initial + value + @throws std::invalid_argument if hex_value does not adhere to the expected + format + */ + BigUInt(const std::string &hex_value); + + /** + Creates a BigUInt of the specified bit width and initializes it with the + unsigned hexadecimal integer specified by the string. The string must match + the format returned by to_string() and must consist of only the characters + 0-9, A-F, or a-f, most-significant nibble first. + + @param[in] bit_count The bit width + @param[in] hex_value The hexadecimal integer string specifying the initial + value + @throws std::invalid_argument if bit_count is negative + @throws std::invalid_argument if hex_value does not adhere to the expected + format + */ + BigUInt(int bit_count, const std::string &hex_value); + + /** + Creates an aliased BigUInt with the specified bit width and backing array. + An aliased BigUInt does not internally allocate or deallocate the backing + array, and instead uses the specified backing array for all read/write + operations. Note that resizing is not supported by an aliased BigUInt and + any required deallocation of the specified backing array must occur + externally after the aliased BigUInt is no longer in use. + + @param[in] bit_count The bit width + @param[in] value The backing array to use + @throws std::invalid_argument if bit_count is negative or value is null + and bit_count is positive + */ + BigUInt(int bit_count, std::uint64_t *value); +#ifdef SEAL_USE_MSGSL_SPAN + /** + Creates an aliased BigUInt with given backing array and bit width set to + the size of the backing array. An aliased BigUInt does not internally + allocate or deallocate the backing array, and instead uses the specified + backing array for all read/write operations. Note that resizing is not + supported by an aliased BigUInt and any required deallocation of the + specified backing array must occur externally after the aliased BigUInt + is no longer in use. + + @param[in] value The backing array to use + @throws std::invalid_argument if value has too large size + */ + BigUInt(gsl::span value); +#endif + /** + Creates a BigUInt of the specified bit width and initializes it to the + specified unsigned integer value. + + @param[in] bit_count The bit width + @param[in] value The initial value to set the BigUInt + @throws std::invalid_argument if bit_count is negative + */ + BigUInt(int bit_count, std::uint64_t value); + + /** + Creates a deep copy of a BigUInt. The created BigUInt will have the same + bit count and value as the original. + + @param[in] copy The BigUInt to copy from + */ + BigUInt(const BigUInt ©); + + /** + Creates a new BigUInt by moving an old one. + + @param[in] source The BigUInt to move from + */ + BigUInt(BigUInt &&source) noexcept; + + /** + Destroys the BigUInt and deallocates the backing array if it is not + an aliased BigUInt. + */ + ~BigUInt() noexcept; + + /** + Returns whether or not the BigUInt is an alias. + + @see BigUInt for a detailed description of aliased BigUInt. + */ + inline bool is_alias() const noexcept + { + return value_.is_alias(); + } + + /** + Returns the bit count for the BigUInt. + + @see significant_bit_count() to instead ignore leading zero bits. + */ + inline int bit_count() const noexcept + { + return bit_count_; + } + + /** + Returns a pointer to the backing array storing the BigUInt value. + The pointer points to the beginning of the backing array at the + least-significant quad word. + + @warning The pointer is valid only until the backing array is freed, + which occurs when the BigUInt is resized, destroyed, or the alias() + function is called. + @see uint64_count() to determine the number of std::uint64_t values + in the backing array. + */ + inline std::uint64_t *data() + { + return value_.get(); + } + + /** + Returns a const pointer to the backing array storing the BigUInt value. + The pointer points to the beginning of the backing array at the + least-significant quad word. + + @warning The pointer is valid only until the backing array is freed, which + occurs when the BigUInt is resized, destroyed, or the alias() function is + called. + @see uint64_count() to determine the number of std::uint64_t values in the + backing array. + */ + inline const std::uint64_t *data() const noexcept + { + return value_.get(); + } +#ifdef SEAL_USE_MSGSL_SPAN + /** + Returns the backing array storing the BigUInt value. + + @warning The span is valid only until the backing array is freed, which + occurs when the BigUInt is resized, destroyed, or the alias() function is + called. + */ + inline gsl::span data_span() + { + return gsl::span(value_.get(), + static_cast(uint64_count())); + } + + /** + Returns the backing array storing the BigUInt value. + + @warning The span is valid only until the backing array is freed, which + occurs when the BigUInt is resized, destroyed, or the alias() function is + called. + */ + inline gsl::span data_span() const + { + return gsl::span(value_.get(), + static_cast(uint64_count())); + } +#endif + /** + Returns the number of bytes in the backing array used to store the BigUInt + value. + + @see BigUInt for a detailed description of the format of the backing array. + */ + inline std::size_t byte_count() const + { + return static_cast( + util::divide_round_up(bit_count_, util::bits_per_byte)); + } + + /** + Returns the number of std::uint64_t in the backing array used to store + the BigUInt value. + + @see BigUInt for a detailed description of the format of the backing array. + */ + inline std::size_t uint64_count() const + { + return static_cast( + util::divide_round_up(bit_count_, util::bits_per_uint64)); + } + + /** + Returns the number of significant bits for the BigUInt. + + @see bit_count() to instead return the bit count regardless of leading zero + bits. + */ + inline int significant_bit_count() const + { + if (bit_count_ == 0) + { + return 0; + } + return util::get_significant_bit_count_uint(value_.get(), uint64_count()); + } + + /** + Returns the BigUInt value as a double. Note that precision may be lost during + the conversion. + */ + double to_double() const noexcept + { + const double TwoToThe64 = 18446744073709551616.0; + double result = 0; + for (std::size_t i = uint64_count(); i--; ) + { + result *= TwoToThe64; + result += static_cast(value_[i]); + } + return result; + } + + /** + Returns the BigUInt value as a hexadecimal string. + */ + std::string to_string() const; + + /** + Returns the BigUInt value as a decimal string. + */ + std::string to_dec_string() const; + + /** + Returns whether or not the BigUInt has the value zero. + */ + inline bool is_zero() const + { + if (bit_count_ == 0) + { + return true; + } + return util::is_zero_uint(value_.get(), uint64_count()); + } + + /** + Returns the byte at the corresponding byte index of the BigUInt's integer + value. The bytes of the BigUInt are indexed least-significant byte first. + + @param[in] index The index of the byte to read + @throws std::out_of_range if index is not within [0, byte_count()) + @see BigUInt for a detailed description of the format of the backing array. + */ + inline const SEAL_BYTE &operator [](std::size_t index) const + { + if (index >= byte_count()) + { + throw std::out_of_range("index must be within [0, byte count)"); + } + return *util::get_uint64_byte(value_.get(), index); + } + + /** + Returns an byte reference that can read/write the byte at the corresponding + byte index of the BigUInt's integer value. The bytes of the BigUInt are + indexed least-significant byte first. + + @warning The returned byte is an reference backed by the BigUInt's backing + array. As such, it is only valid until the BigUInt is resized, destroyed, + or alias() is called. + + @param[in] index The index of the byte to read + @throws std::out_of_range if index is not within [0, byte_count()) + @see BigUInt for a detailed description of the format of the backing array. + */ + inline SEAL_BYTE &operator [](std::size_t index) + { + if (index >= byte_count()) + { + throw std::out_of_range("index must be within [0, byte count)"); + } + return *util::get_uint64_byte(value_.get(), index); + } + + /** + Sets the BigUInt value to zero. This does not resize the BigUInt. + */ + inline void set_zero() + { + if (bit_count_) + { + return util::set_zero_uint(uint64_count(), value_.get()); + } + } + + /** + Resizes the BigUInt to the specified bit width, copying over the old value + as much as will fit. + + @param[in] bit_count The bit width + @throws std::invalid_argument if bit_count is negative + @throws std::logic_error if the BigUInt is an alias + */ + void resize(int bit_count); + + /** + Makes the BigUInt an aliased BigUInt with the specified bit width and + backing array. An aliased BigUInt does not internally allocate or + deallocate the backing array, and instead uses the specified backing array + for all read/write operations. Note that resizing is not supported by + an aliased BigUInt and any required deallocation of the specified backing + array must occur externally after the aliased BigUInt is no longer in use. + + @param[in] bit_count The bit width + @param[in] value The backing array to use + @throws std::invalid_argument if bit_count is negative or value is null + */ + inline void alias(int bit_count, std::uint64_t *value) + { + if (bit_count < 0) + { + throw std::invalid_argument("bit_count must be non-negative"); + } + if (value == nullptr && bit_count > 0) + { + throw std::invalid_argument("value must be non-null for non-zero bit count"); + } + + // Deallocate any owned pointers. + reset(); + + // Update class. + value_ = util::Pointer::Aliasing(value); + bit_count_ = bit_count; + } +#ifdef SEAL_USE_MSGSL_SPAN + /** + Makes the BigUInt an aliased BigUInt with the given backing array + and bit width set equal to the size of the backing array. An aliased + BigUInt does not internally allocate or deallocate the backing array, + and instead uses the specified backing array for all read/write + operations. Note that resizing is not supported by an aliased BigUInt + and any required deallocation of the specified backing array must + occur externally after the aliased BigUInt is no longer in use. + + @param[in] value The backing array to use + @throws std::invalid_argument if value has too large size + */ + inline void alias(gsl::span value) + { + if(util::unsigned_gt(value.size(), std::numeric_limits::max())) + { + throw std::invalid_argument("value has too large size"); + } + + // Deallocate any owned pointers. + reset(); + + // Update class. + value_ = util::Pointer::Aliasing(value.data()); + bit_count_ = static_cast(value.size());; + } +#endif + /** + Resets an aliased BigUInt into an empty non-alias BigUInt with bit count + of zero. + + @throws std::logic_error if BigUInt is not an alias + */ + inline void unalias() + { + if (!value_.is_alias()) + { + throw std::logic_error("BigUInt is not an alias"); + } + + // Reset class. + reset(); + } + + /** + Overwrites the BigUInt with the value of the specified BigUInt, enlarging + if needed to fit the assigned value. Only significant bits are used to + size the BigUInt. + + @param[in] assign The BigUInt whose value should be assigned to the + current BigUInt + @throws std::logic_error if BigUInt is an alias and the assigned BigUInt is + too large to fit the current bit width + */ + BigUInt &operator =(const BigUInt &assign); + + /** + Overwrites the BigUInt with the unsigned hexadecimal value specified by + the string, enlarging if needed to fit the assigned value. The string must + match the format returned by to_string() and must consist of only the + characters 0-9, A-F, or a-f, most-significant nibble first. + + @param[in] hex_value The hexadecimal integer string specifying the value + to assign + @throws std::invalid_argument if hex_value does not adhere to the + expected format + @throws std::logic_error if BigUInt is an alias and the assigned value + is too large to fit the current bit width + */ + BigUInt &operator =(const std::string &hex_value); + + /** + Overwrites the BigUInt with the specified integer value, enlarging if + needed to fit the value. + + @param[in] value The value to assign + @throws std::logic_error if BigUInt is an alias and the significant bit + count of value is too large to fit the + current bit width + */ + inline BigUInt &operator =(std::uint64_t value) + { + int assign_bit_count = util::get_significant_bit_count(value); + if (assign_bit_count > bit_count_) + { + // Size is too large to currently fit, so resize. + resize(assign_bit_count); + } + if (bit_count_ > 0) + { + util::set_uint(value, uint64_count(), value_.get()); + } + return *this; + } + + /** + Returns a copy of the BigUInt value resized to the significant bit count. + */ + inline BigUInt operator +() const + { + BigUInt result; + result = *this; + return result; + } + + /** + Returns a negated copy of the BigUInt value. The bit count does not change. + */ + inline BigUInt operator -() const + { + BigUInt result(bit_count_); + util::negate_uint(value_.get(), result.uint64_count(), result.data()); + util::filter_highbits_uint(result.data(), result.uint64_count(), result.bit_count()); + return result; + } + + /** + Returns an inverted copy of the BigUInt value. The bit count does not change. + */ + inline BigUInt operator ~() const + { + BigUInt result(bit_count_); + util::not_uint(value_.get(), result.uint64_count(), result.data()); + util::filter_highbits_uint(result.data(), result.uint64_count(), result.bit_count()); + return result; + } + + /** + Increments the BigUInt and returns the incremented value. The BigUInt will + increment the bit count if needed to fit the carry. + + @throws std::logic_error if BigUInt is an alias and a carry occurs requiring + the BigUInt to be resized + */ + inline BigUInt &operator ++() + { + if (util::increment_uint(value_.get(), uint64_count(), value_.get())) + { + resize(util::add_safe(bit_count_, 1)); + util::set_bit_uint(value_.get(), uint64_count(), bit_count_); + } + bit_count_ = std::max(bit_count_, significant_bit_count()); + return *this; + } + + /** + Decrements the BigUInt and returns the decremented value. The bit count + does not change. + */ + inline BigUInt &operator --() + { + util::decrement_uint(value_.get(), uint64_count(), value_.get()); + util::filter_highbits_uint(value_.get(), uint64_count(), bit_count_); + return *this; + } + + /** + Increments the BigUInt but returns its old value. The BigUInt will increment + the bit count if needed to fit the carry. + */ + inline BigUInt operator ++(int postfix SEAL_MAYBE_UNUSED) + { + BigUInt result; + result = *this; + if (util::increment_uint(value_.get(), uint64_count(), value_.get())) + { + resize(util::add_safe(bit_count_, 1)); + util::set_bit_uint(value_.get(), uint64_count(), bit_count_); + } + bit_count_ = std::max(bit_count_, significant_bit_count()); + return result; + } + + /** + Decrements the BigUInt but returns its old value. The bit count does not change. + */ + inline BigUInt operator --(int postfix SEAL_MAYBE_UNUSED) + { + BigUInt result; + result = *this; + util::decrement_uint(value_.get(), uint64_count(), value_.get()); + util::filter_highbits_uint(value_.get(), uint64_count(), bit_count_); + return result; + } + + /** + Adds two BigUInts and returns the sum. The input operands are not modified. + The bit count of the sum is set to be one greater than the significant bit + count of the larger of the two input operands. + + @param[in] operand2 The second operand to add + */ + inline BigUInt operator +(const BigUInt &operand2) const + { + int result_bits = util::add_safe(std::max(significant_bit_count(), + operand2.significant_bit_count()), 1); + BigUInt result(result_bits); + util::add_uint_uint(value_.get(), uint64_count(), operand2.data(), + operand2.uint64_count(), false, result.uint64_count(), result.data()); + return result; + } + + /** + Adds a BigUInt and an unsigned integer and returns the sum. The input + operands are not modified. The bit count of the sum is set to be one greater + than the significant bit count of the larger of the two operands. + + @param[in] operand2 The second operand to add + */ + inline BigUInt operator +(std::uint64_t operand2) const + { + BigUInt operand2uint; + operand2uint = operand2; + return *this + operand2uint; + } + + /** + Subtracts two BigUInts and returns the difference. The input operands are + not modified. The bit count of the difference is set to be the significant + bit count of the larger of the two input operands. + + @param[in] operand2 The second operand to subtract + */ + inline BigUInt operator -(const BigUInt &operand2) const + { + int result_bits = std::max(bit_count_, operand2.bit_count()); + BigUInt result(result_bits); + util::sub_uint_uint(value_.get(), uint64_count(), operand2.data(), + operand2.uint64_count(), false, result.uint64_count(), result.data()); + util::filter_highbits_uint(result.data(), result.uint64_count(), result_bits); + return result; + } + + /** + Subtracts a BigUInt and an unsigned integer and returns the difference. + The input operands are not modified. The bit count of the difference is set + to be the significant bit count of the larger of the two operands. + + @param[in] operand2 The second operand to subtract + */ + inline BigUInt operator -(std::uint64_t operand2) const + { + BigUInt operand2uint; + operand2uint = operand2; + return *this - operand2uint; + } + + /** + Multiplies two BigUInts and returns the product. The input operands are + not modified. The bit count of the product is set to be the sum of the + significant bit counts of the two input operands. + + @param[in] operand2 The second operand to multiply + */ + inline BigUInt operator *(const BigUInt &operand2) const + { + int result_bits = util::add_safe(significant_bit_count(), + operand2.significant_bit_count()); + BigUInt result(result_bits); + util::multiply_uint_uint(value_.get(), uint64_count(), operand2.data(), + operand2.uint64_count(), result.uint64_count(), result.data()); + return result; + } + + /** + Multiplies a BigUInt and an unsigned integer and returns the product. + The input operands are not modified. The bit count of the product is set + to be the sum of the significant bit counts of the two input operands. + + @param[in] operand2 The second operand to multiply + */ + inline BigUInt operator *(std::uint64_t operand2) const + { + BigUInt operand2uint; + operand2uint = operand2; + return *this * operand2uint; + } + + /** + Divides two BigUInts and returns the quotient. The input operands are + not modified. The bit count of the quotient is set to be the significant + bit count of the first input operand. + + @param[in] operand2 The second operand to divide + @throws std::invalid_argument if operand2 is zero + */ + BigUInt operator /(const BigUInt &operand2) const; + + /** + Divides a BigUInt and an unsigned integer and returns the quotient. The + input operands are not modified. The bit count of the quotient is set + to be the significant bit count of the first input operand. + + @param[in] operand2 The second operand to divide + @throws std::invalid_argument if operand2 is zero + */ + inline BigUInt operator /(std::uint64_t operand2) const + { + BigUInt operand2uint; + operand2uint = operand2; + return *this / operand2uint; + } + + /** + Performs a bit-wise XOR operation between two BigUInts and returns the + result. The input operands are not modified. The bit count of the result + is set to the maximum of the two input operand bit counts. + + @param[in] operand2 The second operand to XOR + */ + inline BigUInt operator ^(const BigUInt &operand2) const + { + int result_bits = std::max(bit_count_, operand2.bit_count()); + BigUInt result(result_bits); + std::size_t uint64_count = result.uint64_count(); + if (uint64_count != this->uint64_count()) + { + result = *this; + util::xor_uint_uint(result.data(), operand2.data(), uint64_count, result.data()); + } + else if (uint64_count != operand2.uint64_count()) + { + result = operand2; + util::xor_uint_uint(result.data(), value_.get(), uint64_count, result.data()); + } + else + { + util::xor_uint_uint(value_.get(), operand2.data(), uint64_count, result.data()); + } + return result; + } + + /** + Performs a bit-wise XOR operation between a BigUInt and an unsigned + integer and returns the result. The input operands are not modified. + The bit count of the result is set to the maximum of the two input + operand bit counts. + + @param[in] operand2 The second operand to XOR + */ + inline BigUInt operator ^(std::uint64_t operand2) const + { + BigUInt operand2uint; + operand2uint = operand2; + return *this ^ operand2uint; + } + + /** + Performs a bit-wise AND operation between two BigUInts and returns the + result. The input operands are not modified. The bit count of the result + is set to the maximum of the two input operand bit counts. + + @param[in] operand2 The second operand to AND + */ + inline BigUInt operator &(const BigUInt &operand2) const + { + int result_bits = std::max(bit_count_, operand2.bit_count()); + BigUInt result(result_bits); + std::size_t uint64_count = result.uint64_count(); + if (uint64_count != this->uint64_count()) + { + result = *this; + util::and_uint_uint(result.data(), operand2.data(), uint64_count, result.data()); + } + else if (uint64_count != operand2.uint64_count()) + { + result = operand2; + util::and_uint_uint(result.data(), value_.get(), uint64_count, result.data()); + } + else + { + util::and_uint_uint(value_.get(), operand2.data(), uint64_count, result.data()); + } + return result; + } + + /** + Performs a bit-wise AND operation between a BigUInt and an unsigned + integer and returns the result. The input operands are not modified. + The bit count of the result is set to the maximum of the two input + operand bit counts. + + @param[in] operand2 The second operand to AND + */ + inline BigUInt operator &(std::uint64_t operand2) const + { + BigUInt operand2uint; + operand2uint = operand2; + return *this & operand2uint; + } + + /** + Performs a bit-wise OR operation between two BigUInts and returns the + result. The input operands are not modified. The bit count of the result + is set to the maximum of the two input operand bit counts. + + @param[in] operand2 The second operand to OR + */ + inline BigUInt operator |(const BigUInt &operand2) const + { + int result_bits = std::max(bit_count_, operand2.bit_count()); + BigUInt result(result_bits); + std::size_t uint64_count = result.uint64_count(); + if (uint64_count != this->uint64_count()) + { + result = *this; + util::or_uint_uint(result.data(), operand2.data(), uint64_count, result.data()); + } + else if (uint64_count != operand2.uint64_count()) + { + result = operand2; + util::or_uint_uint(result.data(), value_.get(), uint64_count, result.data()); + } + else + { + util::or_uint_uint(value_.get(), operand2.data(), uint64_count, result.data()); + } + return result; + } + + /** + Performs a bit-wise OR operation between a BigUInt and an unsigned + integer and returns the result. The input operands are not modified. + The bit count of the result is set to the maximum of the two input + operand bit counts. + + @param[in] operand2 The second operand to OR + */ + inline BigUInt operator |(std::uint64_t operand2) const + { + BigUInt operand2uint; + operand2uint = operand2; + return *this | operand2uint; + } + + /** + Compares two BigUInts and returns -1, 0, or 1 if the BigUInt is + less-than, equal-to, or greater-than the second operand respectively. + The input operands are not modified. + + @param[in] compare The value to compare against + */ + inline int compareto(const BigUInt &compare) const + { + return util::compare_uint_uint(value_.get(), uint64_count(), + compare.value_.get(), compare.uint64_count()); + } + + /** + Compares a BigUInt and an unsigned integer and returns -1, 0, or 1 if + the BigUInt is less-than, equal-to, or greater-than the second operand + respectively. The input operands are not modified. + + @param[in] compare The value to compare against + */ + inline int compareto(std::uint64_t compare) const + { + BigUInt compareuint; + compareuint = compare; + return compareto(compareuint); + } + + /** + Returns whether or not a BigUInt is less-than a second BigUInt. The + input operands are not modified. + + @param[in] operand2 The value to compare against + */ + inline bool operator <(const BigUInt &compare) const + { + return util::compare_uint_uint(value_.get(), uint64_count(), + compare.value_.get(), compare.uint64_count()) < 0; + } + + /** + Returns whether or not a BigUInt is less-than an unsigned integer. + The input operands are not modified. + + @param[in] operand2 The value to compare against + */ + inline bool operator <(std::uint64_t compare) const + { + BigUInt compareuint; + compareuint = compare; + return *this < compareuint; + } + + /** + Returns whether or not a BigUInt is greater-than a second BigUInt. + The input operands are not modified. + + @param[in] operand2 The value to compare against + */ + inline bool operator >(const BigUInt &compare) const + { + return util::compare_uint_uint(value_.get(), uint64_count(), + compare.value_.get(), compare.uint64_count()) > 0; + } + + /** + Returns whether or not a BigUInt is greater-than an unsigned integer. + The input operands are not modified. + + @param[in] operand2 The value to compare against + */ + inline bool operator >(std::uint64_t compare) const + { + BigUInt compareuint; + compareuint = compare; + return *this > compareuint; + } + + /** + Returns whether or not a BigUInt is less-than or equal to a second + BigUInt. The input operands are not modified. + + @param[in] operand2 The value to compare against + */ + inline bool operator <=(const BigUInt &compare) const + { + return util::compare_uint_uint(value_.get(), uint64_count(), + compare.value_.get(), compare.uint64_count()) <= 0; + } + + /** + Returns whether or not a BigUInt is less-than or equal to an unsigned + integer. The input operands are not modified. + + @param[in] operand2 The value to compare against + */ + inline bool operator <=(std::uint64_t compare) const + { + BigUInt compareuint; + compareuint = compare; + return *this <= compareuint; + } + + /** + Returns whether or not a BigUInt is greater-than or equal to a second + BigUInt. The input operands are not modified. + + @param[in] operand2 The value to compare against + */ + inline bool operator >=(const BigUInt &compare) const + { + return util::compare_uint_uint(value_.get(), uint64_count(), + compare.value_.get(), compare.uint64_count()) >= 0; + } + + /** + Returns whether or not a BigUInt is greater-than or equal to an unsigned + integer. The input operands are not modified. + + @param[in] operand2 The value to compare against + */ + inline bool operator >=(std::uint64_t compare) const + { + BigUInt compareuint; + compareuint = compare; + return *this >= compareuint; + } + + /** + Returns whether or not a BigUInt is equal to a second BigUInt. + The input operands are not modified. + + @param[in] compare The value to compare against + */ + inline bool operator ==(const BigUInt &compare) const + { + return util::compare_uint_uint(value_.get(), uint64_count(), + compare.value_.get(), compare.uint64_count()) == 0; + } + + /** + Returns whether or not a BigUInt is equal to an unsigned integer. + The input operands are not modified. + + @param[in] compare The value to compare against + */ + inline bool operator ==(std::uint64_t compare) const + { + BigUInt compareuint; + compareuint = compare; + return *this == compareuint; + } + + /** + Returns whether or not a BigUInt is not equal to a second BigUInt. + The input operands are not modified. + + @param[in] compare The value to compare against + */ + inline bool operator !=(const BigUInt &compare) const + { + return !(operator ==(compare)); + } + + /** + Returns whether or not a BigUInt is not equal to an unsigned integer. + The input operands are not modified. + + @param[in] operand2 The value to compare against + */ + inline bool operator !=(std::uint64_t compare) const + { + BigUInt compareuint; + compareuint = compare; + return *this != compareuint; + } + + /** + Returns a left-shifted copy of the BigUInt. The bit count of the + returned value is the sum of the original significant bit count and + the shift amount. + + @param[in] shift The number of bits to shift by + @throws std::invalid_argument if shift is negative + */ + inline BigUInt operator <<(int shift) const + { + if (shift < 0) + { + throw std::invalid_argument("shift must be non-negative"); + } + int result_bits = util::add_safe(significant_bit_count(), shift); + BigUInt result(result_bits); + result = *this; + util::left_shift_uint( + result.data(), shift, result.uint64_count(), result.data()); + return result; + } + + /** + Returns a right-shifted copy of the BigUInt. The bit count of the + returned value is the original significant bit count subtracted by + the shift amount (clipped to zero if negative). + + @param[in] shift The number of bits to shift by + @throws std::invalid_argument if shift is negative + */ + inline BigUInt operator >>(int shift) const + { + if (shift < 0) + { + throw std::invalid_argument("shift must be non-negative"); + } + int result_bits = util::sub_safe(significant_bit_count(), shift); + if (result_bits <= 0) + { + BigUInt zero; + return zero; + } + BigUInt result(result_bits); + result = *this; + util::right_shift_uint( + result.data(), shift, result.uint64_count(), result.data()); + return result; + } + + /** + Adds two BigUInts saving the sum to the first operand, returning + a reference of the first operand. The second input operand is not + modified. The first operand is resized if and only if its bit count + is smaller than one greater than the significant bit count of the + larger of the two input operands. + + @param[in] operand2 The second operand to add + @throws std::logic_error if the BigUInt is an alias and the operator + attempts to enlarge the BigUInt to fit the result + */ + inline BigUInt &operator +=(const BigUInt &operand2) + { + int result_bits = util::add_safe(std::max( + significant_bit_count(), operand2.significant_bit_count()), 1); + if (bit_count_ < result_bits) + { + resize(result_bits); + } + util::add_uint_uint(value_.get(), uint64_count(), operand2.data(), + operand2.uint64_count(), false, uint64_count(), value_.get()); + return *this; + } + + /** + Adds a BigUInt and an unsigned integer saving the sum to the first operand, + returning a reference of the first operand. The second input operand is not + modified. The first operand is resized if and only if its bit count is + smaller than one greater than the significant bit count of the larger of + the two input operands. + + @param[in] operand2 The second operand to add + @throws std::logic_error if the BigUInt is an alias and the operator + attempts to enlarge the BigUInt to fit the result + */ + inline BigUInt &operator +=(std::uint64_t operand2) + { + BigUInt operand2uint; + operand2uint = operand2; + return operator +=(operand2uint); + } + + /** + Subtracts two BigUInts saving the difference to the first operand, + returning a reference of the first operand. The second input operand is + not modified. The first operand is resized if and only if its bit count + is smaller than the significant bit count of the second operand. + + @param[in] operand2 The second operand to subtract + @throws std::logic_error if the BigUInt is an alias and the operator + attempts to enlarge the BigUInt to fit the result + */ + inline BigUInt &operator -=(const BigUInt &operand2) + { + int result_bits = std::max(bit_count_, operand2.bit_count()); + if (bit_count_ < result_bits) + { + resize(result_bits); + } + util::sub_uint_uint(value_.get(), uint64_count(), operand2.data(), + operand2.uint64_count(), false, uint64_count(), value_.get()); + util::filter_highbits_uint(value_.get(), uint64_count(), result_bits); + return *this; + } + + /** + Subtracts a BigUInt and an unsigned integer saving the difference to + the first operand, returning a reference of the first operand. The second + input operand is not modified. The first operand is resized if and only + if its bit count is smaller than the significant bit count of the second + operand. + + @param[in] operand2 The second operand to subtract + @throws std::logic_error if the BigUInt is an alias and the operator + attempts to enlarge the BigUInt to fit the result + */ + inline BigUInt &operator -=(std::uint64_t operand2) + { + BigUInt operand2uint; + operand2uint = operand2; + return operator -=(operand2uint); + } + + /** + Multiplies two BigUInts saving the product to the first operand, + returning a reference of the first operand. The second input operand + is not modified. The first operand is resized if and only if its bit + count is smaller than the sum of the significant bit counts of the two + input operands. + + @param[in] operand2 The second operand to multiply + @throws std::logic_error if the BigUInt is an alias and the operator + attempts to enlarge the BigUInt to fit the result + */ + inline BigUInt &operator *=(const BigUInt &operand2) + { + *this = *this * operand2; + return *this; + } + + /** + Multiplies a BigUInt and an unsigned integer saving the product to + the first operand, returning a reference of the first operand. The + second input operand is not modified. The first operand is resized if + and only if its bit count is smaller than the sum of the significant + bit counts of the two input operands. + + @param[in] operand2 The second operand to multiply + @throws std::logic_error if the BigUInt is an alias and the operator + attempts to enlarge the BigUInt to fit the result + */ + inline BigUInt &operator *=(std::uint64_t operand2) + { + BigUInt operand2uint; + operand2uint = operand2; + return operator *=(operand2uint); + } + + /** + Divides two BigUInts saving the quotient to the first operand, + returning a reference of the first operand. The second input operand + is not modified. The first operand is never resized. + + @param[in] operand2 The second operand to divide + @throws std::invalid_argument if operand2 is zero + */ + inline BigUInt &operator /=(const BigUInt &operand2) + { + *this = *this / operand2; + return *this; + } + + /** + Divides a BigUInt and an unsigned integer saving the quotient to + the first operand, returning a reference of the first operand. The + second input operand is not modified. The first operand is never resized. + + @param[in] operand2 The second operand to divide + @throws std::invalid_argument if operand2 is zero + */ + inline BigUInt &operator /=(std::uint64_t operand2) + { + BigUInt operand2uint; + operand2uint = operand2; + return operator /=(operand2uint); + } + + /** + Performs a bit-wise XOR operation between two BigUInts saving the result + to the first operand, returning a reference of the first operand. The + second input operand is not modified. The first operand is resized if + and only if its bit count is smaller than the significant bit count of + the second operand. + + @param[in] operand2 The second operand to XOR + @throws std::logic_error if the BigUInt is an alias and the operator + attempts to enlarge the BigUInt to fit the result + */ + inline BigUInt &operator ^=(const BigUInt &operand2) + { + int result_bits = std::max(bit_count_, operand2.bit_count()); + if (bit_count_ != result_bits) + { + resize(result_bits); + } + util::xor_uint_uint( + value_.get(), operand2.data(), operand2.uint64_count(), value_.get()); + return *this; + } + + /** + Performs a bit-wise XOR operation between a BigUInt and an unsigned integer + saving the result to the first operand, returning a reference of the first + operand. The second input operand is not modified. The first operand is + resized if and only if its bit count is smaller than the significant bit + count of the second operand. + + @param[in] operand2 The second operand to XOR + @throws std::logic_error if the BigUInt is an alias and the operator + attempts to enlarge the BigUInt to fit the result + */ + inline BigUInt &operator ^=(std::uint64_t operand2) + { + BigUInt operand2uint; + operand2uint = operand2; + return operator ^=(operand2uint); + } + + /** + Performs a bit-wise AND operation between two BigUInts saving the result + to the first operand, returning a reference of the first operand. The + second input operand is not modified. The first operand is resized if + and only if its bit count is smaller than the significant bit count of + the second operand. + + @param[in] operand2 The second operand to AND + @throws std::logic_error if the BigUInt is an alias and the operator + attempts to enlarge the BigUInt to fit the result + */ + inline BigUInt &operator &=(const BigUInt &operand2) + { + int result_bits = std::max(bit_count_, operand2.bit_count()); + if (bit_count_ != result_bits) + { + resize(result_bits); + } + util::and_uint_uint( + value_.get(), operand2.data(), operand2.uint64_count(), value_.get()); + return *this; + } + + /** + Performs a bit-wise AND operation between a BigUInt and an unsigned integer + saving the result to the first operand, returning a reference of the first + operand. The second input operand is not modified. The first operand is + resized if and only if its bit count is smaller than the significant bit + count of the second operand. + + @param[in] operand2 The second operand to AND + @throws std::logic_error if the BigUInt is an alias and the operator + attempts to enlarge the BigUInt to fit the result + */ + inline BigUInt &operator &=(std::uint64_t operand2) + { + BigUInt operand2uint; + operand2uint = operand2; + return operator &=(operand2uint); + } + + /** + Performs a bit-wise OR operation between two BigUInts saving the result to + the first operand, returning a reference of the first operand. The second + input operand is not modified. The first operand is resized if and only if + its bit count is smaller than the significant bit count of the second + operand. + + @param[in] operand2 The second operand to OR + @throws std::logic_error if the BigUInt is an alias and the operator + attempts to enlarge the BigUInt to fit the result + */ + inline BigUInt &operator |=(const BigUInt &operand2) + { + int result_bits = std::max(bit_count_, operand2.bit_count()); + if (bit_count_ != result_bits) + { + resize(result_bits); + } + util::or_uint_uint(value_.get(), operand2.data(), + operand2.uint64_count(), value_.get()); + return *this; + } + + /** + Performs a bit-wise OR operation between a BigUInt and an unsigned integer + saving the result to the first operand, returning a reference of the first + operand. The second input operand is not modified. The first operand is + resized if and only if its bit count is smaller than the significant bit + count of the second operand. + + @param[in] operand2 The second operand to OR + @throws std::logic_error if the BigUInt is an alias and the operator + attempts to enlarge the BigUInt to fit the result + */ + inline BigUInt &operator |=(std::uint64_t operand2) + { + BigUInt operand2uint; + operand2uint = operand2; + return operator |=(operand2uint); + } + + /** + Left-shifts a BigUInt by the specified amount. The BigUInt is resized if + and only if its bit count is smaller than the sum of its significant bit + count and the shift amount. + + @param[in] shift The number of bits to shift by + @throws std::Invalid_argument if shift is negative + @throws std::logic_error if the BigUInt is an alias and the operator + attempts to enlarge the BigUInt to fit the result + */ + inline BigUInt &operator <<=(int shift) + { + if (shift < 0) + { + throw std::invalid_argument("shift must be non-negative"); + } + int result_bits = util::add_safe(significant_bit_count(), shift); + if (bit_count_ < result_bits) + { + resize(result_bits); + } + util::left_shift_uint(value_.get(), shift, uint64_count(), value_.get()); + return *this; + } + + /** + Right-shifts a BigUInt by the specified amount. The BigUInt is never + resized. + + @param[in] shift The number of bits to shift by + @throws std::Invalid_argument if shift is negative + */ + inline BigUInt &operator >>=(int shift) + { + if (shift < 0) + { + throw std::invalid_argument("shift must be non-negative"); + } + if (shift > bit_count_) + { + set_zero(); + return *this; + } + util::right_shift_uint(value_.get(), shift, uint64_count(), value_.get()); + return *this; + } + + /** + Divides two BigUInts and returns the quotient and sets the remainder + parameter to the remainder. The bit count of the quotient is set to be + the significant bit count of the BigUInt. The remainder is resized if + and only if it is smaller than the bit count of the BigUInt. + + @param[in] operand2 The second operand to divide + @param[out] remainder The BigUInt to store the remainder + @throws std::Invalid_argument if operand2 is zero + @throws std::logic_error if the remainder is an alias and the operator + attempts to enlarge the BigUInt to fit the result + */ + BigUInt divrem(const BigUInt &operand2, BigUInt &remainder) const; + + /** + Divides a BigUInt and an unsigned integer and returns the quotient and + sets the remainder parameter to the remainder. The bit count of the + quotient is set to be the significant bit count of the BigUInt. The + remainder is resized if and only if it is smaller than the bit count + of the BigUInt. + + @param[in] operand2 The second operand to divide + @param[out] remainder The BigUInt to store the remainder + @throws std::Invalid_argument if operand2 is zero + @throws std::logic_error if the remainder is an alias which the + function attempts to enlarge to fit the result + */ + inline BigUInt divrem(std::uint64_t operand2, BigUInt &remainder) const + { + BigUInt operand2uint; + operand2uint = operand2; + return divrem(operand2uint, remainder); + } + + /** + Returns the inverse of a BigUInt with respect to the specified modulus. + The original BigUInt is not modified. The bit count of the inverse is + set to be the significant bit count of the modulus. + + @param[in] modulus The modulus to calculate the inverse with respect to + @throws std::Invalid_argument if modulus is zero + @throws std::Invalid_argument if modulus is not greater than the BigUInt value + @throws std::Invalid_argument if the BigUInt value and modulus are not co-prime + */ + inline BigUInt modinv(const BigUInt &modulus) const + { + if (modulus.is_zero()) + { + throw std::invalid_argument("modulus must be positive"); + } + int result_bits = modulus.significant_bit_count(); + if (*this >= modulus) + { + throw std::invalid_argument("modulus must be greater than BigUInt"); + } + BigUInt result(result_bits); + result = *this; + if (!util::try_invert_uint_mod(result.data(), modulus.data(), + result.uint64_count(), result.data(), pool_)) + { + throw std::invalid_argument("BigUInt and modulus are not co-prime"); + } + return result; + } + + /** + Returns the inverse of a BigUInt with respect to the specified modulus. + The original BigUInt is not modified. The bit count of the inverse is set + to be the significant bit count of the modulus. + + @param[in] modulus The modulus to calculate the inverse with respect to + @throws std::Invalid_argument if modulus is zero + @throws std::Invalid_argument if modulus is not greater than the BigUInt value + @throws std::Invalid_argument if the BigUInt value and modulus are not co-prime + */ + inline BigUInt modinv(std::uint64_t modulus) const + { + BigUInt modulusuint; + modulusuint = modulus; + return modinv(modulusuint); + } + + /** + Attempts to calculate the inverse of a BigUInt with respect to the + specified modulus, returning whether or not the inverse was successful + and setting the inverse parameter to the inverse. The original BigUInt + is not modified. The inverse parameter is resized if and only if its bit + count is smaller than the significant bit count of the modulus. + + @param[in] modulus The modulus to calculate the inverse with respect to + @param[out] inverse Stores the inverse if the inverse operation was + successful + @throws std::Invalid_argument if modulus is zero + @throws std::Invalid_argument if modulus is not greater than the BigUInt + value + @throws std::logic_error if the inverse is an alias which the function + attempts to enlarge to fit the result + */ + inline bool trymodinv(const BigUInt &modulus, BigUInt &inverse) const + { + if (modulus.is_zero()) + { + throw std::invalid_argument("modulus must be positive"); + } + int result_bits = modulus.significant_bit_count(); + if (*this >= modulus) + { + throw std::invalid_argument("modulus must be greater than BigUInt"); + } + if (inverse.bit_count() < result_bits) + { + inverse.resize(result_bits); + } + inverse = *this; + return util::try_invert_uint_mod(inverse.data(), modulus.data(), + inverse.uint64_count(), inverse.data(), pool_); + } + + /** + Attempts to calculate the inverse of a BigUInt with respect to the + specified modulus, returning whether or not the inverse was successful + and setting the inverse parameter to the inverse. The original BigUInt + is not modified. The inverse parameter is resized if and only if its + bit count is smaller than the significant bit count of the modulus. + + @param[in] modulus The modulus to calculate the inverse with respect to + @param[out] inverse Stores the inverse if the inverse operation was + successful + @throws std::Invalid_argument if modulus is zero + @throws std::Invalid_argument if modulus is not greater than the BigUInt + value + @throws std::logic_error if the inverse is an alias which the function + attempts to enlarge to fit the result + */ + inline bool trymodinv(std::uint64_t modulus, BigUInt &inverse) const + { + BigUInt modulusuint; + modulusuint = modulus; + return trymodinv(modulusuint, inverse); + } + + /** + Saves the BigUInt to an output stream. The full state of the BigUInt is + serialized, including insignificant bits. The output is in binary format + and not human-readable. The output stream must have the "binary" flag set. + + @param[in] stream The stream to save the BigUInt to + @throws std::exception if the BigUInt could not be written to stream + */ + void save(std::ostream &stream) const; + + /** + Loads a BigUInt from an input stream overwriting the current BigUInt + and enlarging if needed to fit the loaded BigUInt. + + @param[in] stream The stream to load the BigUInt from + @throws std::logic_error if BigUInt is an alias and the loaded BigUInt + is too large to fit with the current bit + @throws std::exception if a valid BigUInt could not be read from stream + */ + void load(std::istream &stream); + + /** + Creates a minimally sized BigUInt initialized to the specified unsigned + integer value. + + @param[in] value The value to initialized the BigUInt to + */ + inline static BigUInt of(std::uint64_t value) + { + BigUInt result; + result = value; + return result; + } + + /** + Duplicates the current BigUInt. The bit count and the value of the + given BigUInt are set to be exactly the same as in the current one. + + @param[out] destination The BigUInt to overwrite with the duplicate + @throws std::logic_error if the destination BigUInt is an alias + */ + inline void duplicate_to(BigUInt &destination) const + { + destination.resize(this->bit_count_); + destination = *this; + } + + /** + Duplicates a given BigUInt. The bit count and the value of the current + BigUInt are set to be exactly the same as in the given one. + + @param[in] value The BigUInt to duplicate + @throws std::logic_error if the current BigUInt is an alias + */ + inline void duplicate_from(const BigUInt &value) + { + this->resize(value.bit_count_); + *this = value; + } + + private: + MemoryPoolHandle pool_; + + /** + Resets the entire state of the BigUInt to an empty, zero-sized state, + freeing any memory it internally allocated. If the BigUInt was an alias, + the backing array is not freed but the alias is no longer referenced. + */ + inline void reset() noexcept + { + value_.release(); + bit_count_ = 0; + } + + /** + Points to the backing array for the BigUInt. This pointer will be set + to nullptr if and only if the bit count is zero. This pointer is + automatically allocated and freed by the BigUInt if and only if + the BigUInt is not an alias. If the BigUInt is an alias, then the + pointer was passed-in to a constructor or alias() call, and will not be + deallocated by the BigUInt. + + @see BigUInt for more information about aliased BigUInts or the format + of the backing array. + */ + util::Pointer value_; + + /** + The bit count for the BigUInt. + */ + int bit_count_ = 0; + }; +} diff --git a/src/seal/ciphertext.cpp b/src/seal/ciphertext.cpp new file mode 100644 index 000000000..b2c84a78c --- /dev/null +++ b/src/seal/ciphertext.cpp @@ -0,0 +1,254 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/ciphertext.h" +#include "seal/util/polycore.h" + +using namespace std; +using namespace seal::util; + +namespace seal +{ + Ciphertext &Ciphertext::operator =(const Ciphertext &assign) + { + // Check for self-assignment + if (this == &assign) + { + return *this; + } + + // Copy over fields + parms_id_ = assign.parms_id_; + is_ntt_form_ = assign.is_ntt_form_; + scale_ = assign.scale_; + + // Then resize + resize_internal(assign.size_, assign.poly_modulus_degree_, + assign.coeff_mod_count_); + + // Size is guaranteed to be OK now so copy over + copy(assign.data_.cbegin(), assign.data_.cend(), data_.begin()); + + return *this; + } + + void Ciphertext::reserve(shared_ptr context, + parms_id_type parms_id, size_type size_capacity) + { + // Verify parameters + if (!context) + { + throw invalid_argument("invalid context"); + } + if (!context->parameters_set()) + { + throw invalid_argument("encryption parameters are not set correctly"); + } + + auto context_data_ptr = context->context_data(parms_id); + if (!context_data_ptr) + { + throw invalid_argument("parms_id is not valid for encryption parameters"); + } + + // Need to set parms_id first + auto &parms = context_data_ptr->parms(); + parms_id_ = parms.parms_id(); + + reserve_internal(size_capacity, parms.poly_modulus_degree(), + safe_cast(parms.coeff_modulus().size())); + } + + void Ciphertext::reserve_internal(size_type size_capacity, + size_type poly_modulus_degree, size_type coeff_mod_count) + { + if (size_capacity < SEAL_CIPHERTEXT_SIZE_MIN || + size_capacity > SEAL_CIPHERTEXT_SIZE_MAX) + { + throw invalid_argument("invalid size_capacity"); + } + + size_type new_data_capacity = + mul_safe(size_capacity, poly_modulus_degree, coeff_mod_count); + size_type new_data_size = min(new_data_capacity, data_.size()); + + // First reserve, then resize + data_.reserve(new_data_capacity); + data_.resize(new_data_size); + + // Set the size and size_capacity + size_capacity_ = size_capacity; + size_ = min(size_capacity, size_); + poly_modulus_degree_ = poly_modulus_degree; + coeff_mod_count_ = coeff_mod_count; + } + + void Ciphertext::resize(shared_ptr context, + parms_id_type parms_id, size_type size) + { + // Verify parameters + if (!context) + { + throw invalid_argument("invalid context"); + } + if (!context->parameters_set()) + { + throw invalid_argument("encryption parameters are not set correctly"); + } + + auto context_data_ptr = context->context_data(parms_id); + if (!context_data_ptr) + { + throw invalid_argument("parms_id is not valid for encryption parameters"); + } + + // Need to set parms_id first + auto &parms = context_data_ptr->parms(); + parms_id_ = parms.parms_id(); + + resize_internal(size, parms.poly_modulus_degree(), + safe_cast(parms.coeff_modulus().size())); + } + + void Ciphertext::resize_internal(size_type size, + size_type poly_modulus_degree, size_type coeff_mod_count) + { + if ((size < SEAL_CIPHERTEXT_SIZE_MIN && size != 0) || + size > SEAL_CIPHERTEXT_SIZE_MAX) + { + throw invalid_argument("invalid size"); + } + + // Resize the data + size_type new_data_size = + mul_safe(size, poly_modulus_degree, coeff_mod_count); + data_.resize(new_data_size); + + // Set the size parameters + size_ = size; + poly_modulus_degree_ = poly_modulus_degree; + coeff_mod_count_ = coeff_mod_count; + } + + bool Ciphertext::is_valid_for(shared_ptr context) const noexcept + { + // Verify parameters + if (!context || !context->parameters_set()) + { + return false; + } + + auto context_data_ptr = context->context_data(parms_id_); + if (!context_data_ptr) + { + return false; + } + + auto &coeff_modulus = context_data_ptr->parms().coeff_modulus(); + size_t poly_modulus_degree = context_data_ptr->parms().poly_modulus_degree(); + if ((coeff_modulus.size() != coeff_mod_count_) || + (poly_modulus_degree != poly_modulus_degree_)) + { + return false; + } + + const ct_coeff_type *ptr = data(); + for (size_t i = 0; i < size_; i++) + { + for (size_t j = 0; j < coeff_mod_count_; j++) + { + uint64_t modulus = coeff_modulus[j].value(); + for (size_t k = 0; k < poly_modulus_degree_; k++, ptr++) + { + if (*ptr >= modulus) + { + return false; + } + } + } + } + + return true; + } + + void Ciphertext::save(ostream &stream) const + { + auto old_except_mask = stream.exceptions(); + try + { + // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit + stream.exceptions(ios_base::badbit | ios_base::failbit); + + stream.write(reinterpret_cast(&parms_id_), sizeof(parms_id_type)); + SEAL_BYTE is_ntt_form_byte = static_cast(is_ntt_form_); + stream.write(reinterpret_cast(&is_ntt_form_byte), sizeof(SEAL_BYTE)); + uint64_t size64 = safe_cast(size_); + stream.write(reinterpret_cast(&size64), sizeof(uint64_t)); + uint64_t poly_modulus_degree64 = safe_cast(poly_modulus_degree_); + stream.write(reinterpret_cast(&poly_modulus_degree64), sizeof(uint64_t)); + uint64_t coeff_mod_count64 = safe_cast(coeff_mod_count_); + stream.write(reinterpret_cast(&coeff_mod_count64), sizeof(uint64_t)); + stream.write(reinterpret_cast(&scale_), sizeof(double)); + + // Save the data + data_.save(stream); + } + catch (const exception &) + { + stream.exceptions(old_except_mask); + throw; + } + + stream.exceptions(old_except_mask); + } + + void Ciphertext::unsafe_load(istream &stream) + { + auto old_except_mask = stream.exceptions(); + try + { + // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit + stream.exceptions(ios_base::badbit | ios_base::failbit); + + parms_id_type parms_id{}; + stream.read(reinterpret_cast(&parms_id), sizeof(parms_id_type)); + SEAL_BYTE is_ntt_form_byte; + stream.read(reinterpret_cast(&is_ntt_form_byte), sizeof(SEAL_BYTE)); + uint64_t size64 = 0; + stream.read(reinterpret_cast(&size64), sizeof(uint64_t)); + uint64_t poly_modulus_degree64 = 0; + stream.read(reinterpret_cast(&poly_modulus_degree64), sizeof(uint64_t)); + uint64_t coeff_mod_count64 = 0; + stream.read(reinterpret_cast(&coeff_mod_count64), sizeof(uint64_t)); + double scale = 0; + stream.read(reinterpret_cast(&scale), sizeof(double)); + + // Load the data + IntArray new_data(data_.pool()); + new_data.load(stream); + if (unsigned_neq(new_data.size(), + mul_safe(size64, poly_modulus_degree64, coeff_mod_count64))) + { + throw invalid_argument("ciphertext data is invalid"); + } + + // Set values + parms_id_ = parms_id; + is_ntt_form_ = (is_ntt_form_byte == SEAL_BYTE(0)) ? false : true; + size_ = safe_cast(size64); + poly_modulus_degree_ = safe_cast(poly_modulus_degree64); + coeff_mod_count_ = safe_cast(coeff_mod_count64); + scale_ = scale; + + // Set the data + data_.swap_with(new_data); + } + catch (const exception &) + { + stream.exceptions(old_except_mask); + throw; + } + + stream.exceptions(old_except_mask); + } +} diff --git a/src/seal/ciphertext.h b/src/seal/ciphertext.h new file mode 100644 index 000000000..008aae4b6 --- /dev/null +++ b/src/seal/ciphertext.h @@ -0,0 +1,642 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include +#include +#include "seal/util/defines.h" +#include "seal/context.h" +#include "seal/memorymanager.h" +#include "seal/intarray.h" + +namespace seal +{ + /** + Class to store a ciphertext element. The data for a ciphertext consists + of two or more polynomials, which are in SEAL stored in a CRT form with + respect to the factors of the coefficient modulus. This data itself is + not meant to be modified directly by the user, but is instead operated + on by functions in the Evaluator class. The size of the backing array of + a ciphertext depends on the encryption parameters and the size of the + ciphertext (at least 2). If the degree of the poly_modulus encryption + parameter is N, and the number of primes in the coeff_modulus encryption + parameter is K, then the ciphertext backing array requires precisely + 8*N*K*size bytes of memory. A ciphertext also carries with it the + parms_id of its associated encryption parameters, which is used to check + the validity of the ciphertext for homomorphic operations and decryption. + + @par Memory Management + The size of a ciphertext refers to the number of polynomials it contains, + whereas its capacity refers to the number of polynomials that fit in the + current memory allocation. In high-performance applications unnecessary + re-allocations should be avoided by reserving enough memory for the + ciphertext to begin with either by providing the desired capacity to the + constructor as an extra argument, or by calling the reserve function at + any time. + + @par Thread Safety + In general, reading from ciphertext is thread-safe as long as no other + thread is concurrently mutating it. This is due to the underlying data + structure storing the ciphertext not being thread-safe. + + @see Plaintext for the class that stores plaintexts. + */ + class Ciphertext + { + public: + using ct_coeff_type = std::uint64_t; + + using size_type = IntArray::size_type; + + /** + Constructs an empty ciphertext allocating no memory. + + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if pool is uninitialized + */ + Ciphertext(MemoryPoolHandle pool = MemoryManager::GetPool()) : + data_(std::move(pool)) + { + } + + /** + Constructs an empty ciphertext with capacity 2. In addition to the + capacity, the allocation size is determined by the highest-level + parameters associated to the given SEALContext. + + @param[in] context The SEALContext + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + @throws std::invalid_argument if pool is uninitialized + */ + explicit Ciphertext(std::shared_ptr context, + MemoryPoolHandle pool = MemoryManager::GetPool()) : + data_(std::move(pool)) + { + // Allocate memory but don't resize + reserve(std::move(context), 2); + } + + /** + Constructs an empty ciphertext with capacity 2. In addition to the + capacity, the allocation size is determined by the encryption parameters + with given parms_id. + + @param[in] context The SEALContext + @param[in] parms_id The parms_id corresponding to the encryption + parameters to be used + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + @throws std::invalid_argument if parms_id is not valid for the encryption + parameters + @throws std::invalid_argument if pool is uninitialized + */ + explicit Ciphertext(std::shared_ptr context, + parms_id_type parms_id, + MemoryPoolHandle pool = MemoryManager::GetPool()) : + data_(std::move(pool)) + { + // Allocate memory but don't resize + reserve(std::move(context), parms_id, 2); + } + + /** + Constructs an empty ciphertext with given capacity. In addition to + the capacity, the allocation size is determined by the given + encryption parameters. + + @param[in] context The SEALContext + @param[in] parms_id The parms_id corresponding to the encryption + parameters to be used + @param[in] size_capacity The capacity + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + @throws std::invalid_argument if parms_id is not valid for the encryption + parameters + @throws std::invalid_argument if size_capacity is less than 2 or too large + @throws std::invalid_argument if pool is uninitialized + */ + explicit Ciphertext(std::shared_ptr context, + parms_id_type parms_id, size_type size_capacity, + MemoryPoolHandle pool = MemoryManager::GetPool()) : + data_(std::move(pool)) + { + // Allocate memory but don't resize + reserve(std::move(context), parms_id, size_capacity); + } + + /** + Constructs a new ciphertext by copying a given one. + + @param[in] copy The ciphertext to copy from + */ + Ciphertext(const Ciphertext ©) = default; + + /** + Creates a new ciphertext by moving a given one. + + @param[in] source The ciphertext to move from + */ + Ciphertext(Ciphertext &&source) = default; + + /** + Allocates enough memory to accommodate the backing array of a ciphertext + with given capacity. In addition to the capacity, the allocation size is + determined by the encryption parameters corresponing to the given + parms_id. + + @param[in] context The SEALContext + @param[in] parms_id The parms_id corresponding to the encryption + parameters to be used + @param[in] size_capacity The capacity + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + @throws std::invalid_argument if parms_id is not valid for the encryption + parameters + @throws std::invalid_argument if size_capacity is less than 2 or too large + */ + void reserve(std::shared_ptr context, + parms_id_type parms_id, size_type size_capacity); + + /** + Allocates enough memory to accommodate the backing array of a ciphertext + with given capacity. In addition to the capacity, the allocation size is + determined by the highest-level parameters associated to the given + SEALContext. + + @param[in] context The SEALContext + @param[in] size_capacity The capacity + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + @throws std::invalid_argument if size_capacity is less than 2 or too large + */ + inline void reserve(std::shared_ptr context, + size_type size_capacity) + { + // Verify parameters + if (!context) + { + throw std::invalid_argument("invalid context"); + } + auto parms_id = context->first_parms_id(); + reserve(std::move(context), parms_id, size_capacity); + } + + /** + Allocates enough memory to accommodate the backing array of a ciphertext + with given capacity. In addition to the capacity, the allocation size is + determined by the current encryption parameters. + + @param[in] size_capacity The capacity + @throws std::invalid_argument if size_capacity is less than 2 or too large + @throws std::logic_error if the encryption parameters are not + */ + inline void reserve(size_type size_capacity) + { + // Note: poly_modulus_degree_ and coeff_mod_count_ are either valid + // or coeff_mod_count_ is zero (in which case no memory is allocated). + reserve_internal(size_capacity, poly_modulus_degree_, + coeff_mod_count_); + } + + /** + Resizes the ciphertext to given size, reallocating if the capacity + of the ciphertext is too small. The ciphertext parameters are + determined by the given SEALContext and parms_id. + + This function is mainly intended for internal use and is called + automatically by functions such as Evaluator::multiply and + Evaluator::relinearize. A normal user should never have a reason + to manually resize a ciphertext. + + @param[in] context The SEALContext + @param[in] parms_id The parms_id corresponding to the encryption + parameters to be used + @param[in] size The new size + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + @throws std::invalid_argument if parms_id is not valid for the encryption + parameters + @throws std::invalid_argument if size is less than 2 or too large + */ + void resize(std::shared_ptr context, + parms_id_type parms_id, size_type size); + + /** + Resizes the ciphertext to given size, reallocating if the capacity + of the ciphertext is too small. The ciphertext parameters are + determined by the highest-level parameters associated to the given + SEALContext. + + This function is mainly intended for internal use and is called + automatically by functions such as Evaluator::multiply and + Evaluator::relinearize. A normal user should never have a reason + to manually resize a ciphertext. + + @param[in] context The SEALContext + @param[in] size The new size + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + @throws std::invalid_argument if size is less than 2 or too large + */ + inline void resize(std::shared_ptr context, + size_type size) + { + // Verify parameters + if (!context) + { + throw std::invalid_argument("invalid context"); + } + auto parms_id = context->first_parms_id(); + resize(std::move(context), parms_id, size); + } + + /** + Resizes the ciphertext to given size, reallocating if the capacity + of the ciphertext is too small. + + This function is mainly intended for internal use and is called + automatically by functions such as Evaluator::multiply and + Evaluator::relinearize. A normal user should never have a reason + to manually resize a ciphertext. + + @param[in] size The new size + @throws std::invalid_argument if size is less than 2 or too large + */ + inline void resize(size_type size) + { + // Note: poly_modulus_degree_ and coeff_mod_count_ are either valid + // or coeff_mod_count_ is zero (in which case no memory is allocated). + resize_internal(size, poly_modulus_degree_, coeff_mod_count_); + } + + /** + Resets the ciphertext. This function releases any memory allocated + by the ciphertext, returning it to the memory pool. It also sets all + encryption parameter specific size information to zero. + */ + inline void release() noexcept + { + parms_id_ = parms_id_zero; + is_ntt_form_ = false; + size_capacity_ = 2; + size_ = 0; + poly_modulus_degree_ = 0; + coeff_mod_count_ = 0; + scale_ = 1.0; + data_.release(); + } + + /** + Copies a given ciphertext to the current one. + + @param[in] assign The ciphertext to copy from + */ + Ciphertext &operator =(const Ciphertext &assign); + + /** + Moves a given ciphertext to the current one. + + @param[in] assign The ciphertext to move from + */ + Ciphertext &operator =(Ciphertext &&assign) = default; + + /** + Returns a pointer to the beginning of the ciphertext data. + */ + inline ct_coeff_type *data() noexcept + { + return data_.begin(); + } + + /** + Returns a const pointer to the beginning of the ciphertext data. + */ + inline const ct_coeff_type *data() const noexcept + { + return data_.cbegin(); + } +#ifdef SEAL_USE_MSGSL_MULTISPAN + /** + Returns the ciphertext data. + */ + inline gsl::multi_span< + ct_coeff_type, + gsl::dynamic_range, + gsl::dynamic_range, + gsl::dynamic_range> data_span() + { + return gsl::as_multi_span< + ct_coeff_type, + gsl::dynamic_range, + gsl::dynamic_range, + gsl::dynamic_range>( + data_.begin(), + util::safe_cast(size_), + util::safe_cast(coeff_mod_count_), + util::safe_cast(poly_modulus_degree_)); + } + + /** + Returns the backing array storing all of the coefficient values. + */ + inline gsl::multi_span< + const ct_coeff_type, + gsl::dynamic_range, + gsl::dynamic_range, + gsl::dynamic_range> data_span() const + { + return gsl::as_multi_span< + const ct_coeff_type, + gsl::dynamic_range, + gsl::dynamic_range, + gsl::dynamic_range>( + data_.cbegin(), + util::safe_cast(size_), + util::safe_cast(coeff_mod_count_), + util::safe_cast(poly_modulus_degree_)); + } +#endif + /** + Returns a pointer to a particular polynomial in the ciphertext + data. Note that SEAL stores each polynomial in the ciphertext + modulo all of the K primes in the coefficient modulus. The pointer + returned by this function is to the beginning (constant coefficient) + of the first one of these K polynomials. + + @param[in] poly_index The index of the polynomial in the ciphertext + @throws std::out_of_range if poly_index is less than 0 or bigger + than the size of the ciphertext + */ + inline ct_coeff_type *data(size_type poly_index) + { + auto poly_uint64_count = util::mul_safe( + poly_modulus_degree_, coeff_mod_count_); + if (poly_uint64_count == 0) + { + return nullptr; + } + if (poly_index >= size_) + { + throw std::out_of_range("poly_index must be within [0, size)"); + } + return data_.begin() + util::safe_cast( + util::mul_safe(poly_index, poly_uint64_count)); + } + + /** + Returns a const pointer to a particular polynomial in the + ciphertext data. Note that SEAL stores each polynomial in the + ciphertext modulo all of the K primes in the coefficient modulus. + The pointer returned by this function is to the beginning + (constant coefficient) of the first one of these K polynomials. + + @param[in] poly_index The index of the polynomial in the ciphertext + @throws std::out_of_range if poly_index is out of range + */ + inline const ct_coeff_type *data(size_type poly_index) const + { + auto poly_uint64_count = util::mul_safe( + poly_modulus_degree_, coeff_mod_count_); + if (poly_uint64_count == 0) + { + return nullptr; + } + if (poly_index >= size_) + { + throw std::out_of_range("poly_index must be within [0, size)"); + } + return data_.cbegin() + util::safe_cast( + util::mul_safe(poly_index, poly_uint64_count)); + } + + /** + Returns a reference to a polynomial coefficient at a particular + index in the ciphertext data. If the polynomial modulus has degree N, + and the number of primes in the coefficient modulus is K, then the + ciphertext contains size*N*K coefficients. Thus, the coeff_index has + a range of [0, size*N*K). + + @param[in] coeff_index The index of the coefficient + @throws std::out_of_range if coeff_index is out of range + */ + inline ct_coeff_type &operator [](size_type coeff_index) + { + return data_.at(coeff_index); + } + + /** + Returns a const reference to a polynomial coefficient at a particular + index in the ciphertext data. If the polynomial modulus has degree N, + and the number of primes in the coefficient modulus is K, then the + ciphertext contains size*N*K coefficients. Thus, the coeff_index has + a range of [0, size*N*K). + + @param[in] coeff_index The index of the coefficient + @throws std::out_of_range if coeff_index is out of range + */ + inline const ct_coeff_type &operator [](size_type coeff_index) const + { + return data_.at(coeff_index); + } + + /** + Returns the number of primes in the coefficient modulus of the + associated encryption parameters. This directly affects the + allocation size of the ciphertext. + */ + inline size_type coeff_mod_count() const noexcept + { + return coeff_mod_count_; + } + + /** + Returns the degree of the polynomial modulus of the associated + encryption parameters. This directly affects the allocation size + of the ciphertext. + */ + inline size_type poly_modulus_degree() const noexcept + { + return poly_modulus_degree_; + } + + /** + Returns the capacity of the allocation. This means the largest size + of the ciphertext that can be stored in the current allocation with + the current encryption parameters. + */ + inline size_type size_capacity() const noexcept + { + return size_capacity_; + } + + /** + Returns the size of the ciphertext. + */ + inline size_type size() const noexcept + { + return size_; + } + + /** + Returns the total size of the current allocation in 64-bit words. + */ + inline size_type uint64_count_capacity() const noexcept + { + return data_.capacity(); + } + + /** + Returns the total size of the current ciphertext in 64-bit words. + */ + inline size_type uint64_count() const noexcept + { + return data_.size(); + } + + /** + Check whether the current ciphertext is valid for a given SEALContext. + If the given SEALContext is not set, the encryption parameters are invalid, + or the ciphertext data does not match the SEALContext, this function + returns false. Otherwise, returns true. + + @param[in] context The SEALContext + */ + bool is_valid_for(std::shared_ptr context) const noexcept; + + /** + Saves the ciphertext to an output stream. The output is in binary format + and not human-readable. The output stream must have the "binary" flag set. + + @param[in] stream The stream to save the ciphertext to + @throws std::exception if the ciphertext could not be written to stream + */ + void save(std::ostream &stream) const; + + /** + Loads a ciphertext from an input stream overwriting the current ciphertext. + No checking of the validity of the ciphertext data against encryption + parameters is performed. This function should not be used unless the + ciphertext comes from a fully trusted source. + + @param[in] stream The stream to load the ciphertext from + @throws std::exception if a valid ciphertext could not be read from stream + */ + void unsafe_load(std::istream &stream); + + /** + Loads a ciphertext from an input stream overwriting the current ciphertext. + The loaded ciphertext is verified to be valid for the given SEALContext. + + @param[in] context The SEALContext + @param[in] stream The stream to load the ciphertext from + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + @throws std::exception if a valid ciphertext could not be read from stream + @throws std::invalid_argument if the loaded ciphertext is invalid for the + context + */ + inline void load(std::shared_ptr context, + std::istream &stream) + { + unsafe_load(stream); + if (!is_valid_for(std::move(context))) + { + throw std::invalid_argument("ciphertext data is invalid"); + } + } + + /** + Returns whether the ciphertext is in NTT form. + */ + inline bool is_ntt_form() const noexcept + { + return is_ntt_form_; + } + + /** + Returns whether the ciphertext is in NTT form. + */ + inline bool &is_ntt_form() noexcept + { + return is_ntt_form_; + } + + /** + Returns a reference to parms_id. + + @see EncryptionParameters for more information about parms_id. + */ + inline auto &parms_id() noexcept + { + return parms_id_; + } + + /** + Returns a const reference to parms_id. + + @see EncryptionParameters for more information about parms_id. + */ + inline auto &parms_id() const noexcept + { + return parms_id_; + } + + /** + Returns a reference to the scale. This is only needed when using the + CKKS encryption scheme. The user should have little or no reason to ever + change the scale by hand. + */ + inline auto &scale() noexcept + { + return scale_; + } + + /** + Returns a constant reference to the scale. This is only needed when + using the CKKS encryption scheme. + */ + inline auto &scale() const noexcept + { + return scale_; + } + + /** + Returns the currently used MemoryPoolHandle. + */ + inline MemoryPoolHandle pool() const noexcept + { + return data_.pool(); + } + + private: + void reserve_internal(size_type size_capacity, + size_type poly_modulus_degree, size_type coeff_mod_count); + + void resize_internal(size_type size, size_type poly_modulus_degree, + size_type coeff_mod_count); + + parms_id_type parms_id_ = parms_id_zero; + + bool is_ntt_form_ = false; + + size_type size_capacity_ = 2; + + size_type size_ = 0; + + size_type poly_modulus_degree_ = 0; + + size_type coeff_mod_count_ = 0; + + double scale_ = 1.0; + + IntArray data_; + }; +} diff --git a/src/seal/ckks.cpp b/src/seal/ckks.cpp new file mode 100644 index 000000000..c4f3ab695 --- /dev/null +++ b/src/seal/ckks.cpp @@ -0,0 +1,274 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include "seal/ckks.h" + +using namespace std; +using namespace seal::util; + +namespace seal +{ + // For C++14 compatibility need to define static constexpr + // member variables with no initialization here. + constexpr double CKKSEncoder::PI_; + + CKKSEncoder::CKKSEncoder(shared_ptr context) : + context_(context) + { + // Verify parameters + if (!context_) + { + throw invalid_argument("invalid context"); + } + if (!context_->parameters_set()) + { + throw invalid_argument("encryption parameters are not set correctly"); + } + + auto &context_data = *context_->context_data(); + if (context_data.parms().scheme() != scheme_type::CKKS) + { + throw invalid_argument("unsupported scheme"); + } + + size_t coeff_count = context_data.parms().poly_modulus_degree(); + slots_ = coeff_count >> 1; + int logn = get_power_of_two(coeff_count); + + matrix_reps_index_map_ = allocate_uint(coeff_count, pool_); + + // Copy from the matrix to the value vectors + uint64_t gen = 3; + uint64_t pos = 1; + uint64_t m = coeff_count << 1; + for (size_t i = 0; i < slots_; i++) + { + // Position in normal bit order + uint64_t index1 = (pos - 1) >> 1; + uint64_t index2 = (m - pos - 1) >> 1; + + // Set the bit-reversed locations + matrix_reps_index_map_[i] = reverse_bits(index1, logn); + matrix_reps_index_map_[slots_ | i] = reverse_bits(index2, logn); + + // Next primitive root + pos *= gen; + pos &= (m - 1); + } + + roots_ = allocate>(coeff_count, pool_); + inv_roots_ = allocate>(coeff_count, pool_); + complex psi{ cos((2 * PI_) / static_cast(m)), + sin((2 * PI_) / static_cast(m)) }; + for (size_t i = 0; i < coeff_count; i++) + { + roots_[i] = pow(psi, static_cast(reverse_bits(i, logn))); + inv_roots_[i] = 1.0 / roots_[i]; + } + } + + void CKKSEncoder::encode_internal(double value, parms_id_type parms_id, + double scale, Plaintext &destination, MemoryPoolHandle pool) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(parms_id); + if (!context_data_ptr) + { + throw invalid_argument("parms_id is not valid for encryption parameters"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + + auto &context_data = *context_data_ptr; + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t coeff_count = parms.poly_modulus_degree(); + + // Quick sanity check + if (!product_fits_in(coeff_mod_count, coeff_count)) + { + throw logic_error("invalid parameters"); + } + + // Check that scale is positive and not too large + if (scale <= 0 || (static_cast(log2(scale)) >= + context_data.total_coeff_modulus_bit_count())) + { + throw invalid_argument("scale out of bounds"); + } + + // Compute the scaled value + value *= scale; + + int coeff_bit_count = static_cast(log2(fabs(value))) + 2; + if (coeff_bit_count >= context_data.total_coeff_modulus_bit_count()) + { + throw invalid_argument("encoded value is too large"); + } + + double two_pow_64 = pow(2.0, 64); + + // Resize destination to appropriate size + // Need to first set parms_id to zero, otherwise resize + // will throw an exception. + destination.parms_id() = parms_id_zero; + destination.resize(coeff_count * coeff_mod_count); + + double coeffd = round(value); + bool is_negative = signbit(coeffd); + coeffd = fabs(coeffd); + + // Use faster decomposition methods when possible + if (coeff_bit_count <= 64) + { + uint64_t coeffu = static_cast(fabs(coeffd)); + + if (is_negative) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + fill_n(destination.data() + (j * coeff_count), coeff_count, + negate_uint_mod(coeffu % coeff_modulus[j].value(), + coeff_modulus[j])); + } + } + else + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + fill_n(destination.data() + (j * coeff_count), coeff_count, + coeffu % coeff_modulus[j].value()); + } + } + } + else if (coeff_bit_count <= 128) + { + uint64_t coeffu[2]{ + static_cast(fmod(coeffd, two_pow_64)), + static_cast(coeffd / two_pow_64) }; + + if (is_negative) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + fill_n(destination.data() + (j * coeff_count), coeff_count, + negate_uint_mod(barrett_reduce_128( + coeffu, coeff_modulus[j]), coeff_modulus[j])); + } + } + else + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + fill_n(destination.data() + (j * coeff_count), coeff_count, + barrett_reduce_128(coeffu, coeff_modulus[j])); + } + } + } + else + { + // Slow case + auto coeffu(allocate_uint(coeff_mod_count, pool)); + auto decomp_coeffu(allocate_uint(coeff_mod_count, pool)); + + // We are at this point guaranteed to fit in the allocated space + set_zero_uint(coeff_mod_count, coeffu.get()); + auto coeffu_ptr = coeffu.get(); + while (coeffd >= 1) + { + *coeffu_ptr++ = static_cast(fmod(coeffd, two_pow_64)); + coeffd /= two_pow_64; + } + + // Next decompose this coefficient + decompose_single_coeff(context_data, coeffu.get(), decomp_coeffu.get(), pool); + + // Finally replace the sign if necessary + if (is_negative) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + fill_n(destination.data() + (j * coeff_count), coeff_count, + negate_uint_mod(decomp_coeffu[j], coeff_modulus[j])); + } + } + else + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + fill_n(destination.data() + (j * coeff_count), coeff_count, + decomp_coeffu[j]); + } + } + } + + destination.parms_id() = parms_id; + destination.scale() = scale; + } + + void CKKSEncoder::encode_internal(int64_t value, parms_id_type parms_id, + Plaintext &destination) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(parms_id); + if (!context_data_ptr) + { + throw invalid_argument("parms_id is not valid for encryption parameters"); + } + + auto &context_data = *context_data_ptr; + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t coeff_count = parms.poly_modulus_degree(); + + // Quick sanity check + if (!product_fits_in(coeff_mod_count, coeff_count)) + { + throw logic_error("invalid parameters"); + } + + int coeff_bit_count = get_significant_bit_count( + static_cast(llabs(value))) + 2; + if (coeff_bit_count >= context_data.total_coeff_modulus_bit_count()) + { + throw invalid_argument("encoded value is too large"); + } + + // Resize destination to appropriate size + // Need to first set parms_id to zero, otherwise resize + // will throw an exception. + destination.parms_id() = parms_id_zero; + destination.resize(coeff_count * coeff_mod_count); + + if (value < 0) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + uint64_t tmp = static_cast(value); + tmp += coeff_modulus[j].value(); + tmp %= coeff_modulus[j].value(); + fill_n(destination.data() + (j * coeff_count), coeff_count, tmp); + } + } + else + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + uint64_t tmp = static_cast(value); + tmp %= coeff_modulus[j].value(); + fill_n(destination.data() + (j * coeff_count), coeff_count, tmp); + } + } + + destination.parms_id() = parms_id; + destination.scale() = 1.0; + } +} diff --git a/src/seal/ckks.h b/src/seal/ckks.h new file mode 100644 index 000000000..9e2a943c3 --- /dev/null +++ b/src/seal/ckks.h @@ -0,0 +1,745 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "seal/plaintext.h" +#include "seal/context.h" +#include "seal/util/common.h" +#include "seal/util/uintcore.h" +#include "seal/util/uintarithsmallmod.h" + +namespace seal +{ + template::value || + std::is_same>::value>> + inline T_out from_complex(std::complex in); + + template<> + inline double from_complex(std::complex in) + { + return in.real(); + } + + template<> + inline std::complex from_complex(std::complex in) + { + return in; + } + + /** + Provides functionality for encoding vectors of complex or real numbers into plaintext + polynomials to be encrypted and computed on using the CKKS scheme. If the polynomial + modulus degree is N, then CKKSEncoder converts vectors of N/2 complex numbers into + plaintext elements. Homomorphic operations performed on such encrypted vectors are + applied coefficient (slot-)wise, enabling powerful SIMD functionality for computations + that are vectorizable. This functionality is often called "batching" in the homomorphic + encryption literature. + + @par Mathematical Background + Mathematically speaking, if the polynomial modulus is X^N+1, N is a power of two, the + CKKSEncoder implements an approximation of the canonical embedding of the ring of + integers Z[X]/(X^N+1) into C^(N/2), where C denotes the complex numbers. The Galois + group of the extension is (Z/2NZ)* ~= Z/2Z x Z/(N/2) whose action on the primitive roots + of unity modulo coeff_modulus is easy to describe. Since the batching slots correspond + 1-to-1 to the primitive roots of unity, applying Galois automorphisms on the plaintext + acts by permuting the slots. By applying generators of the two cyclic subgroups of the + Galois group, we can effectively enable cyclic rotations and complex conjugations of + the encrypted complex vectors. + */ + class CKKSEncoder + { + public: + /** + Creates a CKKSEncoder instance initialized with the specified SEALContext. + + @param[in] context The SEALContext + @throws std::invalid_argument if the context is not set or encryption parameters + are not valid + @throws std::invalid_argument if scheme is not scheme_type::CKKS + */ + CKKSEncoder(std::shared_ptr context); + + /** + Encodes double-precision floating-point real or complex numbers into a plaintext + polynomial. Dynamic memory allocations in the process are allocated from the + memory pool pointed to by the given MemoryPoolHandle. + + @tparam T Vector value type (double or std::complex) + @param[in] values The vector of double-precision floating-point numbers + (of type T) to encode + @param[in] parms_id parms_id determining the encryption parameters to be used + by the result plaintext + @param[in] scale Scaling parameter defining encoding precision + @param[out] destination The plaintext polynomial to overwrite with the result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if values has invalid size + @throws std::invalid_argument if parms_id is not valid for the encryption + parameters + @throws std::invalid_argument if scale is not strictly positive + @throws std::invalid_argument if encoding is too large for the encryption + parameters + @throws std::invalid_argument if pool is uninitialized + */ + template::value || + std::is_same>::value>> + inline void encode(const std::vector &values, + parms_id_type parms_id, double scale, Plaintext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + encode_internal(values, parms_id, scale, destination, std::move(pool)); + } + + /** + Encodes double-precision floating-point real or complex numbers into + a plaintext polynomial. The encryption parameters used are the top level + parameters for the given context. Dynamic memory allocations in the process + are allocated from the memory pool pointed to by the given MemoryPoolHandle. + + @tparam T Vector value type (double or std::complex) + @param[in] values The vector of double-precision floating-point numbers + (of type T) to encode + @param[in] scale Scaling parameter defining encoding precision + @param[out] destination The plaintext polynomial to overwrite with the result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if values has invalid size + @throws std::invalid_argument if scale is not strictly positive + @throws std::invalid_argument if encoding is too large for the encryption + parameters + @throws std::invalid_argument if pool is uninitialized + */ + template::value || + std::is_same>::value>> + inline void encode(const std::vector &values, + double scale, Plaintext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + encode(values, context_->first_parms_id(), scale, + destination, std::move(pool)); + } + + /** + Encodes a double-precision floating-point number into a plaintext polynomial. + Dynamic memory allocations in the process are allocated from the memory pool + pointed to by the given MemoryPoolHandle. + + @param[in] value The double-precision floating-point number to encode + @param[in] parms_id parms_id determining the encryption parameters to be used + by the result plaintext + @param[in] scale Scaling parameter defining encoding precision + @param[out] destination The plaintext polynomial to overwrite with the result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if parms_id is not valid for the encryption + parameters + @throws std::invalid_argument if scale is not strictly positive + @throws std::invalid_argument if encoding is too large for the encryption + parameters + @throws std::invalid_argument if pool is uninitialized + */ + inline void encode(double value, parms_id_type parms_id, + double scale, Plaintext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + encode_internal(value, parms_id, scale, destination, std::move(pool)); + } + + /** + Encodes a double-precision floating-point number into a plaintext polynomial. + The encryption parameters used are the top level parameters for the given context. + Dynamic memory allocations in the process are allocated from the memory pool + pointed to by the given MemoryPoolHandle. + + @param[in] value The double-precision floating-point number to encode + @param[in] scale Scaling parameter defining encoding precision + @param[out] destination The plaintext polynomial to overwrite with the result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if scale is not strictly positive + @throws std::invalid_argument if encoding is too large for the encryption + parameters + @throws std::invalid_argument if pool is uninitialized + */ + inline void encode(double value, + double scale, Plaintext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + encode(value, context_->first_parms_id(), scale, + destination, std::move(pool)); + } + + /** + Encodes a double-precision complex number into a plaintext polynomial. Dynamic + memory allocations in the process are allocated from the memory pool pointed to + by the given MemoryPoolHandle. + + @param[in] value The double-precision complex number to encode + @param[in] parms_id parms_id determining the encryption parameters to be used + by the result plaintext + @param[in] scale Scaling parameter defining encoding precision + @param[out] destination The plaintext polynomial to overwrite with the result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if parms_id is not valid for the encryption + parameters + @throws std::invalid_argument if scale is not strictly positive + @throws std::invalid_argument if encoding is too large for the encryption + parameters + @throws std::invalid_argument if pool is uninitialized + */ + inline void encode(std::complex value, + parms_id_type parms_id, double scale, Plaintext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + encode_internal(value, parms_id, scale, destination, std::move(pool)); + } + + /** + Encodes a double-precision complex number into a plaintext polynomial. The + encryption parameters used are the top level parameters for the given context. + Dynamic memory allocations in the process are allocated from the memory pool + pointed to by the given MemoryPoolHandle. + + @param[in] value The double-precision complex number to encode + @param[in] scale Scaling parameter defining encoding precision + @param[out] destination The plaintext polynomial to overwrite with the result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if scale is not strictly positive + @throws std::invalid_argument if encoding is too large for the encryption + parameters + @throws std::invalid_argument if pool is uninitialized + */ + inline void encode(std::complex value, + double scale, Plaintext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + encode(value, context_->first_parms_id(), scale, + destination, std::move(pool)); + } + + /** + Encodes an integer number into a plaintext polynomial without any scaling. + + @param[in] value The integer number to encode + @param[in] parms_id parms_id determining the encryption parameters to be used + by the result plaintext + @param[out] destination The plaintext polynomial to overwrite with the result + @throws std::invalid_argument if parms_id is not valid for the encryption + parameters + */ + inline void encode(std::int64_t value, + parms_id_type parms_id, Plaintext &destination) + { + encode_internal(value, parms_id, destination); + } + + /** + Encodes an integer number into a plaintext polynomial without any scaling. The + encryption parameters used are the top level parameters for the given context. + + @param[in] value The integer number to encode + @param[out] destination The plaintext polynomial to overwrite with the result + */ + inline void encode(std::int64_t value, Plaintext &destination) + { + encode(value, context_->first_parms_id(), destination); + } + + /** + Decodes a plaintext polynomial into double-precision floating-point real or + complex numbers. Dynamic memory allocations in the process are allocated from + the memory pool pointed to by the given MemoryPoolHandle. + + @tparam T Vector value type (double or std::complex) + @param[in] plain The plaintext to decode + @param[out] destination The vector to be overwritten with the values in the slots + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if plain is not in NTT form or is invalid for the + encryption parameters + @throws std::invalid_argument if pool is uninitialized + */ + template::value || + std::is_same>::value>> + inline void decode(const Plaintext &plain, std::vector &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + decode_internal(plain, destination, std::move(pool)); + } + + /** + Returns the number of complex numbers encoded. + */ + inline std::size_t slot_count() const noexcept + { + return slots_; + } + + private: + // This is the same function as in evaluator.h + inline void decompose_single_coeff( + const SEALContext::ContextData &context_data, const std::uint64_t *value, + std::uint64_t *destination, util::MemoryPool &pool) + { + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + std::size_t coeff_mod_count = coeff_modulus.size(); +#ifdef SEAL_DEBUG + if (value == nullptr) + { + throw std::invalid_argument("value cannot be null"); + } + if (destination == nullptr) + { + throw std::invalid_argument("destination cannot be null"); + } + if (destination == value) + { + throw std::invalid_argument("value cannot be the same as destination"); + } +#endif + if (coeff_mod_count == 1) + { + util::set_uint_uint(value, coeff_mod_count, destination); + return; + } + + auto value_copy(util::allocate_uint(coeff_mod_count, pool)); + for (std::size_t j = 0; j < coeff_mod_count; j++) + { + //destination[j] = util::modulo_uint( + // value, coeff_mod_count, coeff_modulus_[j], pool); + + // Manually inlined for efficiency + // Make a fresh copy of value + util::set_uint_uint(value, coeff_mod_count, value_copy.get()); + + // Starting from the top, reduce always 128-bit blocks + for (std::size_t k = coeff_mod_count - 1; k--; ) + { + value_copy[k] = util::barrett_reduce_128( + value_copy.get() + k, coeff_modulus[j]); + } + destination[j] = value_copy[0]; + } + } + + template::value || + std::is_same>::value>> + void encode_internal(const std::vector &values, + parms_id_type parms_id, double scale, Plaintext &destination, + MemoryPoolHandle pool) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(parms_id); + if (!context_data_ptr) + { + throw std::invalid_argument("parms_id is not valid for encryption parameters"); + } + if (values.size() > slots_) + { + throw std::invalid_argument("values has invalid size"); + } + if (!pool) + { + throw std::invalid_argument("pool is uninitialized"); + } + + auto &context_data = *context_data_ptr; + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + std::size_t coeff_mod_count = coeff_modulus.size(); + std::size_t coeff_count = parms.poly_modulus_degree(); + + // Quick sanity check + if (!util::product_fits_in(coeff_mod_count, coeff_count)) + { + throw std::logic_error("invalid parameters"); + } + + // Check that scale is positive and not too large + if (scale <= 0 || (static_cast(log2(scale)) + 1 >= + context_data.total_coeff_modulus_bit_count())) + { + throw std::invalid_argument("scale out of bounds"); + } + + auto &small_ntt_tables = context_data.small_ntt_tables(); + + // input_size is guaranteed to be no bigger than slots_ + std::size_t input_size = values.size(); + std::size_t n = util::mul_safe(slots_, std::size_t(2)); + + auto conj_values = util::allocate>(n, pool, 0); + for (std::size_t i = 0; i < input_size; i++) + { + conj_values[matrix_reps_index_map_[i]] = values[i]; + conj_values[matrix_reps_index_map_[i + slots_]] = std::conj(values[i]); + } + + int logn = util::get_power_of_two(n); + std::size_t tt = 1; + for (int i = 0; i < logn; i++) + { + std::size_t mm = std::size_t(1) << (logn - i); + std::size_t k_start = 0; + std::size_t h = mm / 2; + + for (std::size_t j = 0; j < h; j++) + { + std::size_t k_end = k_start + tt; + auto s = inv_roots_[h + j]; + + for (std::size_t k = k_start; k < k_end; k++) + { + auto u = conj_values[k]; + auto v = conj_values[k + tt]; + conj_values[k] = u + v; + conj_values[k + tt] = (u - v) * s; + } + + k_start += 2 * tt; + } + tt *= 2; + } + + double n_inv = double(1.0) / static_cast(n); + + // Put the scale in at this point + n_inv *= scale; + + int max_coeff_bit_count = 1; + for (std::size_t i = 0; i < n; i++) + { + // Multiply by scale and n_inv (see above) + conj_values[i] *= n_inv; + + // Verify that the values are not too large to fit in coeff_modulus + // Note that we have an extra + 1 for the sign bit + max_coeff_bit_count = std::max(max_coeff_bit_count, + static_cast(std::log2(std::fabs(conj_values[i].real()))) + 2); + } + if (max_coeff_bit_count >= context_data.total_coeff_modulus_bit_count()) + { + throw std::invalid_argument("encoded values are too large"); + } + + double two_pow_64 = std::pow(2.0, 64); + + // Resize destination to appropriate size + // Need to first set parms_id to zero, otherwise resize + // will throw an exception. + destination.parms_id() = parms_id_zero; + destination.resize(util::mul_safe(coeff_count, coeff_mod_count)); + + // Use faster decomposition methods when possible + if (max_coeff_bit_count <= 64) + { + for (std::size_t i = 0; i < n; i++) + { + double coeffd = std::round(conj_values[i].real()); + bool is_negative = std::signbit(coeffd); + + std::uint64_t coeffu = + static_cast(std::fabs(coeffd)); + + if (is_negative) + { + for (std::size_t j = 0; j < coeff_mod_count; j++) + { + destination[i + (j * coeff_count)] = util::negate_uint_mod( + coeffu % coeff_modulus[j].value(), coeff_modulus[j]); + } + } + else + { + for (std::size_t j = 0; j < coeff_mod_count; j++) + { + destination[i + (j * coeff_count)] = + coeffu % coeff_modulus[j].value(); + } + } + } + } + else if (max_coeff_bit_count <= 128) + { + for (std::size_t i = 0; i < n; i++) + { + double coeffd = std::round(conj_values[i].real()); + bool is_negative = std::signbit(coeffd); + coeffd = std::fabs(coeffd); + + std::uint64_t coeffu[2]{ + static_cast(std::fmod(coeffd, two_pow_64)), + static_cast(coeffd / two_pow_64) }; + + if (is_negative) + { + for (std::size_t j = 0; j < coeff_mod_count; j++) + { + destination[i + (j * coeff_count)] = + util::negate_uint_mod(util::barrett_reduce_128( + coeffu, coeff_modulus[j]), coeff_modulus[j]); + } + } + else + { + for (std::size_t j = 0; j < coeff_mod_count; j++) + { + destination[i + (j * coeff_count)] = + util::barrett_reduce_128(coeffu, coeff_modulus[j]); + } + } + } + } + else + { + // Slow case + auto coeffu(util::allocate_uint(coeff_mod_count, pool)); + auto decomp_coeffu(util::allocate_uint(coeff_mod_count, pool)); + for (std::size_t i = 0; i < n; i++) + { + double coeffd = std::round(conj_values[i].real()); + bool is_negative = std::signbit(coeffd); + coeffd = std::fabs(coeffd); + + // We are at this point guaranteed to fit in the allocated space + util::set_zero_uint(coeff_mod_count, coeffu.get()); + auto coeffu_ptr = coeffu.get(); + while (coeffd >= 1) + { + *coeffu_ptr++ = static_cast( + std::fmod(coeffd, two_pow_64)); + coeffd /= two_pow_64; + } + + // Next decompose this coefficient + decompose_single_coeff(context_data, coeffu.get(), + decomp_coeffu.get(), pool); + + // Finally replace the sign if necessary + if (is_negative) + { + for (std::size_t j = 0; j < coeff_mod_count; j++) + { + destination[i + (j * coeff_count)] = + util::negate_uint_mod(decomp_coeffu[j], coeff_modulus[j]); + } + } + else + { + for (std::size_t j = 0; j < coeff_mod_count; j++) + { + destination[i + (j * coeff_count)] = decomp_coeffu[j]; + } + } + } + } + + // Transform to NTT domain + for (std::size_t i = 0; i < coeff_mod_count; i++) + { + util::ntt_negacyclic_harvey( + destination.data(i * coeff_count), small_ntt_tables[i]); + } + + destination.parms_id() = parms_id; + destination.scale() = scale; + } + + template::value || + std::is_same>::value>> + void decode_internal(const Plaintext &plain, std::vector &destination, + MemoryPoolHandle pool) + { + // Verify parameters. + if (!plain.is_ntt_form()) + { + throw std::invalid_argument("plain is not in NTT form"); + } + if (!pool) + { + throw std::invalid_argument("pool is uninitialized"); + } + + auto context_data_ptr = context_->context_data(plain.parms_id()); + if (!context_data_ptr) + { + throw std::invalid_argument("parms_id is not valid for encryption parameters"); + } + + auto &parms = context_data_ptr->parms(); + auto &coeff_modulus = parms.coeff_modulus(); + std::size_t coeff_mod_count = coeff_modulus.size(); + std::size_t coeff_count = parms.poly_modulus_degree(); + std::size_t rns_poly_uint64_count = + util::mul_safe(coeff_count, coeff_mod_count); + + auto &small_ntt_tables = context_data_ptr->small_ntt_tables(); + + // Check that scale is positive and not too large + if (plain.scale() <= 0 || (static_cast(log2(plain.scale())) >= + context_data_ptr->total_coeff_modulus_bit_count())) + { + throw std::invalid_argument("scale out of bounds"); + } + + auto decryption_modulus = context_data_ptr->total_coeff_modulus(); + auto upper_half_threshold = context_data_ptr->upper_half_threshold(); + + auto &inv_coeff_products_mod_coeff_array = + context_data_ptr->base_converter()->get_inv_coeff_mod_coeff_array(); + auto coeff_products_array = + context_data_ptr->base_converter()->get_coeff_products_array(); + + int logn = util::get_power_of_two(coeff_count); + + // Quick sanity check + if ((logn < 0) || (coeff_count < SEAL_POLY_MOD_DEGREE_MIN) || + (coeff_count > SEAL_POLY_MOD_DEGREE_MAX)) + { + throw std::logic_error("invalid parameters"); + } + + double inv_scale = double(1.0) / plain.scale(); + + // Create mutable copy of input + auto plain_copy = util::allocate_uint(rns_poly_uint64_count, pool); + util::set_uint_uint(plain.data(), rns_poly_uint64_count, plain_copy.get()); + + // Array to keep number bigger than std::uint64_t + auto temp(util::allocate_uint(coeff_mod_count, pool)); + + // destination mod q + auto wide_tmp_dest(util::allocate_zero_uint(rns_poly_uint64_count, pool)); + + // Transform each polynomial from NTT domain + for (std::size_t i = 0; i < coeff_mod_count; i++) + { + util::inverse_ntt_negacyclic_harvey( + plain_copy.get() + (i * coeff_count), small_ntt_tables[i]); + } + + auto res = util::allocate>(coeff_count, pool); + + double two_pow_64 = std::pow(2.0, 64); + for (std::size_t i = 0; i < coeff_count; i++) + { + for (std::size_t j = 0; j < coeff_mod_count; j++) + { + std::uint64_t tmp = util::multiply_uint_uint_mod( + plain_copy[(j * coeff_count) + i], + inv_coeff_products_mod_coeff_array[j], // (qi/q * plain[i]) mod qi + coeff_modulus[j]); + util::multiply_uint_uint64( + coeff_products_array + (j * coeff_mod_count), + coeff_mod_count, tmp, coeff_mod_count, temp.get()); + util::add_uint_uint_mod(temp.get(), + wide_tmp_dest.get() + (i * coeff_mod_count), + decryption_modulus, coeff_mod_count, + wide_tmp_dest.get() + (i * coeff_mod_count)); + } + + double res_accum = 0.0; + if (util::is_greater_than_or_equal_uint_uint( + wide_tmp_dest.get() + (i * coeff_mod_count), + upper_half_threshold, coeff_mod_count)) + { + for (std::size_t j = 0; j < coeff_mod_count; j++) + { + double diff = 0.0; + if (wide_tmp_dest[i * coeff_mod_count + j] > decryption_modulus[j]) + { + diff = static_cast(wide_tmp_dest[i * coeff_mod_count + j] + - decryption_modulus[j]); + } + else + { + diff = -static_cast(decryption_modulus[j] + - wide_tmp_dest[i * coeff_mod_count + j]); + } + res_accum += diff * pow(two_pow_64, j); + } + } + else + { + for (std::size_t j = 0; j < coeff_mod_count; j++) + { + res_accum += static_cast( + wide_tmp_dest[i * coeff_mod_count + j]) * pow(two_pow_64, j); + } + } + + res[i] = res_accum * inv_scale; + } + + std::size_t tt = coeff_count; + for (int i = 0; i < logn; i++) + { + std::size_t mm = std::size_t(1) << i; + tt >>= 1; + + for (std::size_t j = 0; j < mm; j++) + { + std::size_t j1 = 2 * j * tt; + std::size_t j2 = j1 + tt - 1; + auto s = roots_[mm + j]; + + for (std::size_t k = j1; k < j2 + 1; k++) + { + auto u = res[k]; + auto v = res[k + tt] * s; + res[k] = u + v; + res[k + tt] = u - v; + } + } + } + + destination.clear(); + destination.reserve(slots_); + for (std::size_t i = 0; i < slots_; i++) + { + destination.emplace_back( + from_complex(res[matrix_reps_index_map_[i]])); + } + } + + void encode_internal(double value, parms_id_type parms_id, + double scale, Plaintext &destination, MemoryPoolHandle pool); + + inline void encode_internal(std::complex value, + parms_id_type parms_id, double scale, Plaintext &destination, + MemoryPoolHandle pool) + { + encode_internal(std::vector>(1, value), + parms_id, scale, destination, std::move(pool)); + } + + void encode_internal(std::int64_t value, + parms_id_type parms_id, Plaintext &destination); + + MemoryPoolHandle pool_ = MemoryManager::GetPool(); + + static constexpr double PI_ = 3.14159265358979323846; + + static const double two_pow_64_; + + std::shared_ptr context_{ nullptr }; + + std::size_t slots_; + + util::Pointer> roots_; + + util::Pointer> inv_roots_; + + util::Pointer matrix_reps_index_map_; + }; +} diff --git a/src/seal/context.cpp b/src/seal/context.cpp new file mode 100644 index 000000000..e1055b53a --- /dev/null +++ b/src/seal/context.cpp @@ -0,0 +1,381 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/context.h" +#include "seal/util/pointer.h" +#include "seal/util/polycore.h" +#include "seal/util/uintarith.h" +#include "seal/util/uintarithsmallmod.h" +#include "seal/util/numth.h" +#include "seal/defaultparams.h" +#include +#include + +using namespace std; +using namespace seal::util; + +namespace seal +{ + SEALContext::ContextData SEALContext::validate(EncryptionParameters parms) + { + ContextData context_data(parms, pool_); + context_data.qualifiers_.parameters_set = true; + + auto &coeff_modulus = parms.coeff_modulus(); + auto &plain_modulus = parms.plain_modulus(); + + // The number of coeff moduli is restricted to 62 for lazy reductions + // in baseconverter.cpp to work + if (coeff_modulus.size() > SEAL_COEFF_MOD_COUNT_MAX || + coeff_modulus.size() < SEAL_COEFF_MOD_COUNT_MIN) + { + context_data.qualifiers_.parameters_set = false; + return context_data; + } + + size_t coeff_mod_count = coeff_modulus.size(); + for (size_t i = 0; i < coeff_mod_count; i++) + { + // Coeff moduli must be at least 2 and at most USER_MODULO_BOUND bits + if (coeff_modulus[i].value() >> SEAL_USER_MOD_BIT_COUNT_MAX || + coeff_modulus[i].value() < (uint64_t(1) << SEAL_USER_MOD_BIT_COUNT_MIN)) + { + context_data.qualifiers_.parameters_set = false; + return context_data; + } + + // Check that all coeff moduli are pairwise relatively prime + for (size_t j = 0; j < i; j++) + { + if (gcd(coeff_modulus[i].value(), coeff_modulus[j].value()) > 1) + { + context_data.qualifiers_.parameters_set = false; + return context_data; + } + } + } + + // Compute the product of all coeff moduli + context_data.total_coeff_modulus_ = allocate_uint(coeff_mod_count, pool_); + auto temp(allocate_uint(coeff_mod_count, pool_)); + set_uint(1, coeff_mod_count, context_data.total_coeff_modulus_.get()); + for (size_t i = 0; i < coeff_mod_count; i++) + { + multiply_uint_uint64(context_data.total_coeff_modulus_.get(), + coeff_mod_count, coeff_modulus[i].value(), coeff_mod_count, + temp.get()); + set_uint_uint(temp.get(), coeff_mod_count, + context_data.total_coeff_modulus_.get()); + } + context_data.total_coeff_modulus_bit_count_ = get_significant_bit_count_uint( + context_data.total_coeff_modulus_.get(), coeff_mod_count); + + // Check polynomial modulus degree and create poly_modulus + size_t poly_modulus_degree = parms.poly_modulus_degree(); + int coeff_count_power = get_power_of_two(poly_modulus_degree); + if (poly_modulus_degree < SEAL_POLY_MOD_DEGREE_MIN || + poly_modulus_degree > SEAL_POLY_MOD_DEGREE_MAX || + coeff_count_power < 0) + { + // Parameters are not valid + context_data.qualifiers_.parameters_set = false; + return context_data; + } + + // Quick sanity check + if (!product_fits_in(coeff_mod_count, poly_modulus_degree)) + { + throw logic_error("invalid parameters"); + } + + // Polynomial modulus X^(2^k) + 1 is guaranteed at this point + context_data.qualifiers_.using_fft = true; + + // Verify that noise_standard_deviation is positive + if (parms.noise_standard_deviation() < 0 || + parms.noise_max_deviation() < 0) + { + // Parameters are not valid + context_data.qualifiers_.parameters_set = false; + return context_data; + } + + // Assume parameters are secure according to HomomorphicEncryption.org + // security standard + context_data.qualifiers_.using_he_std_security = true; + + // Check if the noise_standard_deviation is less than the default value + if (parms.noise_standard_deviation() < + util::global_variables::default_noise_standard_deviation) + { + // Not secure according to HomomorphicEncryption.org security standard + context_data.qualifiers_.using_he_std_security = false; +#ifdef SEAL_ENFORCE_HE_STD_SECURITY + // Parameters are not valid + context_data.qualifiers_.parameters_set = false; + return context_data; +#endif + } + + // Check if the parameters are secure according to HomomorphicEncryption.org + // security standard + if (util::global_variables:: + max_secure_coeff_modulus_bit_count.count(poly_modulus_degree) && + (context_data.total_coeff_modulus_bit_count_ > util::global_variables:: + max_secure_coeff_modulus_bit_count.at(poly_modulus_degree))) + { + // Not secure according to HomomorphicEncryption.org security standard + context_data.qualifiers_.using_he_std_security = false; +#ifdef SEAL_ENFORCE_HE_STD_SECURITY + // Parameters are not valid + context_data.qualifiers_.parameters_set = false; + return context_data; +#endif + } + + // Can we use NTT with coeff_modulus? + context_data.qualifiers_.using_ntt = true; + context_data.small_ntt_tables_ = + allocate(coeff_mod_count, pool_, pool_); + for (size_t i = 0; i < coeff_mod_count; i++) + { + if (!context_data.small_ntt_tables_[i].generate(coeff_count_power, + coeff_modulus[i])) + { + // Parameters are not valid + context_data.qualifiers_.using_ntt = false; + context_data.qualifiers_.parameters_set = false; + return context_data; + } + } + + if (parms.scheme() == scheme_type::BFV) + { + // Plain modulus must be at least 2 and at most 60 bits + if (plain_modulus.value() < SEAL_PLAIN_MOD_MIN || + plain_modulus.value() > SEAL_PLAIN_MOD_MAX) + { + context_data.qualifiers_.parameters_set = false; + return context_data; + } + + // Check that all coeff moduli are relatively prime to plain_modulus + for (size_t i = 0; i < coeff_mod_count; i++) + { + if (gcd(coeff_modulus[i].value(), plain_modulus.value()) > 1) + { + context_data.qualifiers_.parameters_set = false; + return context_data; + } + } + + // Check that plain_modulus is smaller than total coeff modulus + if (!is_less_than_uint_uint(plain_modulus.data(), plain_modulus.uint64_count(), + context_data.total_coeff_modulus_.get(), coeff_mod_count)) + { + // Parameters are not valid + context_data.qualifiers_.parameters_set = false; + return context_data; + } + + // Can we use batching? (NTT with plain_modulus) + context_data.qualifiers_.using_batching = false; + context_data.plain_ntt_tables_ = allocate(pool_); + if (context_data.plain_ntt_tables_->generate(coeff_count_power, plain_modulus)) + { + context_data.qualifiers_.using_batching = true; + } + + // Check for plain_lift + // If all the small coefficient moduli are larger than plain modulus, + // we can quickly lift plain coefficients to RNS form + context_data.qualifiers_.using_fast_plain_lift = true; + for (size_t i = 0; i < coeff_mod_count; i++) + { + context_data.qualifiers_.using_fast_plain_lift &= + (coeff_modulus[i].value() > plain_modulus.value()); + } + + // Calculate coeff_div_plain_modulus (BFV-"Delta") and the remainder + // upper_half_increment + context_data.coeff_div_plain_modulus_ = allocate_uint(coeff_mod_count, pool_); + context_data.upper_half_increment_ = allocate_uint(coeff_mod_count, pool_); + auto wide_plain_modulus(duplicate_uint_if_needed(plain_modulus.data(), + plain_modulus.uint64_count(), coeff_mod_count, false, pool_)); + divide_uint_uint(context_data.total_coeff_modulus_.get(), + wide_plain_modulus.get(), coeff_mod_count, + context_data.coeff_div_plain_modulus_.get(), + context_data.upper_half_increment_.get(), pool_); + + // Decompose coeff_div_plain_modulus into RNS factors + for (size_t i = 0; i < coeff_mod_count; i++) + { + temp[i] = modulo_uint(context_data.coeff_div_plain_modulus_.get(), + coeff_mod_count, coeff_modulus[i], pool_); + } + set_uint_uint(temp.get(), coeff_mod_count, + context_data.coeff_div_plain_modulus_.get()); + + // Decompose upper_half_increment into RNS factors + for (size_t i = 0; i < coeff_mod_count; i++) + { + temp[i] = modulo_uint(context_data.upper_half_increment_.get(), + coeff_mod_count, coeff_modulus[i], pool_); + } + set_uint_uint(temp.get(), coeff_mod_count, + context_data.upper_half_increment_.get()); + + // Calculate (plain_modulus + 1) / 2. + context_data.plain_upper_half_threshold_ = (plain_modulus.value() + 1) >> 1; + + // Calculate coeff_modulus - plain_modulus. + context_data.plain_upper_half_increment_ = + allocate_uint(coeff_mod_count, pool_); + if (context_data.qualifiers_.using_fast_plain_lift) + { + // Calculate coeff_modulus[i] - plain_modulus if using_fast_plain_lift + for (size_t i = 0; i < coeff_mod_count; i++) + { + context_data.plain_upper_half_increment_[i] = + coeff_modulus[i].value() - plain_modulus.value(); + } + } + else + { + sub_uint_uint(context_data.total_coeff_modulus(), + wide_plain_modulus.get(), coeff_mod_count, + context_data.plain_upper_half_increment_.get()); + } + } + else if (parms.scheme() == scheme_type::CKKS) + { + // Check that plain_modulus is set to zero + if (!plain_modulus.is_zero()) + { + // Parameters are not valid + context_data.qualifiers_.parameters_set = false; + return context_data; + } + + // When using CKKS batching (BatchEncoder) is always enabled + context_data.qualifiers_.using_batching = true; + + // Cannot use fast_plain_lift for CKKS since the plaintext coefficients + // can easily be larger than coefficient moduli + context_data.qualifiers_.using_fast_plain_lift = false; + + // Calculate 2^64 / 2 (most negative plaintext coefficient value) + context_data.plain_upper_half_threshold_ = uint64_t(1) << 63; + + // Calculate plain_upper_half_increment = 2^64 mod coeff_modulus for CKKS plaintexts + context_data.plain_upper_half_increment_ = allocate_uint(coeff_mod_count, pool_); + for (size_t i = 0; i < coeff_mod_count; i++) + { + uint64_t tmp = (uint64_t(1) << 63) % coeff_modulus[i].value(); + context_data.plain_upper_half_increment_[i] = multiply_uint_uint_mod( + tmp, + sub_safe(coeff_modulus[i].value(), uint64_t(2)), + coeff_modulus[i]); + } + + // Compute the upper_half_threshold for this modulus. + context_data.upper_half_threshold_ = allocate_uint( + coeff_mod_count, pool_); + increment_uint(context_data.total_coeff_modulus(), + coeff_mod_count, context_data.upper_half_threshold_.get()); + right_shift_uint(context_data.upper_half_threshold_.get(), 1, + coeff_mod_count, context_data.upper_half_threshold_.get()); + } + else + { + throw invalid_argument("unsupported scheme"); + } + + // Create BaseConverter + context_data.base_converter_ = allocate(pool_, pool_); + context_data.base_converter_->generate(coeff_modulus, poly_modulus_degree, + plain_modulus); + if (!context_data.base_converter_->is_generated()) + { + // Parameters are not valid + context_data.qualifiers_.parameters_set = false; + return context_data; + } + + // Done with validation and pre-computations + return context_data; + } + + SEALContext::SEALContext(EncryptionParameters parms, bool expand_mod_chain, + MemoryPoolHandle pool) : pool_(move(pool)) + { + if (!pool_) + { + throw invalid_argument("pool is uninitialized"); + } + + // Set random generator + if (!parms.random_generator()) + { + parms.set_random_generator( + UniformRandomGeneratorFactory::default_factory()); + } + + // Validate parameters and add new ContextData to the map + // Note that this happens even if parameters are not valid + context_data_map_.emplace(make_pair(parms.parms_id(), + make_shared(validate(parms)))); + + first_parms_id_ = parms.parms_id(); + last_parms_id_ = first_parms_id_; + + // If modulus switching is to be created then compute the remaining parameter + // sets as long as they are valid to use (parameters_set == true) + if (expand_mod_chain && + context_data_map_.at(first_parms_id_)->qualifiers_.parameters_set) + { + auto prev_parms_id = first_parms_id_; + while (context_data_map_.at(prev_parms_id)->parms().coeff_modulus().size() > 1) + { + // Create the next set of parameters by removing last modulus + auto next_parms = context_data_map_.at(prev_parms_id)->parms_; + auto next_coeff_modulus = next_parms.coeff_modulus(); + next_coeff_modulus.pop_back(); + next_parms.set_coeff_modulus(next_coeff_modulus); + auto next_parms_id = next_parms.parms_id(); + + // Validate next parameters + auto next_context_data = validate(next_parms); + + // If not valid then break + if (!next_context_data.qualifiers_.parameters_set) + { + break; + } + + // Add them to the context_data_map_ + context_data_map_.emplace(make_pair(next_parms_id, + make_shared(move(next_context_data)))); + + // Add pointer to next context_data to the previous one (linked list) + // We need to remove constness first to modify this + const_pointer_cast( + context_data_map_.at(prev_parms_id))->next_context_data_ = + context_data_map_.at(next_parms_id); + prev_parms_id = next_parms_id; + last_parms_id_ = prev_parms_id; + } + } + + // Set the chain_index for each context_data + size_t parms_count = context_data_map_.size(); + auto context_data_ptr = context_data_map_.at(first_parms_id_); + while (context_data_ptr) + { + // We need to remove constness first to modify this + const_pointer_cast( + context_data_ptr)->chain_index_ = --parms_count; + context_data_ptr = context_data_ptr->next_context_data_; + } + } +} diff --git a/src/seal/context.h b/src/seal/context.h new file mode 100644 index 000000000..e5b0f6588 --- /dev/null +++ b/src/seal/context.h @@ -0,0 +1,415 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include "seal/encryptionparams.h" +#include "seal/memorymanager.h" +#include "seal/util/smallntt.h" +#include "seal/util/baseconverter.h" +#include "seal/util/pointer.h" + +namespace seal +{ + /** + Stores a set of attributes (qualifiers) of a set of encryption parameters. + These parameters are mainly used internally in various parts of the library, e.g. + to determine which algorithmic optimizations the current support. The qualifiers + are automatically created by the SEALContext class, silently passed on to classes + such as Encryptor, Evaluator, and Decryptor, and the only way to change them is by + changing the encryption parameters themselves. In other words, a user will never + have to create their own instance of EncryptionParameterQualifiers, and in most + cases never have to worry about them at all. + + @see EncryptionParameters::GetQualifiers for obtaining the EncryptionParameterQualifiers + corresponding to the current parameter set. + */ + struct EncryptionParameterQualifiers + { + /** + If the encryption parameters are set in a way that is considered valid by SEAL, the + variable parameters_set is set to true. + */ + bool parameters_set; + + /** + Tells whether FFT can be used for polynomial multiplication. If the polynomial modulus + is of the form X^N+1, where N is a power of two, then FFT can be used for fast + multiplication of polynomials modulo the polynomial modulus. In this case the + variable using_fft will be set to true. However, currently SEAL requires this + to be the case for the parameters to be valid. Therefore, parameters_set can only + be true if using_fft is true. + */ + bool using_fft; + + /** + Tells whether NTT can be used for polynomial multiplication. If the primes in the + coefficient modulus are congruent to 1 modulo 2N, where X^N+1 is the polynomial + modulus and N is a power of two, then the number-theoretic transform (NTT) can be + used for fast multiplications of polynomials modulo the polynomial modulus and + coefficient modulus. In this case the variable using_ntt will be set to true. However, + currently SEAL requires this to be the case for the parameters to be valid. Therefore, + parameters_set can only be true if using_ntt is true. + */ + bool using_ntt; + + /** + Tells whether batching is supported by the encryption parameters. If the plaintext + modulus is congruent to 1 modulo 2N, where X^N+1 is the polynomial modulus and N is + a power of two, then it is possible to use the BatchEncoder class to view plaintext + elements as 2-by-(N/2) matrices of integers modulo the plaintext modulus. This is + called batching, and allows the user to operate on the matrix elements (slots) in + a SIMD fashion, and rotate the matrix rows and columns. When the computation is + easily vectorizable, using batching can yield a huge performance boost. If the + encryption parameters support batching, the variable using_batching is set to true. + */ + bool using_batching; + + /** + Tells whether fast plain lift is supported by the encryption parameters. A certain + performance optimization in multiplication of a ciphertext by a plaintext + (Evaluator::multiply_plain) and in transforming a plaintext element to NTT domain + (Evaluator::transform_to_ntt) can be used when the plaintext modulus is smaller than + each prime in the coefficient modulus. In this case the variable using_fast_plain_lift + is set to true. + */ + bool using_fast_plain_lift; + + /** + Tells whether the encryption parameters are secure based on the standard parameters + from HomomorphicEncryption.org security standard. + */ + bool using_he_std_security; + + private: + EncryptionParameterQualifiers() : + parameters_set(false), + using_fft(false), + using_ntt(false), + using_batching(false), + using_fast_plain_lift(false), + using_he_std_security(false) + { + } + + friend class SEALContext; + }; + + /** + Performs sanity checks (validation) and pre-computations for a given set of encryption + parameters. While the EncryptionParameters class is intended to be a light-weight class + to store the encryption parameters, the SEALContext class is a heavy-weight class that + is constructed from a given set of encryption parameters. It validates the parameters + for correctness, evaluates their properties, and performs and stores the results of + several costly pre-computations. + + After the user has set at least the poly_modulus, coeff_modulus, and plain_modulus + parameters in a given EncryptionParameters instance, the parameters can be validated + for correctness and functionality by constructing an instance of SEALContext. The + constructor of SEALContext does all of its work automatically, and concludes by + constructing and storing an instance of the EncryptionParameterQualifiers class, with + its flags set according to the properties of the given parameters. If the created + instance of EncryptionParameterQualifiers has the parameters_set flag set to true, the + given parameter set has been deemed valid and is ready to be used. If the parameters + were for some reason not appropriately set, the parameters_set flag will be false, + and a new SEALContext will have to be created after the parameters are corrected. + + @see EncryptionParameters for more details on the parameters. + @see EncryptionParameterQualifiers for more details on the qualifiers. + */ + class SEALContext + { + public: + class ContextData + { + friend class SEALContext; + + public: + ContextData() = delete; + + ContextData(const ContextData ©) = delete; + + ContextData(ContextData &&move) = default; + + ContextData &operator =(ContextData &&move) = default; + + /** + Returns a const reference to the underlying encryption parameters. + */ + inline auto &parms() const + { + return parms_; + } + + /** + Returns a copy of EncryptionParameterQualifiers corresponding to the + current encryption parameters. Note that to change the qualifiers it is + necessary to create a new instance of SEALContext once appropriate changes + to the encryption parameters have been made. + */ + inline auto qualifiers() const + { + return qualifiers_; + } + + /** + Returns a pointer to a pre-computed product of all primes in the coefficient + modulus. The security of the encryption parameters largely depends on the + bit-length of this product, and on the degree of the polynomial modulus. + */ + inline const std::uint64_t *total_coeff_modulus() const + { + return total_coeff_modulus_.get(); + } + + /** + Returns the significant bit count of the total coefficient modulus. + */ + inline auto total_coeff_modulus_bit_count() const + { + return total_coeff_modulus_bit_count_; + } + + /** + Returns a const reference to the base converter. + */ + inline auto &base_converter() const + { + return base_converter_; + } + + /** + Returns a const reference to the NTT tables. + */ + inline auto &small_ntt_tables() const + { + return small_ntt_tables_; + } + + /** + Returns a const reference to the NTT tables. + */ + inline auto &plain_ntt_tables() const + { + return plain_ntt_tables_; + } + + /** + Return a pointer to BFV "Delta", i.e. coefficient modulus divided by + plaintext modulus. + */ + inline const std::uint64_t *coeff_div_plain_modulus() const + { + return coeff_div_plain_modulus_.get(); + } + + /** + Return the threshold for the upper half of integers modulo plain_modulus. + This is simply (plain_modulus + 1) / 2. + */ + inline std::uint64_t plain_upper_half_threshold() const + { + return plain_upper_half_threshold_; + } + + /** + Return a pointer to the plaintext upper half increment, i.e. coeff_modulus + minus plain_modulus. The upper half increment is represented as an integer + for the full product coeff_modulus if using_fast_plain_lift is false and is + otherwise represented modulo each of the coeff_modulus primes in order. + */ + inline const std::uint64_t *plain_upper_half_increment() const + { + return plain_upper_half_increment_.get(); + } + + /** + Return a pointer to the upper half threshold with respect to the total + coefficient modulus. This is needed in CKKS decryption. + */ + inline const std::uint64_t *upper_half_threshold() const + { + return upper_half_threshold_.get(); + } + + /** + Return a pointer to the upper half increment used for computing Delta*m + and converting the coefficients to modulo coeff_modulus. For example, + t-1 in plaintext should change into + q - Delta = Delta*t + r_t(q) - Delta + = Delta*(t-1) + r_t(q) + so multiplying the message by Delta is not enough and requires also an + addition of r_t(q). This is precisely the upper_half_increment. Note that + this operation is only done for negative message coefficients, i.e. those + that exceed plain_upper_half_threshold. + */ + inline const std::uint64_t *upper_half_increment() const + { + return upper_half_increment_.get(); + } + + /** + Returns a shared_ptr to the context data corresponding to the next parameters + in the modulus switching chain. If the current data is the last one in the + chain, then the result is nullptr. + */ + inline auto next_context_data() const + { + return next_context_data_; + } + + /** + Returns the index of the parameter set in a chain. The initial parameters + have index 0 and the index increases sequentially in the parameter chain. + */ + inline std::size_t chain_index() const + { + return chain_index_; + } + + private: + ContextData(EncryptionParameters parms, MemoryPoolHandle pool) : + pool_(std::move(pool)), parms_(parms) + { + if (!pool_) + { + throw std::invalid_argument("pool is uninitialized"); + } + } + + MemoryPoolHandle pool_; + + EncryptionParameters parms_; + + EncryptionParameterQualifiers qualifiers_; + + util::Pointer base_converter_; + + util::Pointer small_ntt_tables_; + + util::Pointer plain_ntt_tables_; + + util::Pointer total_coeff_modulus_; + + int total_coeff_modulus_bit_count_; + + util::Pointer coeff_div_plain_modulus_; + + std::uint64_t plain_upper_half_threshold_; + + util::Pointer plain_upper_half_increment_; + + util::Pointer upper_half_threshold_; + + util::Pointer upper_half_increment_; + + std::shared_ptr next_context_data_{ nullptr }; + + std::size_t chain_index_ = 0; + }; + + SEALContext() = delete; + + /** + Creates an instance of SEALContext, and performs several pre-computations + on the given EncryptionParameters. + + @param[in] parms The encryption parameters + @param[in] expand_mod_chain Determines whether the modulus switching chain + should be created + */ + static auto Create(const EncryptionParameters &parms, + bool expand_mod_chain = true) + { + return std::shared_ptr( + new SEALContext(parms, expand_mod_chain, + MemoryManager::GetPool())); + } + + /** + Returns a const reference to ContextData class corresponding to the + encryption parameters. This is the first set of parameters in a chain + of parameters when modulus switching is used. + */ + inline auto context_data() const + { + return context_data_map_.at(first_parms_id_); + } + + /** + Returns an optional const reference to ContextData class corresponding to + the parameters with a given parms_id. If parameters with the given parms_id + are not found then the function returns nullptr. + + @param[in] parms_id The parms_id of the encryption parameters + */ + inline auto context_data(parms_id_type parms_id) const + { + auto data = context_data_map_.find(parms_id); + return (data != context_data_map_.end()) ? + data->second : std::shared_ptr{ nullptr }; + } + + /** + Returns whether the encryption parameters are valid. + */ + inline auto parameters_set() const + { + return context_data()->qualifiers_.parameters_set; + } + + /** + Returns a parms_id_type corresponding to the first set + of encryption parameters. + */ + inline auto &first_parms_id() const + { + return first_parms_id_; + } + + /** + Returns a parms_id_type corresponding to the last set + of encryption parameters. + */ + inline auto &last_parms_id() const + { + return last_parms_id_; + } + + private: + SEALContext(const SEALContext ©) = delete; + + SEALContext(SEALContext &&source) = delete; + + SEALContext &operator =(const SEALContext &assign) = delete; + + SEALContext &operator =(SEALContext &&assign) = delete; + + /** + Creates an instance of SEALContext, and performs several pre-computations + on the given EncryptionParameters. + + @param[in] parms The encryption parameters + @param[in] expand_mod_chain Determines whether the modulus switching chain + should be created + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if pool is uninitialized + */ + SEALContext(EncryptionParameters parms, bool expand_mod_chain, + MemoryPoolHandle pool); + + ContextData validate(EncryptionParameters parms); + + MemoryPoolHandle pool_; + + parms_id_type first_parms_id_; + + parms_id_type last_parms_id_; + + std::unordered_map< + parms_id_type, std::shared_ptr> context_data_map_{}; + }; +} diff --git a/src/seal/decryptor.cpp b/src/seal/decryptor.cpp new file mode 100644 index 000000000..0ca0c9b23 --- /dev/null +++ b/src/seal/decryptor.cpp @@ -0,0 +1,553 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include "seal/decryptor.h" +#include "seal/util/common.h" +#include "seal/util/uintcore.h" +#include "seal/util/uintarith.h" +#include "seal/util/uintarithmod.h" +#include "seal/util/uintarithsmallmod.h" +#include "seal/util/polycore.h" +#include "seal/util/polyarithmod.h" +#include "seal/util/polyarithsmallmod.h" + +using namespace std; +using namespace seal::util; + +namespace seal +{ + Decryptor::Decryptor(std::shared_ptr context, + const SecretKey &secret_key) : context_(move(context)) + { + // Verify parameters + if (!context_) + { + throw invalid_argument("invalid context"); + } + if (!context_->parameters_set()) + { + throw invalid_argument("encryption parameters are not set correctly"); + } + if (secret_key.parms_id() != context_->first_parms_id()) + { + throw invalid_argument("secret key is not valid for encryption parameters"); + } + + auto &parms = context_->context_data()->parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + + // Allocate secret_key_ and copy over value + secret_key_ = allocate_poly(coeff_count, coeff_mod_count, pool_); + set_poly_poly(secret_key.data().data(), coeff_count, coeff_mod_count, + secret_key_.get()); + + // Set the secret_key_array to have size 1 (first power of secret) + secret_key_array_ = allocate_poly(coeff_count, coeff_mod_count, pool_); + set_poly_poly(secret_key_.get(), coeff_count, coeff_mod_count, + secret_key_array_.get()); + secret_key_array_size_ = 1; + } + + void Decryptor::decrypt(const Ciphertext &encrypted, Plaintext &destination) + { + // Verify parameters. + if (!context_->context_data(encrypted.parms_id())) + { + throw invalid_argument("encrypted is not valid for encryption parameters"); + } + + auto &context_data = *context_->context_data(); + auto &parms = context_data.parms(); + + switch (parms.scheme()) + { + case scheme_type::BFV: + bfv_decrypt(encrypted, destination, pool_); + return; + + case scheme_type::CKKS: + ckks_decrypt(encrypted, destination, pool_); + return; + + default: + throw invalid_argument("unsupported scheme"); + } + } + + void Decryptor::bfv_decrypt(const Ciphertext &encrypted, + Plaintext &destination, MemoryPoolHandle pool) + { + if (encrypted.is_ntt_form()) + { + throw invalid_argument("encrypted cannot be in NTT form"); + } + + auto &context_data = *context_->context_data(encrypted.parms_id()); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t rns_poly_uint64_count = mul_safe(coeff_count, coeff_mod_count); + size_t first_rns_poly_uint64_count = mul_safe(coeff_count, + context_->context_data()->parms().coeff_modulus().size()); + size_t encrypted_size = encrypted.size(); + + auto &small_ntt_tables = context_data.small_ntt_tables(); + auto &base_converter = context_data.base_converter(); + auto &plain_gamma_product = base_converter->get_plain_gamma_product(); + auto &plain_gamma_array = base_converter->get_plain_gamma_array(); + auto &neg_inv_coeff = base_converter->get_neg_inv_coeff(); + auto inv_gamma = base_converter->get_inv_gamma(); + + // The number of uint64 count for plain_modulus and gamma together + size_t plain_gamma_uint64_count = 2; + + // Allocate a full size destination to write to + auto wide_destination(allocate_uint(coeff_count, pool)); + + // Make sure we have enough secret key powers computed + compute_secret_key_array(encrypted_size - 1); + + /* + Firstly find c_0 + c_1 *s + ... + c_{count-1} * s^{count-1} mod q + This is equal to Delta m + v where ||v|| < Delta/2. + So, add Delta / 2 and now we have something which is Delta * (m + epsilon) where epsilon < 1 + Therefore, we can (integer) divide by Delta and the answer will round down to m. + */ + + // Make a temp destination for all the arithmetic mod qi before calling FastBConverse + auto tmp_dest_modq(allocate_zero_poly(coeff_count, coeff_mod_count, pool)); + + // put < (c_1 , c_2, ... , c_{count-1}) , (s,s^2,...,s^{count-1}) > mod q in destination + + // Now do the dot product of encrypted_copy and the secret key array using NTT. + // The secret key powers are already NTT transformed. + auto copy_operand1(allocate_uint(coeff_count, pool)); + for (size_t i = 0; i < coeff_mod_count; i++) + { + // Initialize pointers for multiplication + const uint64_t *current_array1 = encrypted.data(1) + (i * coeff_count); + const uint64_t *current_array2 = secret_key_array_.get() + (i * coeff_count); + + for (size_t j = 0; j < encrypted_size - 1; j++) + { + // Perform the dyadic product. + set_uint_uint(current_array1, coeff_count, copy_operand1.get()); + + // Lazy reduction + ntt_negacyclic_harvey_lazy(copy_operand1.get(), small_ntt_tables[i]); + + dyadic_product_coeffmod(copy_operand1.get(), current_array2, coeff_count, + coeff_modulus[i], copy_operand1.get()); + add_poly_poly_coeffmod(tmp_dest_modq.get() + (i * coeff_count), + copy_operand1.get(), coeff_count, coeff_modulus[i], + tmp_dest_modq.get() + (i * coeff_count)); + + current_array1 += rns_poly_uint64_count; + current_array2 += first_rns_poly_uint64_count; + } + + // Perform inverse NTT + inverse_ntt_negacyclic_harvey(tmp_dest_modq.get() + (i * coeff_count), + small_ntt_tables[i]); + } + + // add c_0 into destination + for (size_t i = 0; i < coeff_mod_count; i++) + { + //add_poly_poly_coeffmod(tmp_dest_modq.get() + (i * coeff_count), + // encrypted.data() + (i * coeff_count), coeff_count, coeff_modulus_[i], + // tmp_dest_modq.get() + (i * coeff_count)); + + // Lazy reduction + for (size_t j = 0; j < coeff_count; j++) + { + tmp_dest_modq[j + (i * coeff_count)] += encrypted[j + (i * coeff_count)]; + } + + // Compute |gamma * plain|qi * ct(s) + multiply_poly_scalar_coeffmod(tmp_dest_modq.get() + (i * coeff_count), coeff_count, + plain_gamma_product[i], coeff_modulus[i], tmp_dest_modq.get() + (i * coeff_count)); + } + + // Make another temp destination to get the poly in mod {gamma U plain_modulus} + auto tmp_dest_plain_gamma(allocate_poly(coeff_count, plain_gamma_uint64_count, pool)); + + // Compute FastBConvert from q to {gamma, plain_modulus} + base_converter->fastbconv_plain_gamma(tmp_dest_modq.get(), tmp_dest_plain_gamma.get(), pool); + + // Compute result multiply by coeff_modulus inverse in mod {gamma U plain_modulus} + for (size_t i = 0; i < plain_gamma_uint64_count; i++) + { + multiply_poly_scalar_coeffmod(tmp_dest_plain_gamma.get() + (i * coeff_count), + coeff_count, neg_inv_coeff[i], plain_gamma_array[i], + tmp_dest_plain_gamma.get() + (i * coeff_count)); + } + + // First correct the values which are larger than floor(gamma/2) + uint64_t gamma_div_2 = plain_gamma_array[1].value() >> 1; + + // Now compute the subtraction to remove error and perform final multiplication by + // gamma inverse mod plain_modulus + for (size_t i = 0; i < coeff_count; i++) + { + // Need correction beacuse of center mod + if (tmp_dest_plain_gamma[i + coeff_count] > gamma_div_2) + { + // Compute -(gamma - a) instead of (a - gamma) + tmp_dest_plain_gamma[i + coeff_count] = plain_gamma_array[1].value() - + tmp_dest_plain_gamma[i + coeff_count]; + tmp_dest_plain_gamma[i + coeff_count] %= plain_gamma_array[0].value(); + wide_destination[i] = add_uint_uint_mod(tmp_dest_plain_gamma[i], + tmp_dest_plain_gamma[i + coeff_count], plain_gamma_array[0]); + } + // No correction needed + else + { + tmp_dest_plain_gamma[i + coeff_count] %= plain_gamma_array[0].value(); + wide_destination[i] = sub_uint_uint_mod(tmp_dest_plain_gamma[i], + tmp_dest_plain_gamma[i + coeff_count], plain_gamma_array[0]); + } + } + + // How many non-zero coefficients do we really have in the result? + size_t plain_coeff_count = get_significant_uint64_count_uint( + wide_destination.get(), coeff_count); + + // Resize destination to appropriate size + destination.resize(max(plain_coeff_count, size_t(1))); + destination.parms_id() = parms_id_zero; + + // Perform final multiplication by gamma inverse mod plain_modulus + multiply_poly_scalar_coeffmod(wide_destination.get(), + max(plain_coeff_count, size_t(1)), + inv_gamma, plain_gamma_array[0], destination.data()); + } + + void Decryptor::ckks_decrypt(const Ciphertext &encrypted, + Plaintext &destination, MemoryPoolHandle pool) + { + if (!encrypted.is_ntt_form()) + { + throw invalid_argument("encrypted must be in NTT form"); + } + + // We already know that the parameters are valid + auto &context_data = *context_->context_data(encrypted.parms_id()); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t rns_poly_uint64_count = mul_safe(coeff_count, coeff_mod_count); + size_t first_rns_poly_uint64_count = mul_safe(coeff_count, + context_->context_data()->parms().coeff_modulus().size()); + size_t encrypted_size = encrypted.size(); + + // Make sure we have enough secret key powers computed + compute_secret_key_array(encrypted_size - 1); + + /* + Decryption consists in finding c_0 + c_1 *s + ... + c_{count-1} * s^{count-1} mod q_1 * q_2 * q_3 + as long as ||m + v|| < q_1 * q_2 * q_3 + This is equal to m + v where ||v|| is small enough. + */ + + // Since we overwrite destination, we zeroize destination parameters + // This is necessary, otherwise resize will throw an exception. + destination.parms_id() = parms_id_zero; + + // Resize destination to appropriate size + destination.resize(rns_poly_uint64_count); + + // Make a temp destination for all the arithmetic mod q1, q2, q3 + //auto tmp_dest_modq(allocate_zero_poly(coeff_count, decryption_coeff_mod_count, pool)); + + // put < (c_1 , c_2, ... , c_{count-1}) , (s,s^2,...,s^{count-1}) > mod q in destination + + // Now do the dot product of encrypted_copy and the secret key array using NTT. + // The secret key powers are already NTT transformed. + + auto copy_operand1(allocate_uint(coeff_count, pool)); + for (size_t i = 0; i < coeff_mod_count; i++) + { + // Initialize pointers for multiplication + // c_1 mod qi + const uint64_t *current_array1 = encrypted.data(1) + (i * coeff_count); + // s mod qi + const uint64_t *current_array2 = secret_key_array_.get() + (i * coeff_count); + // set destination coefficients to zero modulo q_i + set_zero_uint(coeff_count, destination.data() + (i * coeff_count)); + + for (size_t j = 0; j < encrypted_size - 1; j++) + { + // Perform the dyadic product. + set_uint_uint(current_array1, coeff_count, copy_operand1.get()); + + // Lazy reduction + //ntt_negacyclic_harvey_lazy(copy_operand1.get(), small_ntt_tables[i]); + dyadic_product_coeffmod(copy_operand1.get(), current_array2, coeff_count, + coeff_modulus[i], copy_operand1.get()); + add_poly_poly_coeffmod(destination.data() + (i * coeff_count), + copy_operand1.get(), coeff_count, coeff_modulus[i], + destination.data() + (i * coeff_count)); + + // go to c_{1+j+1} and s^{1+j+1} mod qi + current_array1 += rns_poly_uint64_count; + current_array2 += first_rns_poly_uint64_count; + } + + // add c_0 into destination + add_poly_poly_coeffmod(destination.data() + (i * coeff_count), + encrypted.data() + (i * coeff_count), coeff_count, + coeff_modulus[i], destination.data() + (i * coeff_count)); + } + + // Set destination parameters as in encrypted + //destination.parms_id() = last_parms_id; + destination.parms_id() = encrypted.parms_id(); + destination.scale() = encrypted.scale(); + } + + void Decryptor::compute_secret_key_array(size_t max_power) + { +#ifdef SEAL_DEBUG + if (max_power < 1) + { + throw invalid_argument("max_power must be at least 1"); + } +#endif + // WARNING: This function must be called with the original context_data + auto &context_data = *context_->context_data(); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t rns_poly_uint64_count = mul_safe(coeff_count, coeff_mod_count); + + ReaderLock reader_lock(secret_key_array_locker_.acquire_read()); + + size_t old_size = secret_key_array_size_; + size_t new_size = max(max_power, old_size); + + if (old_size == new_size) + { + return; + } + + reader_lock.unlock(); + + // Need to extend the array + // Compute powers of secret key until max_power + auto new_secret_key_array(allocate_poly(mul_safe(new_size, coeff_count), + coeff_mod_count, pool_)); + set_poly_poly(secret_key_array_.get(), mul_safe(old_size, coeff_count), + coeff_mod_count, new_secret_key_array.get()); + + uint64_t *prev_poly_ptr = new_secret_key_array.get() + + mul_safe(old_size - 1, rns_poly_uint64_count); + uint64_t *next_poly_ptr = prev_poly_ptr + rns_poly_uint64_count; + + // Since all of the key powers in secret_key_array_ are already NTT transformed, + // to get the next one we simply need to compute a dyadic product of the last + // one with the first one [which is equal to NTT(secret_key_)]. + for (size_t i = old_size; i < new_size; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + dyadic_product_coeffmod(prev_poly_ptr + (j * coeff_count), + new_secret_key_array.get() + (j * coeff_count), + coeff_count, coeff_modulus[j], + next_poly_ptr + (j * coeff_count)); + } + prev_poly_ptr = next_poly_ptr; + next_poly_ptr += rns_poly_uint64_count; + } + + + // Take writer lock to update array + WriterLock writer_lock(secret_key_array_locker_.acquire_write()); + + // Do we still need to update size? + old_size = secret_key_array_size_; + new_size = max(max_power, secret_key_array_size_); + + if (old_size == new_size) + { + return; + } + + // Acquire new array + secret_key_array_size_ = new_size; + secret_key_array_.acquire(new_secret_key_array); + } + + void Decryptor::compose( + const SEALContext::ContextData &context_data, uint64_t *value) + { +#ifdef SEAL_DEBUG + if (value == nullptr) + { + throw invalid_argument("input cannot be null"); + } +#endif + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t rns_poly_uint64_count = mul_safe(coeff_count, coeff_mod_count); + + auto &base_converter = context_data.base_converter(); + auto coeff_products_array = base_converter->get_coeff_products_array(); + auto &inv_coeff_mod_coeff_array = base_converter->get_inv_coeff_mod_coeff_array(); + + // Set temporary coefficients_ptr pointer to point to either an existing + // allocation given as parameter, or else to a new allocation from the memory pool. + auto coefficients(allocate_uint(rns_poly_uint64_count, pool_)); + uint64_t *coefficients_ptr = coefficients.get(); + + // Re-merge the coefficients first + for (size_t i = 0; i < coeff_count; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + coefficients_ptr[j + (i * coeff_mod_count)] = value[(j * coeff_count) + i]; + } + } + + auto temp(allocate_uint(coeff_mod_count, pool_)); + set_zero_uint(rns_poly_uint64_count, value); + + for (size_t i = 0; i < coeff_count; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + uint64_t tmp = multiply_uint_uint_mod(coefficients_ptr[j], + inv_coeff_mod_coeff_array[j], coeff_modulus[j]); + multiply_uint_uint64(coeff_products_array + (j * coeff_mod_count), + coeff_mod_count, tmp, coeff_mod_count, temp.get()); + add_uint_uint_mod(temp.get(), value + (i * coeff_mod_count), + context_data.total_coeff_modulus(), + coeff_mod_count, value + (i * coeff_mod_count)); + } + set_zero_uint(coeff_mod_count, temp.get()); + coefficients_ptr += coeff_mod_count; + } + } + + int Decryptor::invariant_noise_budget(const Ciphertext &encrypted) + { + if (context_->context_data()->parms().scheme() != scheme_type::BFV) + { + throw logic_error("unsupported scheme"); + } + + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypted is not valid for encryption parameters"); + } + if (encrypted.is_ntt_form()) + { + throw invalid_argument("encrypted cannot be in NTT form"); + } + + auto &context_data = *context_data_ptr; + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t rns_poly_uint64_count = mul_safe(coeff_count, coeff_mod_count); + size_t encrypted_size = encrypted.size(); + uint64_t plain_modulus = parms.plain_modulus().value(); + + auto &small_ntt_tables = context_data.small_ntt_tables(); + + // Storage for noise uint + auto destination(allocate_uint(coeff_mod_count, pool_)); + + // Storage for noise poly + auto noise_poly(allocate_zero_poly(coeff_count, coeff_mod_count, pool_)); + + // Now need to compute c(s) - Delta*m (mod q) + + // Make sure we have enough secret keys computed + compute_secret_key_array(encrypted_size - 1); + + /* + Firstly find c_0 + c_1 *s + ... + c_{count-1} * s^{count-1} mod q + This is equal to Delta m + v where ||v|| < Delta/2. + */ + // put < (c_1 , c_2, ... , c_{count-1}) , (s,s^2,...,s^{count-1}) > mod q + // in destination_poly. + // Make a copy of the encryption for NTT (except the first polynomial is + // not needed). + auto encrypted_copy(allocate_poly( + mul_safe(encrypted_size - 1, coeff_count), coeff_mod_count, pool_)); + set_poly_poly(encrypted.data(1), mul_safe(encrypted_size - 1, coeff_count), + coeff_mod_count, encrypted_copy.get()); + + // Now do the dot product of encrypted_copy and the secret key array using NTT. + // The secret key powers are already NTT transformed. + auto copy_operand1(allocate_uint(coeff_count, pool_)); + for (size_t i = 0; i < coeff_mod_count; i++) + { + // Initialize pointers for multiplication + const uint64_t *current_array1 = encrypted.data(1) + (i * coeff_count); + const uint64_t *current_array2 = secret_key_array_.get() + (i * coeff_count); + + for (size_t j = 0; j < encrypted_size - 1; j++) + { + // Perform the dyadic product. + set_uint_uint(current_array1, coeff_count, copy_operand1.get()); + + // Lazy reduction + ntt_negacyclic_harvey_lazy(copy_operand1.get(), small_ntt_tables[i]); + + dyadic_product_coeffmod(copy_operand1.get(), current_array2, coeff_count, + coeff_modulus[i], copy_operand1.get()); + add_poly_poly_coeffmod(noise_poly.get() + (i * coeff_count), + copy_operand1.get(), + coeff_count, coeff_modulus[i], + noise_poly.get() + (i * coeff_count)); + + current_array1 += rns_poly_uint64_count; + current_array2 += rns_poly_uint64_count; + } + + // Perform inverse NTT + inverse_ntt_negacyclic_harvey(noise_poly.get() + (i * coeff_count), + small_ntt_tables[i]); + } + + for (size_t i = 0; i < coeff_mod_count; i++) + { + // add c_0 into noise_poly + add_poly_poly_coeffmod(noise_poly.get() + (i * coeff_count), + encrypted.data() + (i * coeff_count), coeff_count, coeff_modulus[i], + noise_poly.get() + (i * coeff_count)); + + // Multiply by parms.plain_modulus() and reduce mod parms.coeff_modulus() to get + // parms.coeff_modulus()*noise + multiply_poly_scalar_coeffmod(noise_poly.get() + (i * coeff_count), + coeff_count, plain_modulus, coeff_modulus[i], + noise_poly.get() + (i * coeff_count)); + } + + // Compose the noise + compose(context_data, noise_poly.get()); + + // Next we compute the infinity norm mod parms.coeff_modulus() + poly_infty_norm_coeffmod(noise_poly.get(), coeff_count, coeff_mod_count, + context_data.total_coeff_modulus(), destination.get(), pool_); + + // The -1 accounts for scaling the invariant noise by 2 + int bit_count_diff = context_data.total_coeff_modulus_bit_count() - + get_significant_bit_count_uint(destination.get(), coeff_mod_count) - 1; + return max(0, bit_count_diff); + } +} diff --git a/src/seal/decryptor.h b/src/seal/decryptor.h new file mode 100644 index 000000000..9d90f6928 --- /dev/null +++ b/src/seal/decryptor.h @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include "seal/randomgen.h" +#include "seal/encryptionparams.h" +#include "seal/context.h" +#include "seal/util/smallntt.h" +#include "seal/memorymanager.h" +#include "seal/ciphertext.h" +#include "seal/plaintext.h" +#include "seal/secretkey.h" +#include "seal/util/baseconverter.h" +#include "seal/smallmodulus.h" +#include "seal/util/locks.h" + +namespace seal +{ + /** + Decrypts Ciphertext objects into Plaintext objects. Constructing a Decryptor + requires a SEALContext with valid encryption parameters, and the secret key. + The Decryptor is also used to compute the invariant noise budget in a given + ciphertext. + + @par Overloads + For the decrypt function we provide two overloads concerning the memory pool + used in allocations needed during the operation. In one overload the global + memory pool is used for this purpose, and in another overload the user can + supply a MemoryPoolHandle to be used instead. This is to allow one single + Decryptor to be used concurrently by several threads without running into + thread contention in allocations taking place during operations. For example, + one can share one single Decryptor across any number of threads, but in each + thread call the decrypt function by giving it a thread-local MemoryPoolHandle + to use. It is important for a developer to understand how this works to avoid + unnecessary performance bottlenecks. + + + @par NTT form + When using the BFV scheme (scheme_type::BFV), all plaintext and ciphertexts + should remain by default in the usual coefficient representation, i.e. not in + NTT form. When using the CKKS scheme (scheme_type::CKKS), all plaintexts and + ciphertexts should remain by default in NTT form. We call these scheme-specific + NTT states the "default NTT form". Decryption requires the input ciphertexts + to be in the default NTT form, and will throw an exception if this is not the + case. + */ + class Decryptor + { + public: + /** + Creates a Decryptor instance initialized with the specified SEALContext + and secret key. + + @param[in] context The SEALContext + @param[in] secret_key The secret key + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + @throws std::invalid_argument if secret_key is not valid + */ + Decryptor(std::shared_ptr context, const SecretKey &secret_key); + + /* + Decrypts a Ciphertext and stores the result in the destination parameter. + + @param[in] encrypted The ciphertext to decrypt + @param[out] destination The plaintext to overwrite with the decrypted ciphertext + @throws std::invalid_argument if encrypted is not valid for the encryption parameters + @throws std::invalid_argument if encrypted is not in the default NTT form + */ + void decrypt(const Ciphertext &encrypted, Plaintext &destination); + + /* + Computes the invariant noise budget (in bits) of a ciphertext. The invariant + noise budget measures the amount of room there is for the noise to grow while + ensuring correct decryptions. This function works only with the BFV scheme. + + @par Invariant Noise Budget + The invariant noise polynomial of a ciphertext is a rational coefficient + polynomial, such that a ciphertext decrypts correctly as long as the + coefficients of the invariantnoise polynomial are of absolute value less + than 1/2. Thus, we call the infinity-norm of the invariant noise polynomial + the invariant noise, and for correct decryption requireit to be less than + 1/2. If v denotes the invariant noise, we define the invariant noise budget + as -log2(2v). Thus, the invariant noise budget starts from some initial + value, which depends on the encryption parameters, and decreases when + computations are performed. When the budget reaches zero, the ciphertext + becomes too noisy to decrypt correctly. + + @param[in] encrypted The ciphertext + @throws std::invalid_argument if the scheme is not BFV + @throws std::invalid_argument if encrypted is not valid for the encryption parameters + @throws std::invalid_argument if encrypted is in NTT form + */ + int invariant_noise_budget(const Ciphertext &encrypted); + + private: + void bfv_decrypt(const Ciphertext &encrypted, Plaintext &destination, + MemoryPoolHandle pool); + + void ckks_decrypt(const Ciphertext &encrypted, Plaintext &destination, + MemoryPoolHandle pool); + + Decryptor(const Decryptor ©) = delete; + + Decryptor(Decryptor &&source) = delete; + + Decryptor &operator =(const Decryptor &assign) = delete; + + Decryptor &operator =(Decryptor &&assign) = delete; + + void compute_secret_key_array(std::size_t max_power); + + void compose(const SEALContext::ContextData &context_data, + std::uint64_t *value); + + /** + We use a fresh memory pool with `clear_on_destruction' enabled + */ + MemoryPoolHandle pool_ = MemoryManager::GetPool(mm_prof_opt::FORCE_NEW, true); + + std::shared_ptr context_{ nullptr }; + + util::Pointer secret_key_; + + std::size_t secret_key_array_size_ = 0; + + util::Pointer secret_key_array_; + + mutable util::ReaderWriterLocker secret_key_array_locker_; + }; +} diff --git a/src/seal/defaultparams.h b/src/seal/defaultparams.h new file mode 100644 index 000000000..7d0874811 --- /dev/null +++ b/src/seal/defaultparams.h @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "seal/util/globals.h" +#include "seal/smallmodulus.h" +#include "seal/util/defines.h" +#include +#include +#include + +namespace seal +{ + /** + Returns the default coefficients modulus for a given polynomial modulus degree. + The polynomial modulus and the coefficient modulus obtained in this way should + provide approdimately 128 bits of security against the best known attacks, + assuming the standard deviation of the noise distribution is left to its default + value. + + @param[in] poly_modulus_degree The degree of the polynomial modulus + @throws std::out_of_range if poly_modulus_degree is not 1024, 2048, 4096, 8192, 16384, or 32768 + */ + inline std::vector coeff_modulus_128(std::size_t poly_modulus_degree) + { + try + { + return util::global_variables::default_coeff_modulus_128.at(poly_modulus_degree); + } + catch (const std::exception &) + { + throw std::out_of_range("no default parameters found"); + } + return {}; + } + + /** + Returns the default coefficients modulus for a given polynomial modulus degree. + The polynomial modulus and the coefficient modulus obtained in this way should + provide approdimately 192 bits of security against the best known attacks, + assuming the standard deviation of the noise distribution is left to its default + value. + + @param[in] poly_modulus_degree The degree of the polynomial modulus + @throws std::out_of_range if poly_modulus_degree is not 1024, 2048, 4096, 8192, 16384, or 32768 + */ + inline std::vector coeff_modulus_192(std::size_t poly_modulus_degree) + { + try + { + return util::global_variables::default_coeff_modulus_192.at(poly_modulus_degree); + } + catch (const std::exception &) + { + throw std::out_of_range("no default parameters found"); + } + return {}; + } + + /** + Returns the default coefficients modulus for a given polynomial modulus degree. + The polynomial modulus and the coefficient modulus obtained in this way should + provide approdimately 256 bits of security against the best known attacks, + assuming the standard deviation of the noise distribution is left to its default + value. + + @param[in] poly_modulus_degree The degree of the polynomial modulus + @throws std::out_of_range if poly_modulus_degree is not 1024, 2048, 4096, 8192, 16384, or 32768 + */ + inline std::vector coeff_modulus_256(std::size_t poly_modulus_degree) + { + try + { + return util::global_variables::default_coeff_modulus_256.at(poly_modulus_degree); + } + catch (const std::exception &) + { + throw std::out_of_range("no default parameters found"); + } + return {}; + } + + /** + Returns a 60-bit coefficient modulus prime. + + @param[in] index The list index of the prime + @throws std::out_of_range if index is not within [0, 64) + */ + inline SmallModulus small_mods_60bit(std::size_t index) + { + try + { + return util::global_variables::small_mods_60bit.at(index); + } + catch (const std::exception &) + { + throw std::out_of_range("index out of range"); + } + return 0; + } + + /** + Returns a 50-bit coefficient modulus prime. + + @param[in] index The list index of the prime + @throws std::out_of_range if index is not within [0, 64) + */ + inline SmallModulus small_mods_50bit(std::size_t index) + { + try + { + return util::global_variables::small_mods_50bit.at(index); + } + catch (const std::exception &) + { + throw std::out_of_range("index out of range"); + } + return 0; + } + + /** + Returns a 40-bit coefficient modulus prime. + + @param[in] index The list index of the prime + @throws std::out_of_range if index is not within [0, 64) + */ + inline SmallModulus small_mods_40bit(std::size_t index) + { + try + { + return util::global_variables::small_mods_40bit.at(index); + } + catch (const std::exception &) + { + throw std::out_of_range("index out of range"); + } + return 0; + } + + /** + Returns a 30-bit coefficient modulus prime. + + @param[in] index The list index of the prime + @throws std::out_of_range if index is not within [0, 64) + */ + inline SmallModulus small_mods_30bit(std::size_t index) + { + try + { + return util::global_variables::small_mods_30bit.at(index); + } + catch (const std::exception &) + { + throw std::out_of_range("index out of range"); + } + return 0; + } + + /** + Returns the largest allowed decomposition bit count (60). + */ + constexpr int dbc_max() + { + return SEAL_DBC_MAX; + } + + /** + Returns the smallest allowed decomposition bit count (1). + */ + constexpr int dbc_min() + { + return SEAL_DBC_MIN; + } +} diff --git a/src/seal/encoder.cpp b/src/seal/encoder.cpp new file mode 100644 index 000000000..5991cbbe0 --- /dev/null +++ b/src/seal/encoder.cpp @@ -0,0 +1,1379 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include "seal/encoder.h" +#include "seal/util/common.h" +#include "seal/util/polyarith.h" +#include "seal/util/pointer.h" +#include "seal/util/defines.h" +#include "seal/util/uintarithsmallmod.h" + +using namespace std; +using namespace seal::util; + +namespace seal +{ + BinaryEncoder::BinaryEncoder(const SmallModulus &plain_modulus) : + plain_modulus_(plain_modulus), + coeff_neg_threshold_((plain_modulus.value() + 1) >> 1), + neg_one_(plain_modulus_.value() - 1) + { + if (plain_modulus.bit_count() <= 1) + { + throw invalid_argument("plain_modulus must be at least 2"); + } + } + + Plaintext BinaryEncoder::encode(uint64_t value) + { + Plaintext result; + encode(value, result); + return result; + } + + void BinaryEncoder::encode(uint64_t value, Plaintext &destination) + { + size_t encode_coeff_count = safe_cast( + get_significant_bit_count(value)); + destination.resize(encode_coeff_count); + destination.set_zero(); + + size_t coeff_index = 0; + while (value != 0) + { + if ((value & 1) != 0) + { + destination[coeff_index] = 1; + } + value >>= 1; + coeff_index++; + } + } + + Plaintext BinaryEncoder::encode(int64_t value) + { + Plaintext result; + encode(value, result); + return result; + } + + void BinaryEncoder::encode(int64_t value, Plaintext &destination) + { + if (value < 0) + { + uint64_t pos_value = static_cast(-value); + size_t encode_coeff_count = safe_cast( + get_significant_bit_count(pos_value)); + destination.resize(encode_coeff_count); + destination.set_zero(); + + size_t coeff_index = 0; + while (pos_value != 0) + { + if ((pos_value & 1) != 0) + { + destination[coeff_index] = neg_one_; + } + pos_value >>= 1; + coeff_index++; + } + } + else + { + encode(static_cast(value), destination); + } + } + + Plaintext BinaryEncoder::encode(const BigUInt &value) + { + Plaintext result; + encode(value, result); + return result; + } + + void BinaryEncoder::encode(const BigUInt &value, Plaintext &destination) + { + size_t encode_coeff_count = safe_cast( + value.significant_bit_count()); + destination.resize(encode_coeff_count); + destination.set_zero(); + + size_t coeff_index = 0; + size_t coeff_count = safe_cast(value.significant_bit_count()); + size_t coeff_uint64_count = value.uint64_count(); + while (coeff_index < coeff_count) + { + if (is_bit_set_uint(value.data(), coeff_uint64_count, + safe_cast(coeff_index))) + { + destination[coeff_index] = 1; + } + coeff_index++; + } + } + + uint32_t BinaryEncoder::decode_uint32(const Plaintext &plain) + { + uint64_t value64 = decode_uint64(plain); + if (value64 > UINT32_MAX) + { + throw invalid_argument("output out of range"); + } + return static_cast(value64); + } + + uint64_t BinaryEncoder::decode_uint64(const Plaintext &plain) + { + BigUInt bigvalue = decode_biguint(plain); + int bit_count = bigvalue.significant_bit_count(); + if (bit_count > bits_per_uint64) + { + // Decoded value has more bits than fit in a 64-bit uint. + throw invalid_argument("output out of range"); + } + return bit_count > 0 ? bigvalue.data()[0] : 0; + } + + int32_t BinaryEncoder::decode_int32(const Plaintext &plain) + { + int64_t value64 = decode_int64(plain); + return safe_cast(value64); + } + + int64_t BinaryEncoder::decode_int64(const Plaintext &plain) + { + unsigned long long pos_value; + + // Determine coefficient threshold for negative numbers. + int64_t result = 0; + for (size_t bit_index = plain.significant_coeff_count(); bit_index--; ) + { + unsigned long long coeff = plain[bit_index]; + + // Left shift result. + int64_t next_result = result << 1; + if ((next_result < 0) != (result < 0)) + { + // Check for overflow. + throw invalid_argument("output out of range"); + } + + // Get sign/magnitude of coefficient. + int coeff_bit_count = get_significant_bit_count(coeff); + if (coeff >= plain_modulus_.value()) + { + // Coefficient is bigger than plaintext modulus + throw invalid_argument("plain does not represent a valid plaintext polynomial"); + } + bool coeff_is_negative = coeff >= coeff_neg_threshold_; + const unsigned long long *pos_pointer; + if (coeff_is_negative) + { + if (sub_uint64(plain_modulus_.value(), coeff, 0, &pos_value)) + { + // Check for borrow, which means value is greater than plain_modulus. + throw invalid_argument("plain does not represent a valid plaintext polynomial"); + } + pos_pointer = &pos_value; + coeff_bit_count = get_significant_bit_count(pos_value); + } + else + { + pos_pointer = &coeff; + } + if (coeff_bit_count > bits_per_uint64 - 1) + { + // Absolute value of coefficient is too large to represent in a int64_t, so overflow. + throw invalid_argument("output out of range"); + } + int64_t coeff_value = safe_cast(*pos_pointer); + if (coeff_is_negative) + { + coeff_value = -coeff_value; + } + bool next_result_was_negative = next_result < 0; + next_result += coeff_value; + bool next_result_is_negative = next_result < 0; + if ((next_result_was_negative == coeff_is_negative) && + (next_result_was_negative != next_result_is_negative)) + { + // Accumulation and coefficient had same signs, but accumulator changed signs after addition, so must be overflow. + throw invalid_argument("output out of range"); + } + result = next_result; + } + return result; + } + + BigUInt BinaryEncoder::decode_biguint(const Plaintext &plain) + { + unsigned long long pos_value; + + // Determine coefficient threshold for negative numbers. + size_t result_uint64_count = 1; + size_t bits_per_uint64_sz = safe_cast(bits_per_uint64); + size_t result_bit_capacity = result_uint64_count * bits_per_uint64_sz; + BigUInt resultint(safe_cast(result_bit_capacity)); + bool result_is_negative = false; + uint64_t *result = resultint.data(); + for (size_t bit_index = plain.significant_coeff_count(); bit_index--; ) + { + unsigned long long coeff = plain[bit_index]; + + // Left shift result, resizing if highest bit set. + if (is_bit_set_uint(result, result_uint64_count, + safe_cast(result_bit_capacity) - 1)) + { + // Resize to make bigger. + result_uint64_count++; + result_bit_capacity = mul_safe(result_uint64_count, bits_per_uint64_sz); + resultint.resize(safe_cast(result_bit_capacity)); + result = resultint.data(); + } + left_shift_uint(result, 1, result_uint64_count, result); + + // Get sign/magnitude of coefficient. + if (coeff >= plain_modulus_.value()) + { + // Coefficient is bigger than plaintext modulus + throw invalid_argument("plain does not represent a valid plaintext polynomial"); + } + bool coeff_is_negative = coeff >= coeff_neg_threshold_; + const unsigned long long *pos_pointer; + if (coeff_is_negative) + { + if (sub_uint64(plain_modulus_.value(), coeff, 0, &pos_value)) + { + // Check for borrow, which means value is greater than plain_modulus. + throw invalid_argument("plain does not represent a valid plaintext polynomial"); + } + pos_pointer = &pos_value; + } + else + { + pos_pointer = &coeff; + } + + // Add or subtract-in coefficient. + if (result_is_negative == coeff_is_negative) + { + // Result and coefficient have same signs so add. + if (add_uint_uint64(result, *pos_pointer, result_uint64_count, result)) + { + // Add produced a carry that didn't fit, so resize and put it in. + int carry_bit_index = safe_cast(mul_safe( + result_uint64_count, bits_per_uint64_sz)); + result_uint64_count++; + result_bit_capacity = mul_safe( + result_uint64_count, bits_per_uint64_sz); + resultint.resize(safe_cast(result_bit_capacity)); + result = resultint.data(); + set_bit_uint(result, result_uint64_count, carry_bit_index); + } + } + else + { + // Result and coefficient have opposite signs so subtract. + if (sub_uint_uint64(result, *pos_pointer, result_uint64_count, result)) + { + // Subtraction produced a borrow so coefficient is larger (in magnitude) + // than result, so need to negate result. + negate_uint(result, result_uint64_count, result); + result_is_negative = !result_is_negative; + } + } + } + + // Verify result is non-negative. + if (result_is_negative && !resultint.is_zero()) + { + throw invalid_argument("poly must decode to positive value"); + } + return resultint; + } + + void BinaryEncoder::decode_biguint(const Plaintext &plain, BigUInt &destination) + { + unsigned long long pos_value; + + // Determine coefficient threshold for negative numbers. + destination.set_zero(); + size_t bits_per_uint64_sz = static_cast(bits_per_uint64); + size_t result_uint64_count = destination.uint64_count(); + size_t result_bit_capacity = result_uint64_count * bits_per_uint64_sz; + bool result_is_negative = false; + uint64_t *result = destination.data(); + for (size_t bit_index = plain.significant_coeff_count(); bit_index--; ) + { + unsigned long long coeff = plain[bit_index]; + + // Left shift result, failing if highest bit set. + if (is_bit_set_uint(result, result_uint64_count, + safe_cast(result_bit_capacity) - 1)) + { + throw invalid_argument("plain does not represent a valid plaintext polynomial"); + } + left_shift_uint(result, 1, result_uint64_count, result); + + // Get sign/magnitude of coefficient. + if (coeff >= plain_modulus_.value()) + { + // Coefficient is bigger than plaintext modulus. + throw invalid_argument("plain does not represent a valid plaintext polynomial"); + } + bool coeff_is_negative = coeff >= coeff_neg_threshold_; + const unsigned long long *pos_pointer; + if (coeff_is_negative) + { + if (sub_uint64(plain_modulus_.value(), coeff, 0, &pos_value)) + { + // Check for borrow, which means value is greater than plain_modulus. + throw invalid_argument("plain does not represent a valid plaintext polynomial"); + } + pos_pointer = &pos_value; + } + else + { + pos_pointer = &coeff; + } + + // Add or subtract-in coefficient. + if (result_is_negative == coeff_is_negative) + { + // Result and coefficient have same signs so add. + if (add_uint_uint64(result, *pos_pointer, result_uint64_count, result)) + { + // Add produced a carry that didn't fit. + throw invalid_argument("output out of range"); + } + } + else + { + // Result and coefficient have opposite signs so subtract. + if (sub_uint_uint64(result, *pos_pointer, result_uint64_count, result)) + { + // Subtraction produced a borrow so coefficient is larger (in magnitude) + // than result, so need to negate result. + negate_uint(result, result_uint64_count, result); + result_is_negative = !result_is_negative; + } + } + } + + // Verify result is non-negative. + if (result_is_negative && !destination.is_zero()) + { + throw invalid_argument("poly must decode to a positive value"); + } + + // Verify result fits in actual bit-width (as opposed to capacity) of destination. + if (destination.significant_bit_count() > destination.bit_count()) + { + throw invalid_argument("output out of range"); + } + } + + BalancedEncoder::BalancedEncoder(const SmallModulus &plain_modulus, uint64_t base) : + plain_modulus_(plain_modulus), + base_(base), + coeff_neg_threshold_((plain_modulus.value() + 1) >> 1) + { + if (base <= 2) + { + throw invalid_argument("base must be at least 3"); + } + if (*plain_modulus.data() < base) + { + throw invalid_argument("plain_modulus must be at least b"); + } + } + + Plaintext BalancedEncoder::encode(uint64_t value) + { + Plaintext result; + encode(value, result); + + return result; + } + + void BalancedEncoder::encode(uint64_t value, Plaintext &destination) + { + // We estimate the number of coefficients in the expansion + size_t encode_coeff_count = static_cast(ceil( + static_cast(get_significant_bit_count(value)) / log2(base_)) + 1); + destination.resize(encode_coeff_count); + destination.set_zero(); + + size_t coeff_index = 0; + while (value) + { + uint64_t remainder = value % base_; + if (0 < remainder && remainder <= (base_ - 1) / 2) + { + destination[coeff_index] = remainder; + } + else if (remainder > (base_ - 1) / 2) + { + destination[coeff_index] = plain_modulus_.value() - base_ + remainder; + } + value = (value + base_ / 2) / base_; + + coeff_index++; + } + } + + Plaintext BalancedEncoder::encode(int64_t value) + { + Plaintext result; + encode(value, result); + return result; + } + + void BalancedEncoder::encode(int64_t value, Plaintext &destination) + { + if (value < 0) + { + uint64_t pos_value = static_cast(-value); + + // We estimate the number of coefficients in the expansion + size_t encode_coeff_count = static_cast(ceil( + static_cast(get_significant_bit_count(pos_value)) / log2(base_)) + 1); + destination.resize(encode_coeff_count); + destination.set_zero(); + + size_t coeff_index = 0; + while (pos_value) + { + uint64_t remainder = pos_value % base_; + if (0 < remainder && remainder <= (base_ - 1) / 2) + { + destination[coeff_index] = plain_modulus_.value() - remainder; + } + else if (remainder > (base_ - 1) / 2) + { + destination[coeff_index] = base_ - remainder; + + if ((base_ % 2 == 0) && (remainder == base_ / 2)) + { + destination[coeff_index] = + plain_modulus_.value() - destination[coeff_index]; + } + } + + // Note that we are adding now (base_-1)/2 instead of base_/2 as in the even case, + // because value is negative. + pos_value = (pos_value + ((base_ - 1) / 2)) / base_; + + coeff_index++; + } + } + else + { + encode(static_cast(value), destination); + } + } + + Plaintext BalancedEncoder::encode(const BigUInt &value) + { + Plaintext result; + encode(value, result); + return result; + } + + void BalancedEncoder::encode(const BigUInt &value, Plaintext &destination) + { + if (value.is_zero()) + { + destination.set_zero(); + return; + } + + // We estimate the number of coefficients in the expansion + size_t bits_per_uint64_sz = static_cast(bits_per_uint64); + size_t encode_coeff_count = static_cast(ceil( + static_cast(value.significant_bit_count()) / log2(base_)) + 1); + size_t encode_uint64_count = + divide_round_up(encode_coeff_count, bits_per_uint64_sz); + + destination.resize(encode_coeff_count); + destination.set_zero(); + + auto base_uint(allocate_uint(encode_uint64_count, pool_)); + set_uint(base_, encode_uint64_count, base_uint.get()); + auto base_div_two_uint(allocate_uint(encode_uint64_count, pool_)); + right_shift_uint(base_uint.get(), 1, encode_uint64_count, base_div_two_uint.get()); + uint64_t mod_minus_base = plain_modulus_.value() - base_; + + auto quotient(allocate_uint(encode_uint64_count, pool_)); + auto remainder(allocate_uint(encode_uint64_count, pool_)); + auto temp(allocate_uint(value.uint64_count(), pool_)); + set_uint_uint(value.data(), value.uint64_count(), temp.get()); + + size_t coeff_index = 0; + while (!is_zero_uint(temp.get(), value.uint64_count())) + { + divide_uint_uint(temp.get(), base_uint.get(), encode_uint64_count, + quotient.get(), remainder.get(), pool_); + uint64_t *dest_coeff = destination.data() + coeff_index; + if (is_greater_than_uint_uint(remainder.get(), base_div_two_uint.get(), + encode_uint64_count)) + { + *dest_coeff = mod_minus_base + remainder[0]; + } + else if (!is_zero_uint(remainder.get(), encode_uint64_count)) + { + *dest_coeff = remainder[0]; + } + add_uint_uint(temp.get(), base_div_two_uint.get(), encode_uint64_count, temp.get()); + divide_uint_uint(temp.get(), base_uint.get(), encode_uint64_count, + quotient.get(), remainder.get(), pool_); + set_uint_uint(quotient.get(), encode_uint64_count, temp.get()); + + coeff_index++; + } + } + + uint32_t BalancedEncoder::decode_uint32(const Plaintext &plain) + { + return safe_cast(decode_uint64(plain)); + } + + uint64_t BalancedEncoder::decode_uint64(const Plaintext &plain) + { + BigUInt bigvalue = decode_biguint(plain); + int bit_count = bigvalue.significant_bit_count(); + if (bit_count > bits_per_uint64) + { + // Decoded value has more bits than fit in a 64-bit uint. + throw invalid_argument("output out of range"); + } + return bit_count > 0 ? bigvalue.data()[0] : 0; + } + + int32_t BalancedEncoder::decode_int32(const Plaintext &plain) + { + return safe_cast(decode_int64(plain)); + } + + int64_t BalancedEncoder::decode_int64(const Plaintext &plain) + { + unsigned long long pos_value; + + // Determine coefficient threshold for negative numbers. + int64_t result = 0; + int64_t base_int = safe_cast(base_); + for (size_t bit_index = plain.significant_coeff_count(); bit_index--; ) + { + unsigned long long coeff = plain[bit_index]; + + // Multiply result by base. + int64_t next_result = mul_safe(result, base_int); + + // Get sign/magnitude of coefficient. + int coeff_bit_count = get_significant_bit_count(coeff); + if (coeff >= plain_modulus_.value()) + { + // Coefficient is bigger than plaintext modulus + throw invalid_argument("plain does not represent a valid plaintext polynomial"); + } + bool coeff_is_negative = coeff >= coeff_neg_threshold_; + const unsigned long long *pos_pointer; + if (coeff_is_negative) + { + if (sub_uint64(plain_modulus_.value(), coeff, 0, &pos_value)) + { + // Check for borrow, which means value is greater than plain_modulus. + throw invalid_argument("plain does not represent a valid plaintext polynomial"); + } + pos_pointer = &pos_value; + coeff_bit_count = get_significant_bit_count(pos_value); + } + else + { + pos_pointer = &coeff; + } + if (coeff_bit_count > bits_per_uint64 - 1) + { + // Absolute value of coefficient is too large to represent in a int64_t, so overflow. + throw invalid_argument("output out of range"); + } + int64_t coeff_value = static_cast(*pos_pointer); + if (coeff_is_negative) + { + coeff_value = -coeff_value; + } + bool next_result_was_negative = next_result < 0; + next_result += coeff_value; + bool next_result_is_negative = next_result < 0; + if ((next_result_was_negative == coeff_is_negative) && + (next_result_was_negative != next_result_is_negative)) + { + // Accumulation and coefficient had same signs, but accumulator changed signs after + // addition, so must be overflow. + throw invalid_argument("output out of range"); + } + result = next_result; + } + return result; + } + + BigUInt BalancedEncoder::decode_biguint(const Plaintext &plain) + { + // Determine plain_modulus width. + unsigned long long pos_value; + + // Determine coefficient threshold for negative numbers. + size_t bits_per_uint64_sz = static_cast(bits_per_uint64); + size_t result_uint64_count = 1; + size_t result_bit_capacity = result_uint64_count * bits_per_uint64_sz; + + // Quick sanity check + if (!fits_in(result_bit_capacity)) + { + throw logic_error("invalid parameters"); + } + + BigUInt resultint(static_cast(result_bit_capacity)); + bool result_is_negative = false; + uint64_t *result = resultint.data(); + + BigUInt base_uint(static_cast(result_bit_capacity)); + base_uint = base_; + BigUInt temp_result(static_cast(result_bit_capacity)); + + for (size_t bit_index = plain.significant_coeff_count(); bit_index--; ) + { + unsigned long long coeff = plain[bit_index]; + + // Multiply result by base. Resize if highest bit set. + if (is_bit_set_uint(result, result_uint64_count, + safe_cast(result_bit_capacity) - 1)) + { + // Resize to make bigger. + result_uint64_count++; + result_bit_capacity = mul_safe(result_uint64_count, bits_per_uint64_sz); + resultint.resize(safe_cast(result_bit_capacity)); + result = resultint.data(); + } + set_uint_uint(result, result_uint64_count, temp_result.data()); + multiply_uint_uint(temp_result.data(), result_uint64_count, base_uint.data(), + result_uint64_count, result_uint64_count, result); + + // Get sign/magnitude of coefficient. + if (coeff >= plain_modulus_.value()) + { + // Coefficient is bigger than plaintext modulus. + throw invalid_argument("plain does not represent a valid plaintext polynomial"); + } + bool coeff_is_negative = coeff >= coeff_neg_threshold_; + const unsigned long long *pos_pointer; + if (coeff_is_negative) + { + if (sub_uint64(plain_modulus_.value(), coeff, 0, &pos_value)) + { + // Check for borrow, which means value is greater than plain_modulus. + throw invalid_argument("plain does not represent a valid plaintext polynomial"); + } + pos_pointer = &pos_value; + } + else + { + pos_pointer = &coeff; + } + + // Add or subtract-in coefficient. + if (result_is_negative == coeff_is_negative) + { + // Result and coefficient have same signs so add. + if (add_uint_uint64(result, *pos_pointer, result_uint64_count, result)) + { + // Add produced a carry that didn't fit, so resize and put it in. + int carry_bit_index = safe_cast( + mul_safe(result_uint64_count, bits_per_uint64_sz)); + result_uint64_count++; + result_bit_capacity = mul_safe(result_uint64_count, bits_per_uint64_sz); + resultint.resize(safe_cast(result_bit_capacity)); + result = resultint.data(); + set_bit_uint(result, result_uint64_count, carry_bit_index); + } + } + else + { + // Result and coefficient have opposite signs so subtract. + if (sub_uint_uint64(result, *pos_pointer, result_uint64_count, result)) + { + // Subtraction produced a borrow so coefficient is larger (in magnitude) than result, so need to negate result. + negate_uint(result, result_uint64_count, result); + result_is_negative = !result_is_negative; + } + } + } + + // Verify result is non-negative. + if (result_is_negative && !resultint.is_zero()) + { + throw invalid_argument("poly must decode to a positive value"); + } + return resultint; + } + + void BalancedEncoder::decode_biguint(const Plaintext &plain, BigUInt &destination) + { + // Determine plain_modulus width. + unsigned long long pos_value; + + // Determine coefficient threshold for negative numbers. + destination.set_zero(); + size_t bits_per_uint64_sz = static_cast(bits_per_uint64); + size_t result_uint64_count = destination.uint64_count(); + size_t result_bit_capacity = result_uint64_count * bits_per_uint64_sz; + bool result_is_negative = false; + uint64_t *result = destination.data(); + + // Quick sanity check + if (!fits_in(result_bit_capacity)) + { + throw logic_error("invalid parameters"); + } + + BigUInt base_uint(static_cast(result_bit_capacity)); + BigUInt temp_result(static_cast(result_bit_capacity)); + base_uint = base_; + + for (size_t bit_index = plain.significant_coeff_count(); bit_index--; ) + { + unsigned long long coeff = plain[bit_index]; + + // Multiply result by base, failing if highest bit set. + if (is_bit_set_uint(result, result_uint64_count, + safe_cast(result_bit_capacity) - 1)) + { + throw invalid_argument("plain does not represent a valid plaintext polynomial"); + } + set_uint_uint(result, result_uint64_count, temp_result.data()); + multiply_truncate_uint_uint(temp_result.data(), base_uint.data(), + result_uint64_count, result); + + // Get sign/magnitude of coefficient. + if (coeff >= plain_modulus_.value()) + { + // Coefficient has more bits than plain_modulus. + throw invalid_argument("plain does not represent a valid plaintext polynomial"); + } + bool coeff_is_negative = coeff >= coeff_neg_threshold_; + const unsigned long long *pos_pointer; + if (coeff_is_negative) + { + if (sub_uint64(plain_modulus_.value(), coeff, 0, &pos_value)) + { + // Check for borrow, which means value is greater than plain_modulus. + throw invalid_argument("plain does not represent a valid plaintext polynomial"); + } + pos_pointer = &pos_value; + } + else + { + pos_pointer = &coeff; + } + + // Add or subtract-in coefficient. + if (result_is_negative == coeff_is_negative) + { + // Result and coefficient have same signs so add. + if (add_uint_uint64(result, *pos_pointer, result_uint64_count, result)) + { + // Add produced a carry that didn't fit. + throw invalid_argument("output out of range"); + } + } + else + { + // Result and coefficient have opposite signs so subtract. + if (sub_uint_uint64(result, *pos_pointer, result_uint64_count, result)) + { + // Subtraction produced a borrow so coefficient is larger (in magnitude) than result, + // so need to negate result. + negate_uint(result, result_uint64_count, result); + result_is_negative = !result_is_negative; + } + } + } + + // Verify result is non-negative. + if (result_is_negative && !destination.is_zero()) + { + throw invalid_argument("poly must decode to a positive value"); + } + + // Verify result fits in actual bit-width (as opposed to capacity) of destination. + if (destination.significant_bit_count() > destination.bit_count()) + { + throw invalid_argument("output out of range"); + } + } + + BinaryFractionalEncoder::BinaryFractionalEncoder( + const SmallModulus &plain_modulus, + size_t poly_modulus_degree, size_t integer_coeff_count, + size_t fraction_coeff_count) : + encoder_(plain_modulus), + fraction_coeff_count_(fraction_coeff_count), + integer_coeff_count_(integer_coeff_count), + poly_modulus_degree_(poly_modulus_degree) + { + if (integer_coeff_count <= 0) + { + throw invalid_argument("integer_coeff_count must be positive"); + } + if (fraction_coeff_count <= 0) + { + throw invalid_argument("fraction_coeff_count must be positive"); + } + if (poly_modulus_degree_ < SEAL_POLY_MOD_DEGREE_MIN || + poly_modulus_degree_ > SEAL_POLY_MOD_DEGREE_MAX) + { + throw invalid_argument("poly_modulus_degree is invalid"); + } + if (add_safe(integer_coeff_count_, fraction_coeff_count_) > poly_modulus_degree_) + { + throw invalid_argument("integer/fractional parts are too large for poly_modulus_degree"); + } + } + + Plaintext BinaryFractionalEncoder::encode(double value) + { + // Take care of the integral part + int64_t value_int = safe_cast(value); + Plaintext encoded_int; + encoder_.encode(value_int, encoded_int); + value -= static_cast(value_int); + + // If the fractional part is zero, return encoded_int + if (value == 0) + { + return encoded_int; + } + + bool is_negative = value < 0; + + //Extract the fractional part + Plaintext encoded_fract(poly_modulus_degree_); + for (size_t i = 0; i < fraction_coeff_count_; i++) + { + value *= 2; + value_int = safe_cast(value); + value -= static_cast(value_int); + + // We want to encode the least significant bit of value_int to the least + // significant bit of encoded_fract. First set it to 1 if it is to be set + // at all. Later we will negate them all if the number was negative. + encoded_fract[0] = static_cast(value_int & 1); + + // Shift encoded_fract by one coefficient unless we are at the last coefficient + if (i < fraction_coeff_count_ - 1) + { + left_shift_uint(encoded_fract.data(), bits_per_uint64, + poly_modulus_degree_, encoded_fract.data()); + } + } + + // We negate the coefficients only if the number was NOT negative. + // This is because the coefficients will have to be negated in any case (sign changes + // at "wrapping around" the polynomial modulus). + if (!is_negative) + { + for (size_t i = 0; i < fraction_coeff_count_; i++) + { + if (encoded_fract[i] != 0) + { + encoded_fract[i] = encoder_.neg_one_; + } + } + } + + // Shift the fractional part to top of polynomial + left_shift_uint(encoded_fract.data(), mul_safe(bits_per_uint64, + safe_cast(poly_modulus_degree_ - fraction_coeff_count_)), + poly_modulus_degree_, encoded_fract.data()); + + // Combine everything together + set_uint_uint(encoded_int.data(), encoded_int.coeff_count(), encoded_fract.data()); + + return encoded_fract; + } + + double BinaryFractionalEncoder::decode(const Plaintext &plain) + { + // Validate input parameters + if (plain.coeff_count() > poly_modulus_degree_) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#ifdef SEAL_DEBUG + if (!are_poly_coefficients_less_than(plain.data(), + plain.coeff_count(), 1, encoder_.plain_modulus_.data(), + encoder_.plain_modulus().uint64_count())) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#endif + // Do we have an empty plaintext + if (plain.coeff_count() == 0) + { + return 0; + } + + // plain might be smaller than expected if leading coefficients are missing + auto plain_copy(allocate_zero_uint(poly_modulus_degree_, pool_)); + set_uint_uint(plain.data(), plain.coeff_count(), plain_copy.get()); + + // Extract the fractional and integral parts + Plaintext encoded_int(integer_coeff_count_); + auto encoded_fract(allocate_zero_uint(fraction_coeff_count_, pool_)); + + // Integer part + set_uint_uint(plain_copy.get(), integer_coeff_count_, encoded_int.data()); + + // Read from the top of the poly all the way to the top of the integral part + // to obtain the fractional part + set_uint_uint(plain_copy.get() + poly_modulus_degree_ - fraction_coeff_count_, + fraction_coeff_count_, encoded_fract.get()); + + // Decode integral part + int64_t integral_part = encoder_.decode_int64(encoded_int); + + // Decode fractional part (or rather negative of it), one coefficient at a time + double fractional_part = 0; + Plaintext temp_int_part(1); + for (size_t i = 0; i < fraction_coeff_count_; i++) + { + temp_int_part[0] = encoded_fract[i]; + fractional_part += static_cast(encoder_.decode_int64(temp_int_part)); + fractional_part /= 2; + } + + return static_cast(integral_part) - fractional_part; + } + + BalancedFractionalEncoder::BalancedFractionalEncoder( + const SmallModulus &plain_modulus, + size_t poly_modulus_degree, size_t integer_coeff_count, + size_t fraction_coeff_count, uint64_t base) : + encoder_(plain_modulus, base), + fraction_coeff_count_(fraction_coeff_count), + integer_coeff_count_(integer_coeff_count), + poly_modulus_degree_(poly_modulus_degree) + { + if (integer_coeff_count == 0) + { + throw invalid_argument("integer_coeff_count must be positive"); + } + if (fraction_coeff_count == 0) + { + throw invalid_argument("fraction_coeff_count must be positive"); + } + if (poly_modulus_degree_ < SEAL_POLY_MOD_DEGREE_MIN || + poly_modulus_degree_ > SEAL_POLY_MOD_DEGREE_MAX) + { + throw invalid_argument("poly_modulus_degree is invalid"); + } + if (add_safe(integer_coeff_count_, fraction_coeff_count_) > poly_modulus_degree_) + { + throw invalid_argument("integer/fractional parts are too large for poly_modulus_degree"); + } + } + + // We encode differently based on whether the base is odd or even. + Plaintext BalancedFractionalEncoder::encode(double value) + { + if (encoder_.base_ & 1) + { + return encode_odd(value); + } + else + { + return encode_even(value); + } + } + + Plaintext BalancedFractionalEncoder::encode_odd(double value) + { + // Take care of the integral part + int64_t value_int = safe_cast(round(value)); + Plaintext encoded_int; + encoder_.encode(value_int, encoded_int); + value -= static_cast(value_int); + + // If the fractional part is zero, return encoded_int + if (value == 0) + { + return encoded_int; + } + + // Extract the fractional part + Plaintext encoded_fract(poly_modulus_degree_); + for (size_t i = 0; i < fraction_coeff_count_; i++) + { + value *= static_cast(encoder_.base()); + + // When computing the next value_int we need to round e.g. 0.5 to 0 (not to 1) and + // -0.5 to 0 (not to -1), i.e. always towards zero. + int sign = (value >= 0 ? 1 : -1); + value_int = safe_cast(sign * ceil(abs(value) - 0.5)); + value -= static_cast(value_int); + + // We store the representative of value_int modulo the base (symmetric representative) + // as the absolute value (in value_int_mod_base) and as the sign (in is_negative). + bool is_negative = false; + + if (value_int < 0) + { + is_negative = true; + value_int = -value_int; + } + + // Set the constant coefficient of encoded_fract to be the correct absolute value. + encoded_fract[0] = static_cast(value_int); + // And negate it modulo plain_modulus if it was NOT supposed to be negative, because the + // fractional encoding requires the signs of the fractional coefficients to be negatives of + // what one might naively expect, as they change sign when "wrapping around" the polynomial modulus. + if (!is_negative && value_int != 0) + { + encoded_fract[0] = encoder_.plain_modulus_.value() - encoded_fract[0]; + } + + // Shift encoded_fract by one coefficient unless we are at the last coefficient + if (i < fraction_coeff_count_ - 1) + { + left_shift_uint(encoded_fract.data(), bits_per_uint64, + poly_modulus_degree_, encoded_fract.data()); + } + } + + // Shift the fractional part to top of polynomial + left_shift_uint(encoded_fract.data(), mul_safe(bits_per_uint64, + safe_cast(poly_modulus_degree_ - fraction_coeff_count_)), + poly_modulus_degree_, encoded_fract.data()); + + // Combine everything together + set_uint_uint(encoded_int.data(), encoded_int.coeff_count(), encoded_fract.data()); + + return encoded_fract; + } + + Plaintext BalancedFractionalEncoder::encode_even(double value) + { + // Take care of the integral part + int64_t value_int = safe_cast(round(value)); + + // We store the integral part for further use, since we may end up changing the integral + // part based on our encoding of the fractional part + int64_t initial = value_int; + + Plaintext encoded_int(poly_modulus_degree_); + encoder_.encode(value_int, encoded_int); + value -= static_cast(value_int); + + // If the fractional part is zero, return encoded_int + if (value == 0) + { + return encoded_int; + } + + // Extract the fractional part + // We will first compute the balanced base b encoding of the fractional part, allowing + // coefficients in the range -b/2, ..., b/2. We use Pointer carry to mark the coefficients + // that are equal to b/2, and we use Pointer is_less_than_neg_one to mark the coefficients + // that are less than -1 (we need this because when we encounter a coefficient greater than + // or equal to b/2, we need to store base - coefficient instead and add 1 to the coefficient + // to the left, which might change the sign of the coefficient to the left). + + Plaintext encoded_fract(poly_modulus_degree_); + auto carry(allocate_zero_uint(poly_modulus_degree_, pool_)); + auto is_less_than_neg_one(allocate_zero_uint(poly_modulus_degree_, pool_)); + auto is_negative(allocate_zero_uint(poly_modulus_degree_, pool_)); + + for (size_t i = 0; i < fraction_coeff_count_; i++) + { + value *= static_cast(encoder_.base()); + + // When computing the next value_int we need to round e.g. 0.5 to 0 (not to 1) and + // -0.5 to 0 (not to -1), i.e. always towards zero. + int sign = (value >= 0 ? 1 : -1); + value_int = safe_cast(sign * ceil(abs(value) - 0.5)); + value -= static_cast(value_int); + + // Set the constant coefficients of carry, is_less_than_neg_one, is_negative and + // encoded_fract to be the correct values. + if ((static_cast(abs(value_int)) >= encoder_.base_ / 2) + && (value_int >= 0)) + { + carry[0] = uint64_t(1); + } + if (value_int < -1) + { + is_less_than_neg_one[0] = uint64_t(1); + } + if (value_int < 0) + { + is_negative[0] = uint64_t(1); + value_int = -value_int; + } + + // Set the constant coefficient of encoded_fract to be the correct absolute value. + encoded_fract[0] = static_cast(value_int); + + // Shift all the polynomials by one coefficient unless we are at the last coefficient + if (i < fraction_coeff_count_ - 1) + { + left_shift_uint(encoded_fract.data(), bits_per_uint64, poly_modulus_degree_, + encoded_fract.data()); + left_shift_uint(carry.get(), bits_per_uint64, poly_modulus_degree_, carry.get()); + left_shift_uint(is_less_than_neg_one.get(), bits_per_uint64, poly_modulus_degree_, + is_less_than_neg_one.get()); + left_shift_uint(is_negative.get(), bits_per_uint64, poly_modulus_degree_, + is_negative.get()); + } + } + + uint64_t *encoded_fract_ptr = encoded_fract.data(); + uint64_t *is_negative_ptr = is_negative.get(); + uint64_t base_div_two = encoder_.base_ / 2; + + // Now we get rid of those coefficients that are greater than or equal to base / 2 + for (size_t i = 0; i < fraction_coeff_count_ - 1; i++) + { + if (carry[i] != 0) + { + // Set the sign of the current coefficient to be negative + is_negative[i] = uint64_t(1); + + // Store base - current coefficient + *encoded_fract_ptr = encoder_.base_ - *encoded_fract_ptr; + + // Add 1 to the coefficient to the left. Update the carry entry for the coefficient + // to the left. + if (is_negative[i + 1] == 0) + { + encoded_fract_ptr[1]++; + } + else + { + encoded_fract_ptr[1]--; + + // Update the sign of the coefficient to the left if needed + if (!is_less_than_neg_one[i + 1]) + { + is_negative[i + 1] = 0; + } + } + + if (encoded_fract_ptr[1] >= base_div_two) + // if (is_greater_than_or_equal_uint_uint(encoded_fract_ptr + plain_uint64_count, + // &base_div_two, plain_uint64_count)) + { + carry[i + 1] = uint64_t(1); + } + } + + encoded_fract_ptr++; + is_negative_ptr++; + } + + // Do we need to change the integral part? + bool change_int = (carry[fraction_coeff_count_ - 1] != 0); + if (change_int) + { + *encoded_fract_ptr = encoder_.base_ - *encoded_fract_ptr; + is_negative[fraction_coeff_count_ - 1] = uint64_t(1); + } + + // And negate it modulo plain_modulus if it was NOT supposed to be negative, because the + // fractional encoding requires the signs of the fractional coefficients to be negatives of + // what one might naively expect, as they change sign when "wrapping around" the polynomial modulus. + for (size_t i = fraction_coeff_count_; i--; encoded_fract_ptr--) + { + if ((!is_negative[i]) && (encoded_fract[i] != 0)) + { + encoded_fract_ptr[0] = encoder_.plain_modulus_.value() - encoded_fract_ptr[0]; + } + } + + // Shift the fractional part to top of polynomial + left_shift_uint(encoded_fract.data(), mul_safe(bits_per_uint64, + safe_cast(poly_modulus_degree_ - fraction_coeff_count_)), + poly_modulus_degree_, encoded_fract.data()); + + // If change_int is true, then we need to add 1 to the integral part and re-encode it. + if (change_int) + { + encoder_.encode(initial + 1, encoded_int); + } + + // Combine everything together + set_uint_uint(encoded_int.data(), encoded_int.coeff_count(), encoded_fract.data()); + + return encoded_fract; + } + + double BalancedFractionalEncoder::decode(const Plaintext &plain) + { + // Validate input parameters + if (plain.coeff_count() > poly_modulus_degree_) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#ifdef SEAL_DEBUG + if (!are_poly_coefficients_less_than(plain.data(), + plain.coeff_count(), 1, encoder_.plain_modulus_.data(), + encoder_.plain_modulus_.uint64_count())) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#endif + // plain might be smaller than expected if leading coefficients are missing + auto plain_copy(allocate_zero_uint(poly_modulus_degree_, pool_)); + set_uint_uint(plain.data(), plain.coeff_count(), plain_copy.get()); + + // Extract the fractional and integral parts + Plaintext encoded_int(integer_coeff_count_); + auto encoded_fract(allocate_zero_uint(fraction_coeff_count_, pool_)); + + // Integer part + set_uint_uint(plain_copy.get(), integer_coeff_count_, encoded_int.data()); + + // Read from the top of the poly all the way to the top of the integral part to obtain the fractional part + set_uint_uint(plain_copy.get() + poly_modulus_degree_ - fraction_coeff_count_, + fraction_coeff_count_, encoded_fract.get()); + + // Decode integral part + int64_t integral_part = encoder_.decode_int64(encoded_int); + + // Decode fractional part (or rather negative of it), one coefficient at a time + double fractional_part = 0; + Plaintext temp_int_part(1); + for (size_t i = 0; i < fraction_coeff_count_; i++) + { + temp_int_part[0] = encoded_fract[i]; + fractional_part += static_cast(encoder_.decode_int64(temp_int_part)); + fractional_part /= static_cast(encoder_.base()); + } + + return static_cast(integral_part) - fractional_part; + } + + IntegerEncoder::IntegerEncoder(const SmallModulus &plain_modulus, uint64_t base) + { + if (base == 2) + { + encoder_ = new BinaryEncoder(plain_modulus); + } + else + { + encoder_ = new BalancedEncoder(plain_modulus, base); + } + } + + IntegerEncoder::IntegerEncoder(const IntegerEncoder ©) + { + if (copy.base() == 2) + { + encoder_ = new BinaryEncoder(*dynamic_cast(copy.encoder_)); + } + else + { + encoder_ = new BalancedEncoder(*dynamic_cast(copy.encoder_)); + } + } + + IntegerEncoder::~IntegerEncoder() + { + if (encoder_ != nullptr) + { + delete encoder_; + encoder_ = nullptr; + } + } + + void IntegerEncoder::encode(uint64_t value, Plaintext &destination) + { + encoder_->encode(value, destination); + + // Resize to correct size + destination.resize(destination.significant_coeff_count()); + } + + void IntegerEncoder::encode(int64_t value, Plaintext &destination) + { + encoder_->encode(value, destination); + + // Resize to correct size + destination.resize(destination.significant_coeff_count()); + } + + void IntegerEncoder::encode(const BigUInt &value, Plaintext &destination) + { + encoder_->encode(value, destination); + + // Resize to correct size + destination.resize(destination.significant_coeff_count()); + } + + void IntegerEncoder::encode(int32_t value, Plaintext &destination) + { + encoder_->encode(value, destination); + + // Resize to correct size + destination.resize(destination.significant_coeff_count()); + } + + void IntegerEncoder::encode(uint32_t value, Plaintext &destination) + { + encoder_->encode(value, destination); + + // Resize to correct size + destination.resize(destination.significant_coeff_count()); + } + + FractionalEncoder::FractionalEncoder(const SmallModulus &plain_modulus, + size_t poly_modulus_degree, size_t integer_coeff_count, size_t fraction_coeff_count, + uint64_t base) + { + if (base == 2) + { + encoder_ = new BinaryFractionalEncoder(plain_modulus, poly_modulus_degree, + integer_coeff_count, fraction_coeff_count); + } + else + { + encoder_ = new BalancedFractionalEncoder(plain_modulus, poly_modulus_degree, + integer_coeff_count, fraction_coeff_count, base); + } + } + + FractionalEncoder::FractionalEncoder(const FractionalEncoder ©) + { + if (copy.base() == 2) + { + encoder_ = new BinaryFractionalEncoder( + *dynamic_cast(copy.encoder_)); + } + else + { + encoder_ = new BalancedFractionalEncoder( + *dynamic_cast(copy.encoder_)); + } + } + + FractionalEncoder::~FractionalEncoder() + { + if (encoder_ != nullptr) + { + delete encoder_; + encoder_ = nullptr; + } + } +} diff --git a/src/seal/encoder.h b/src/seal/encoder.h new file mode 100644 index 000000000..2a4bcb75c --- /dev/null +++ b/src/seal/encoder.h @@ -0,0 +1,1376 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include "seal/context.h" +#include "seal/biguint.h" +#include "seal/plaintext.h" +#include "seal/smallmodulus.h" +#include "seal/memorymanager.h" + +namespace seal +{ + // Abstract base class for integer encoders. + class AbstractIntegerEncoder + { + public: + virtual ~AbstractIntegerEncoder() = default; + + virtual Plaintext encode(std::uint64_t value) = 0; + + virtual void encode(std::uint64_t value, Plaintext &destination) = 0; + + virtual std::uint32_t decode_uint32(const Plaintext &plain) = 0; + + virtual std::uint64_t decode_uint64(const Plaintext &plain) = 0; + + virtual Plaintext encode(std::int64_t value) = 0; + + virtual void encode(std::int64_t value, Plaintext &destination) = 0; + + virtual Plaintext encode(const BigUInt &value) = 0; + + virtual void encode(const BigUInt &value, Plaintext &destination) = 0; + + virtual std::int32_t decode_int32(const Plaintext &plain) = 0; + + virtual std::int64_t decode_int64(const Plaintext &plain) = 0; + + virtual BigUInt decode_biguint(const Plaintext &plain) = 0; + + virtual void decode_biguint(const Plaintext &plain, BigUInt &destination) = 0; + + virtual Plaintext encode(std::int32_t value) = 0; + + virtual Plaintext encode(std::uint32_t value) = 0; + + virtual void encode(std::int32_t value, Plaintext &destination) = 0; + + virtual void encode(std::uint32_t value, Plaintext &destination) = 0; + + virtual const SmallModulus &plain_modulus() const = 0; + + virtual std::uint64_t base() const = 0; + + private: + }; + + // Abstract base class for fractional encoders. + class AbstractFractionalEncoder + { + public: + virtual ~AbstractFractionalEncoder() = default; + + virtual Plaintext encode(double value) = 0; + + virtual double decode(const Plaintext &plain) = 0; + + virtual const SmallModulus &plain_modulus() const = 0; + + virtual std::size_t poly_modulus_degree() const = 0; + + virtual std::size_t fraction_coeff_count() const = 0; + + virtual std::size_t integer_coeff_count() const = 0; + + virtual std::uint64_t base() const = 0; + + private: + }; + + /** + Encodes integers into plaintext polynomials that Encryptor can encrypt. An instance of + the BinaryEncoder class converts an integer into a plaintext polynomial by placing its + binary digits as the coefficients of the polynomial. Decoding the integer amounts to + evaluating the plaintext polynomial at X=2. + + Addition and multiplication on the integer side translate into addition and multiplication + on the encoded plaintext polynomial side, provided that the length of the polynomial + never grows to be of the size of the polynomial modulus (poly_modulus), and that the + coefficients of the plaintext polynomials appearing throughout the computations never + experience coefficients larger than the plaintext modulus (plain_modulus). + + @par Negative Integers + Negative integers are represented by using -1 instead of 1 in the binary representation, + and the negative coefficients are stored in the plaintext polynomials as unsigned integers + that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 + would be stored as a polynomial coefficient plain_modulus-1. + + @see BinaryFractionalEncoder for encoding real numbers. + @see BalancedEncoder for encoding using base-b representation for b greater than 2. + @see IntegerEncoder for a common interface to all integer encoders. + */ + class BinaryEncoder : public AbstractIntegerEncoder + { + public: + /** + Creates a BinaryEncoder object. The constructor takes as input a reference + to the plaintext modulus (represented by SmallModulus). + + @param[in] plain_modulus The plaintext modulus (represented by SmallModulus) + @throws std::invalid_argument if plain_modulus is not at least 2 + */ + BinaryEncoder(const SmallModulus &plain_modulus); + + /** + Creates a copy of a BinaryEncoder. + + @param[in] copy The BinaryEncoder to copy from + */ + BinaryEncoder(const BinaryEncoder ©) = default; + + /** + Creates a new BinaryEncoder by moving an old one. + + @param[in] source The BinaryEncoder to move from + */ + BinaryEncoder(BinaryEncoder &&source) = default; + + /** + Destroys the BinaryEncoder. + */ + virtual ~BinaryEncoder() override + { + } + + /** + Encodes an unsigned integer (represented by std::uint64_t) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + */ + virtual Plaintext encode(std::uint64_t value) override; + + /** + Encodes an unsigned integer (represented by std::uint64_t) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + @param[out] destination The plaintext to overwrite with the encoding + */ + virtual void encode(std::uint64_t value, Plaintext &destination) override; + + /** + Decodes a plaintext polynomial and returns the result as std::uint32_t. + Mathematically this amounts to evaluating the input polynomial at X=2. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if the output does not fit in std::uint32_t + */ + virtual std::uint32_t decode_uint32(const Plaintext &plain) override; + + /** + Decodes a plaintext polynomial and returns the result as std::uint64_t. + Mathematically this amounts to evaluating the input polynomial at X=2. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if the output does not fit in std::uint64_t + */ + virtual std::uint64_t decode_uint64(const Plaintext &plain) override; + + /** + Encodes a signed integer (represented by std::uint64_t) into a plaintext polynomial. + + @par Negative Integers + Negative integers are represented by using -1 instead of 1 in the binary representation, + and the negative coefficients are stored in the plaintext polynomials as unsigned integers + that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 + would be stored as a polynomial coefficient plain_modulus-1. + + @param[in] value The signed integer to encode + */ + virtual Plaintext encode(std::int64_t value) override; + + /** + Encodes a signed integer (represented by std::int64_t) into a plaintext polynomial. + + @par Negative Integers + Negative integers are represented by using -1 instead of 1 in the binary representation, + and the negative coefficients are stored in the plaintext polynomials as unsigned integers + that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 + would be stored as a polynomial coefficient plain_modulus-1. + + @param[in] value The signed integer to encode + @param[out] destination The plaintext to overwrite with the encoding + */ + virtual void encode(std::int64_t value, Plaintext &destination) override; + + /** + Encodes an unsigned integer (represented by BigUInt) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + */ + virtual Plaintext encode(const BigUInt &value) override; + + /** + Encodes an unsigned integer (represented by BigUInt) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + @param[out] destination The plaintext to overwrite with the encoding + */ + virtual void encode(const BigUInt &value, Plaintext &destination) override; + + /** + Decodes a plaintext polynomial and returns the result as std::int32_t. + Mathematically this amounts to evaluating the input polynomial at X=2. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the output does not fit in std::int32_t + */ + virtual std::int32_t decode_int32(const Plaintext &plain) override; + + /** + Decodes a plaintext polynomial and returns the result as std::int64_t. + Mathematically this amounts to evaluating the input polynomial at X=2. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the output does not fit in std::int64_t + */ + virtual std::int64_t decode_int64(const Plaintext &plain) override; + + /** + Decodes a plaintext polynomial and returns the result as BigUInt. + Mathematically this amounts to evaluating the input polynomial at X=2. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the output is negative + */ + virtual BigUInt decode_biguint(const Plaintext &plain) override; + + /** + Decodes a plaintext polynomial and stores the result in a given BigUInt. + Mathematically this amounts to evaluating the input polynomial at X=2. + + @param[in] plain The plaintext to be decoded + @param[out] destination The BigUInt to overwrite with the decoding + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the output does not fit in destination + @throws std::invalid_argument if the output is negative + */ + virtual void decode_biguint(const Plaintext &plain, BigUInt &destination) override; + + /** + Encodes a signed integer (represented by std::int32_t) into a plaintext polynomial. + + @par Negative Integers + Negative integers are represented by using -1 instead of 1 in the binary representation, + and the negative coefficients are stored in the plaintext polynomials as unsigned integers + that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 + would be stored as a polynomial coefficient plain_modulus-1. + + @param[in] value The signed integer to encode + */ + virtual Plaintext encode(std::int32_t value) override + { + return encode(static_cast(value)); + } + + /** + Encodes an unsigned integer (represented by std::uint32_t) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + */ + virtual Plaintext encode(std::uint32_t value) override + { + return encode(static_cast(value)); + } + + /** + Encodes a signed integer (represented by std::int32_t) into a plaintext polynomial. + + @par Negative Integers + Negative integers are represented by using -1 instead of 1 in the binary representation, + and the negative coefficients are stored in the plaintext polynomials as unsigned integers + that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 + would be stored as a polynomial coefficient plain_modulus-1. + + @param[in] value The signed integer to encode + @param[out] destination The plaintext to overwrite with the encoding + */ + virtual void encode(std::int32_t value, Plaintext &destination) override + { + encode(static_cast(value), destination); + } + + /** + Encodes an unsigned integer (represented by std::uint32_t) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + @param[out] destination The plaintext to overwrite with the encoding + */ + virtual void encode(std::uint32_t value, Plaintext &destination) override + { + encode(static_cast(value), destination); + } + + /** + Returns a reference to the plaintext modulus. + */ + virtual const SmallModulus &plain_modulus() const override + { + return plain_modulus_; + } + + /** + Returns the base used for encoding (2). + */ + virtual std::uint64_t base() const override + { + return 2; + } + + private: + BinaryEncoder &operator =(const BinaryEncoder &assign) = delete; + + BinaryEncoder &operator =(BinaryEncoder &&assign) = delete; + + MemoryPoolHandle pool_ = MemoryManager::GetPool(); + + SmallModulus plain_modulus_; + + std::uint64_t coeff_neg_threshold_; + + std::uint64_t neg_one_; + + friend class BinaryFractionalEncoder; + }; + + /** + Encodes integers into plaintext polynomials that Encryptor can encrypt. An instance of + the BalancedEncoder class converts an integer into a plaintext polynomial by placing its + digits in balanced base-b representation as the coefficients of the polynomial. The base + b must be a positive integer at least 3 (which is the default value). When b is odd, + digits in such a balanced representation are integers in the range + -(b-1)/2,...,(b-1)/2. When b is even, digits are integers in the range -b/2,..., b/2-1. + Note that the default value 3 for the base b allows for more compact representation than + BinaryEncoder without increasing the sizes of the coefficients of freshly encoded plaintext + polynomials. A larger base allows for an even more compact representation at the cost of + having larger coefficients in freshly encoded plaintext polynomials. Decoding the integer + amounts to evaluating the plaintext polynomial at X=b. + + Addition and multiplication on the integer side translate into addition and multiplication + on the encoded plaintext polynomial side, provided that the length of the polynomial + never grows to be of the size of the polynomial modulus (poly_modulus), and that the + coefficients of the plaintext polynomials appearing throughout the computations never + experience coefficients larger than the plaintext modulus (plain_modulus). + + @par Negative Integers + Negative integers in the balanced base-b encoding are represented the same way as + positive integers, namely, both positive and negative integers can have both positive and negative + digits in their balanced base-b representation. Negative coefficients are stored in the + plaintext polynomials as unsigned integers that represent them modulo the plaintext modulus. + Thus, for example, a coefficient of -1 would be stored as a polynomial coefficient plain_modulus-1. + + @see BalancedFractionalEncoder for encoding real numbers. + @see BinaryEncoder for encoding using the binary representation. + @see IntegerEncoder for a common interface to all integer encoders. + */ + class BalancedEncoder : public AbstractIntegerEncoder + { + public: + /** + Creates a BalancedEncoder object. The constructor takes as input a reference + to the plaintext modulus (represented by SmallModulus), and optionally an integer, + at least 3, that is used as a base in the encoding. + + @param[in] plain_modulus The plaintext modulus (represented by SmallModulus) + @param[in] base The base to be used for encoding (default value is 3) + @throws std::invalid_argument if base is not an integer and at least 3 + @throws std::invalid_argument if plain_modulus is not at least base + */ + BalancedEncoder(const SmallModulus &plain_modulus, std::uint64_t base = 3); + + /** + Creates a copy of a BalancedEncoder. + + @param[in] copy The BalancedEncoder to copy from + */ + BalancedEncoder(const BalancedEncoder ©) = default; + + /** + Creates a new BalancedEncoder by moving an old one. + + @param[in] source The BalancedEncoder to move from + */ + BalancedEncoder(BalancedEncoder &&source) = default; + + /** + Destroys the BalancedEncoder. + */ + virtual ~BalancedEncoder() + { + } + + /** + Encodes an unsigned integer (represented by std::uint64_t) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + */ + virtual Plaintext encode(std::uint64_t value) override; + + /** + Encodes an unsigned integer (represented by std::uint64_t) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + @param[out] destination The plaintext to overwrite with the encoding + */ + virtual void encode(std::uint64_t value, Plaintext &destination) override; + + /** + Decodes a plaintext polynomial and returns the result as std::uint32_t. + Mathematically this amounts to evaluating the input polynomial at X=base. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the output does not fit in std::uint32_t + */ + virtual std::uint32_t decode_uint32(const Plaintext &plain) override; + + /** + Decodes a plaintext polynomial and returns the result as std::uint64_t. + Mathematically this amounts to evaluating the input polynomial at X=base. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the output does not fit in std::uint64_t + */ + virtual std::uint64_t decode_uint64(const Plaintext &plain) override; + + /** + Encodes a signed integer (represented by std::int64_t) into a plaintext polynomial. + + @par Negative Integers + Negative integers in the balanced base-b encoding are represented the same way as + positive integers, namely, both positive and negative integers can have both positive and negative + digits in their balanced base-b representation. Negative coefficients are stored in the + plaintext polynomials as unsigned integers that represent them modulo the plaintext modulus. + Thus, for example, a coefficient of -1 would be stored as a polynomial coefficient plain_modulus-1. + + @param[in] value The signed integer to encode + */ + virtual Plaintext encode(std::int64_t value) override; + + /** + Encodes a signed integer (represented by std::int64_t) into a plaintext polynomial. + + @par Negative Integers + Negative integers in the balanced base-b encoding are represented the same way as + positive integers, namely, both positive and negative integers can have both positive and negative + digits in their balanced base-b representation. Negative coefficients are stored in the + plaintext polynomials as unsigned integers that represent them modulo the plaintext modulus. + Thus, for example, a coefficient of -1 would be stored as a polynomial coefficient plain_modulus-1. + + @param[in] value The signed integer to encode + @param[out] destination The plaintext to overwrite with the encoding + */ + virtual void encode(std::int64_t value, Plaintext &destination) override; + + /** + Encodes an unsigned integer (represented by BigUInt) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + */ + virtual Plaintext encode(const BigUInt &value) override; + + /** + Encodes an unsigned integer (represented by BigUInt) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + @param[out] destination The plaintext to overwrite with the encoding + */ + virtual void encode(const BigUInt &value, Plaintext &destination) override; + + /** + Decodes a plaintext polynomial and returns the result as std::int32_t. + Mathematically this amounts to evaluating the input polynomial at X=base. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the output does not fit in std::int32_t + */ + virtual std::int32_t decode_int32(const Plaintext &plain) override; + + /** + Decodes a plaintext polynomial and returns the result as std::int64_t. + Mathematically this amounts to evaluating the input polynomial at X=base. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the output does not fit in std::int64_t + */ + virtual std::int64_t decode_int64(const Plaintext &plain) override; + + /** + Decodes a plaintext polynomial and returns the result as BigUInt. + Mathematically this amounts to evaluating the input polynomial at X=base. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the output is negative + */ + virtual BigUInt decode_biguint(const Plaintext &plain) override; + + /** + Decodes a plaintext polynomial and stores the result in a given BigUInt. + Mathematically this amounts to evaluating the input polynomial at X=base. + + @param[in] plain The plaintext to be decoded + @param[out] destination The BigUInt to overwrite with the decoding + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the output does not fit in destination + @throws std::invalid_argument if the output is negative + */ + virtual void decode_biguint(const Plaintext &plain, BigUInt &destination) override; + + /** + Encodes a signed integer (represented by std::int32_t) into a plaintext polynomial. + + @par Negative Integers + Negative integers in the balanced base-b encoding are represented the same way as + positive integers, namely, both positive and negative integers can have both positive and negative + digits in their balanced base-b representation. Negative coefficients are stored in the + plaintext polynomials as unsigned integers that represent them modulo the plaintext modulus. + Thus, for example, a coefficient of -1 would be stored as a polynomial coefficient plain_modulus-1. + + @param[in] value The signed integer to encode + */ + virtual Plaintext encode(std::int32_t value) override + { + return encode(static_cast(value)); + } + + /** + Encodes an unsigned integer (represented by std::uint32_t) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + */ + virtual Plaintext encode(std::uint32_t value) override + { + return encode(static_cast(value)); + } + + /** + Encodes a signed integer (represented by std::int32_t) into a plaintext polynomial. + + @par Negative Integers + Negative integers in the balanced base-b encoding are represented the same way as + positive integers, namely, both positive and negative integers can have both positive and negative + digits in their balanced base-b representation. Negative coefficients are stored in the + plaintext polynomials as unsigned integers that represent them modulo the plaintext modulus. + Thus, for example, a coefficient of -1 would be stored as a polynomial coefficient plain_modulus-1. + + @param[in] value The signed integer to encode + @param[out] destination The plaintext to overwrite with the encoding + */ + virtual void encode(std::int32_t value, Plaintext &destination) override + { + encode(static_cast(value), destination); + } + + /** + Encodes an unsigned integer (represented by std::uint32_t) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + @param[out] destination The plaintext to overwrite with the encoding + */ + virtual void encode(std::uint32_t value, Plaintext &destination) override + { + encode(static_cast(value), destination); + } + + /** + Returns a reference to the plaintext modulus. + */ + virtual const SmallModulus &plain_modulus() const override + { + return plain_modulus_; + } + + /** + Returns the base used for encoding. + */ + virtual std::uint64_t base() const override + { + return base_; + } + + private: + BalancedEncoder &operator =(const BalancedEncoder &assign) = delete; + + BalancedEncoder &operator =(BalancedEncoder &&assign) = delete; + + MemoryPoolHandle pool_ = MemoryManager::GetPool(); + + SmallModulus plain_modulus_; + + std::uint64_t base_; + + std::uint64_t coeff_neg_threshold_; + + friend class BalancedFractionalEncoder; + }; + + /** + Encodes floating point numbers into plaintext polynomials that Encryptor can encrypt. + An instance of the BinaryFractionalEncoder class converts a double-precision floating-point + number into a plaintext polynomial by computing its binary representation, encoding the + integral part as in BinaryEncoder, and the fractional part as the highest degree + terms of the plaintext polynomial, with signs inverted. Decoding the polynomial + back into a double amounts to evaluating the low degree part at X=2, negating the + coefficients of the high degree part and evaluating it at X=1/2. + + Addition and multiplication on the double side translate into addition and multiplication + on the encoded plaintext polynomial side, provided that the integral part never mixes + with the fractional part in the plaintext polynomials, and that the + coefficients of the plaintext polynomials appearing throughout the computations never + experience coefficients larger than the plaintext modulus (plain_modulus). + + @par Integral and Fractional Parts + When homomorphic multiplications are performed, the integral part "grows up" to higher + degree coefficients of the plaintext polynomial space, and the fractional part "grows down" + from the top degree coefficients towards the lower degree coefficients. For decoding to work, + these parts must not interfere with each other. When setting up the BinaryFractionalEncoder, + one must specify how many coefficients of a plaintext polynomial are reserved for the integral + part and how many for the fractional. The sum of these numbers can be at most equal to the + degree of the polynomial modulus minus one. If homomorphic multiplications are performed, it is + also necessary to leave enough room for the fractional part to "grow down". + + @par Negative Integers + Negative integers are represented by using -1 instead of 1 in the binary representation, + and the negative coefficients are stored in the plaintext polynomials as unsigned integers + that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 + would be stored as a polynomial coefficient plain_modulus-1. + + @see BinaryEncoder for encoding integers. + @see BalancedFractionalEncoder for encoding using base-b representation for b greater than 2. + @see FractionalEncoder for a common interface to all fractional encoders. + */ + class BinaryFractionalEncoder : public AbstractFractionalEncoder + { + public: + /** + Creates a new BinaryFractionalEncoder object. The constructor takes as input a reference + to the plaintext modulus, the degree of the polynomial modulus, and the numbers of + coefficients that are reserved for the integral and fractional parts. The coefficients + for the integral part are counted starting from the low-degree end of the polynomial, + and the coefficients for the fractional part are counted from the high-degree end. + + @param[in] plain_modulus The plaintext modulus (represented by SmallModulus) + @param[in] poly_modulus_degree The degree of the polynomial modulus + @param[in] integer_coeff_count The number of polynomial coefficients reserved for the integral part + @param[in] fraction_coeff_count The number of polynomial coefficients reserved for the fractional part + @throws std::invalid_argument if plain_modulus_degree is invalid + @throws std::invalid_argument if integer_coeff_count is not strictly positive + @throws std::invalid_argument if fraction_coeff_count is not strictly positive + @throws std::invalid_argument if poly_modulus_degree is too small for the integral and fractional parts + */ + BinaryFractionalEncoder(const SmallModulus &plain_modulus, std::size_t poly_modulus_degree, + std::size_t integer_coeff_count, std::size_t fraction_coeff_count); + + /** + Creates a copy of a BinaryFractionalEncoder. + + @param[in] copy The BinaryFractionalEncoder to copy from + */ + BinaryFractionalEncoder(const BinaryFractionalEncoder ©) = default; + + /** + Creates a new BinaryFractionalEncoder by moving an old one. + + @param[in] source The BinaryFractionalEncoder to move from + */ + BinaryFractionalEncoder(BinaryFractionalEncoder &&source) = default; + + /** + Destroys the BinaryFractionalEncoder. + */ + virtual ~BinaryFractionalEncoder() + { + } + + /** + Encodes a double precision floating point number into a plaintext polynomial. + + @param[in] value The double-precision floating-point number to encode + */ + virtual Plaintext encode(double value) override; + + /** + Decodes a plaintext polynomial and returns the result as a double-precision + floating-point number. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the integral part does not fit in std::int64_t + */ + virtual double decode(const Plaintext &plain) override; + + /** + Returns a reference to the plaintext modulus. + */ + virtual const SmallModulus &plain_modulus() const override + { + return encoder_.plain_modulus(); + } + + /** + Returns the degree of the polynomial modulus. + */ + virtual std::size_t poly_modulus_degree() const override + { + return poly_modulus_degree_; + } + + /** + Returns the base used for encoding (2). + */ + virtual std::uint64_t base() const override + { + return 2; + } + + /** + Returns the number of coefficients reserved for the fractional part. + */ + virtual std::size_t fraction_coeff_count() const override + { + return fraction_coeff_count_; + } + + /** + Returns the number of coefficients reserved for the integral part. + */ + virtual std::size_t integer_coeff_count() const override + { + return integer_coeff_count_; + } + + private: + BinaryFractionalEncoder &operator =(const BinaryFractionalEncoder &assign) = delete; + + BinaryFractionalEncoder &operator =(BinaryFractionalEncoder &&assign) = delete; + + MemoryPoolHandle pool_ = MemoryManager::GetPool(); + + BinaryEncoder encoder_; + + std::size_t fraction_coeff_count_; + + std::size_t integer_coeff_count_; + + std::size_t poly_modulus_degree_; + }; + + /** + Encodes floating point numbers into plaintext polynomials that Encryptor can encrypt. + An instance of the BalancedFractionalEncoder class converts a double-precision floating-point + number into a plaintext polynomial by computing its balanced base-b representation, encoding the + integral part as in BalancedEncoder, and the fractional part as the highest degree + terms of the plaintext polynomial, with signs inverted. For an even base b, the + coefficients of the polynomial are in the range -b/2,...,b/2-1. Decoding the polynomial back + into a double amounts to evaluating the low degree part at X=b, negating the coefficients + of the high degree part and evaluating it at X=1/b. + + Addition and multiplication on the double side translate into addition and multiplication + on the encoded plaintext polynomial side, provided that the integral part never mixes + with the fractional part in the plaintext polynomials, and that the + coefficients of the plaintext polynomials appearing throughout the computations never + experience coefficients larger than the plaintext modulus (plain_modulus). + + @par Integral and Fractional Parts + When homomorphic multiplications are performed, the integral part "grows up" to higher + degree coefficients of the plaintext polynomial space, and the fractional part "grows down" + from the top degree coefficients towards the lower degree coefficients. For decoding to work, + these parts must not interfere with each other. When setting up the BalancedFractionalEncoder, + one must specify how many coefficients of a plaintext polynomial are reserved for the integral + part and how many for the fractional. The sum of these numbers can be at most equal to the + degree of the polynomial modulus minus one. If homomorphic multiplications are performed, it is + also necessary to leave enough room for the fractional part to "grow down". + + @par Negative Integers + Negative integers in the balanced base-b encoding are represented the same way as + positive integers, namely, both positive and negative integers can have both positive and negative + digits in their balanced base-b representation. Negative coefficients are stored in the + plaintext polynomials as unsigned integers that represent them modulo the plaintext modulus. + Thus, for example, a coefficient of -1 would be stored as a polynomial coefficient plain_modulus-1. + + @see BalancedEncoder for encoding integers. + @see BinaryFractionalEncoder for encoding using the binary representation. + @see FractionalEncoder for a common interface to all fractional encoders. + */ + class BalancedFractionalEncoder : public AbstractFractionalEncoder + { + public: + /** + Creates a new BalancedFractionalEncoder object. The constructor takes as input a reference + to the plaintext modulus, the degree of the polynomial modulus, and the numbers of + coefficients that are reserved for the integral and fractional parts, and optionally + an integer, at least 3, that is used as the base in the encoding. The coefficients for the + integral part are counted starting from the low-degree end of the polynomial, and the + coefficients for the fractional part are counted from the high-degree end. + + @param[in] plain_modulus The plaintext modulus (represented by SmallModulus) + @param[in] poly_modulus_degree The degree of the polynomial modulus + @param[in] integer_coeff_count The number of polynomial coefficients reserved for the integral part + @param[in] fraction_coeff_count The number of polynomial coefficients reserved for the fractional part + @param[in] base The base to be used for encoding (default value is 3) + @throws std::invalid_argument if plain_modulus is not at least base + @throws std::invalid_argument if integer_coeff_count is not strictly positive + @throws std::invalid_argument if fraction_coeff_count is not strictly positive + @throws std::invalid_argument if poly_modulus_degree is too small for the integral and fractional parts + @throws std::invalid_argument if base is not an integer and at least 3 + */ + BalancedFractionalEncoder(const SmallModulus &plain_modulus, std::size_t poly_modulus_degree, + std::size_t integer_coeff_count, std::size_t fraction_coeff_count, std::uint64_t base = 3); + + /** + Creates a copy of a BalancedFractionalEncoder. + + @param[in] copy The BalancedFractionalEncoder to copy from + */ + BalancedFractionalEncoder(const BalancedFractionalEncoder ©) = default; + + /** + Creates a new BalancedFractionalEncoder by moving an old one. + + @param[in] source The BalancedFractionalEncoder to move from + */ + BalancedFractionalEncoder(BalancedFractionalEncoder &&source) = default; + + /** + Destroys the BalancedFractionalEncoder. + */ + virtual ~BalancedFractionalEncoder() + { + } + + /** + Encodes a double precision floating point number into a plaintext polynomial. + + @param[in] value The double-precision floating-point number to encode + */ + virtual Plaintext encode(double value) override; + + /** + Decodes a plaintext polynomial and returns the result as a double-precision + floating-point number. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the integral part does not fit in std::int64_t + */ + virtual double decode(const Plaintext &plain) override; + + /** + Returns a reference to the plaintext modulus. + */ + virtual const SmallModulus &plain_modulus() const override + { + return encoder_.plain_modulus(); + } + + /** + Returns the degree of the polynomial modulus. + */ + virtual std::size_t poly_modulus_degree() const override + { + return poly_modulus_degree_; + } + + /** + Returns the base used for encoding. + */ + virtual std::uint64_t base() const override + { + return encoder_.base(); + } + + /** + Returns the number of coefficients reserved for the fractional part. + */ + virtual std::size_t fraction_coeff_count() const override + { + return fraction_coeff_count_; + } + + /** + Returns the number of coefficients reserved for the integral part. + */ + virtual std::size_t integer_coeff_count() const override + { + return integer_coeff_count_; + } + + private: + BalancedFractionalEncoder &operator =(const BalancedFractionalEncoder &assign) = delete; + + BalancedFractionalEncoder &operator =(BalancedFractionalEncoder &&assign) = delete; + + Plaintext encode_even(double value); + + Plaintext encode_odd(double value); + + MemoryPoolHandle pool_ = MemoryManager::GetPool(); + + BalancedEncoder encoder_; + + std::size_t fraction_coeff_count_; + + std::size_t integer_coeff_count_; + + std::size_t poly_modulus_degree_; + }; + + /** + Encodes integers into plaintext polynomials that Encryptor can encrypt. An instance of + the IntegerEncoder class converts an integer into a plaintext polynomial by placing its + digits in balanced base-b representation as the coefficients of the polynomial. The base + b must be a positive integer at least 2 (which is the default value). When b is odd, + digits in such a balanced representation are integers in the range -(b-1)/2,...,(b-1)/2. + When b is even, digits are integers in the range -b/2,...,b/2-1. When b is 2, the + coefficients are either all non-negative (0 and 1), or all non-positive (0 and -1). A larger + base allows for more compact representation at the cost of having larger coefficients in + freshly encoded plaintext polynomials. Decoding the integer amounts to evaluating the + plaintext polynomial at X=b. + + Addition and multiplication on the integer side translate into addition and multiplication + on the encoded plaintext polynomial side, provided that the length of the polynomial + never grows to be of the size of the polynomial modulus (poly_modulus), and that the + coefficients of the plaintext polynomials appearing throughout the computations never + experience coefficients larger than the plaintext modulus (plain_modulus). + + @par Negative Integers + Negative integers in the base-b encoding are represented the same way as positive integers, + namely, both positive and negative integers can have both positive and negative digits in their + base-b representation. Negative coefficients are stored in the plaintext polynomials as unsigned + integers that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 + would be stored as a polynomial coefficient plain_modulus-1. + + @par BinaryEncoder and BalancedEncoder + Under the hood IntegerEncoder uses either the BinaryEncoder or the BalancedEncoder classes + to do the encoding. The first one is used when the base is 2, and the second one when the + base is at least 3. Currently the BinaryEncoder and BalancedEncoder classes can also be used + directly, but this might change in future releases. + + @see BinaryEncoder for encoding using the binary representation. + @see BalancedEncoder for encoding using base-b representation for b greater than 2. + @see FractionalEncoder for encoding real numbers. + */ + class IntegerEncoder : public AbstractIntegerEncoder + { + public: + /** + Creates an IntegerEncoder object. The constructor takes as input a reference + to the plaintext modulus (represented by SmallModulus), and optionally an integer, + at least 2, that is used as a base in the encoding. + + @param[in] plain_modulus The plaintext modulus (represented by SmallModulus) + @param[in] base The base to be used for encoding (default value is 2) + @throws std::invalid_argument if base is not an integer and at least 2 + @throws std::invalid_argument if plain_modulus is not at least base + */ + IntegerEncoder(const SmallModulus &plain_modulus, std::uint64_t base = 2); + + /** + Creates a copy of a IntegerEncoder. + + @param[in] copy The IntegerEncoder to copy from + */ + IntegerEncoder(const IntegerEncoder ©); + + /** + Creates a new IntegerEncoder by moving an old one. + + @param[in] source The IntegerEncoder to move from + */ + IntegerEncoder(IntegerEncoder &&source) = default; + + /** + Destroys the IntegerEncoder. + */ + virtual ~IntegerEncoder() override; + + /** + Encodes an unsigned integer (represented by std::uint64_t) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + */ + virtual Plaintext encode(std::uint64_t value) override + { + return encoder_->encode(value); + } + + /** + Encodes an unsigned integer (represented by std::uint64_t) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + @param[out] destination The plaintext to overwrite with the encoding + */ + virtual void encode(std::uint64_t value, Plaintext &destination) override; + + /** + Decodes a plaintext polynomial and returns the result as std::uint32_t. + Mathematically this amounts to evaluating the input polynomial at X=base. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the output does not fit in std::uint32_t + */ + virtual std::uint32_t decode_uint32(const Plaintext &plain) override + { + return encoder_->decode_uint32(plain); + } + + /** + Decodes a plaintext polynomial and returns the result as std::uint64_t. + Mathematically this amounts to evaluating the input polynomial at X=base. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the output does not fit in std::uint64_t + */ + virtual std::uint64_t decode_uint64(const Plaintext &plain) override + { + return encoder_->decode_uint64(plain); + } + + /** + Encodes a signed integer (represented by std::int64_t) into a plaintext polynomial. + + @par Negative Integers + Negative integers in the base-b encoding are represented the same way as positive integers, + namely, both positive and negative integers can have both positive and negative digits in their + base-b representation. Negative coefficients are stored in the plaintext polynomials as unsigned + integers that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 + would be stored as a polynomial coefficient plain_modulus-1. + + @param[in] value The signed integer to encode + */ + virtual Plaintext encode(std::int64_t value) override + { + return encoder_->encode(value); + } + + /** + Encodes a signed integer (represented by std::int64_t) into a plaintext polynomial. + + @par Negative Integers + Negative integers in the base-b encoding are represented the same way as positive integers, + namely, both positive and negative integers can have both positive and negative digits in their + base-b representation. Negative coefficients are stored in the plaintext polynomials as unsigned + integers that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 + would be stored as a polynomial coefficient plain_modulus-1. + + @param[in] value The signed integer to encode + @param[out] destination The plaintext to overwrite with the encoding + */ + virtual void encode(std::int64_t value, Plaintext &destination) override; + + /** + Encodes an unsigned integer (represented by BigUInt) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + */ + virtual Plaintext encode(const BigUInt &value) override + { + return encoder_->encode(value); + } + + /** + Encodes an unsigned integer (represented by BigUInt) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + @param[out] destination The plaintext to overwrite with the encoding + */ + virtual void encode(const BigUInt &value, Plaintext &destination) override; + + /** + Decodes a plaintext polynomial and returns the result as std::int32_t. + Mathematically this amounts to evaluating the input polynomial at X=base. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the output does not fit in std::int32_t + */ + virtual std::int32_t decode_int32(const Plaintext &plain) override + { + return encoder_->decode_int32(plain); + } + + /** + Decodes a plaintext polynomial and returns the result as std::int64_t. + Mathematically this amounts to evaluating the input polynomial at X=base. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the output does not fit in std::int64_t + */ + virtual std::int64_t decode_int64(const Plaintext &plain) override + { + return encoder_->decode_int64(plain); + } + + /** + Decodes a plaintext polynomial and returns the result as BigUInt. + Mathematically this amounts to evaluating the input polynomial at X=base. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the output is negative + */ + virtual BigUInt decode_biguint(const Plaintext &plain) override + { + return encoder_->decode_biguint(plain); + } + + /** + Decodes a plaintext polynomial and stores the result in a given BigUInt. + Mathematically this amounts to evaluating the input polynomial at X=base. + + @param[in] plain The plaintext to be decoded + @param[out] destination The BigUInt to overwrite with the decoding + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the output does not fit in destination + @throws std::invalid_argument if the output is negative + */ + virtual void decode_biguint(const Plaintext &plain, BigUInt &destination) override + { + encoder_->decode_biguint(plain, destination); + } + + /** + Encodes a signed integer (represented by std::int32_t) into a plaintext polynomial. + + @par Negative Integers + Negative integers in the base-b encoding are represented the same way as positive integers, + namely, both positive and negative integers can have both positive and negative digits in their + base-b representation. Negative coefficients are stored in the plaintext polynomials as unsigned + integers that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 + would be stored as a polynomial coefficient plain_modulus-1. + + @param[in] value The signed integer to encode + */ + virtual Plaintext encode(std::int32_t value) override + { + return encoder_->encode(value); + } + + /** + Encodes an unsigned integer (represented by std::uint32_t) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + */ + virtual Plaintext encode(std::uint32_t value) override + { + return encoder_->encode(value); + } + + /** + Encodes a signed integer (represented by std::int32_t) into a plaintext polynomial. + + @par Negative Integers + Negative integers in the base-b encoding are represented the same way as positive integers, + namely, both positive and negative integers can have both positive and negative digits in their + base-b representation. Negative coefficients are stored in the plaintext polynomials as unsigned + integers that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 + would be stored as a polynomial coefficient plain_modulus-1. + + @param[in] value The signed integer to encode + @param[out] destination The plaintext to overwrite with the encoding + */ + virtual void encode(std::int32_t value, Plaintext &destination) override; + + /** + Encodes an unsigned integer (represented by std::uint32_t) into a plaintext polynomial. + + @param[in] value The unsigned integer to encode + @param[out] destination The plaintext to overwrite with the encoding + */ + virtual void encode(std::uint32_t value, Plaintext &destination) override; + + /** + Returns a reference to the plaintext modulus. + */ + virtual const SmallModulus &plain_modulus() const override + { + return encoder_->plain_modulus(); + } + + /** + Returns the base used for encoding. + */ + virtual std::uint64_t base() const override + { + return encoder_->base(); + } + + private: + IntegerEncoder &operator =(const IntegerEncoder &assign) = delete; + + IntegerEncoder &operator =(IntegerEncoder &&assign) = delete; + + AbstractIntegerEncoder *encoder_; + }; + + /** + Encodes floating point numbers into plaintext polynomials that Encryptor can encrypt. + An instance of the FractionalEncoder class converts a double-precision floating-point + number into a plaintext polynomial by computing its balanced base-b representation, + encoding the integral part as in IntegerEncoder, and the fractional part as the highest + degree terms of the plaintext polynomial, with signs inverted. For an even base b, the + coefficients of the polynomial are in the range -b/2,...,b/2-1. When b is 2, the + coefficients are either all non-negative (0 and 1), or all non-positive (0 and -1). + Decoding the polynomial back into a double amounts to evaluating the low degree part + at X=b, negating the coefficients of the high degree part and evaluating it at X=1/b. + + Addition and multiplication on the double side translate into addition and multiplication + on the encoded plaintext polynomial side, provided that the integral part never mixes + with the fractional part in the plaintext polynomials, and that the + coefficients of the plaintext polynomials appearing throughout the computations never + experience coefficients larger than the plaintext modulus (plain_modulus). + + @par Integral and Fractional Parts + When homomorphic multiplications are performed, the integral part "grows up" to higher + degree coefficients of the plaintext polynomial space, and the fractional part "grows down" + from the top degree coefficients towards the lower degree coefficients. For decoding to work, + these parts must not interfere with each other. When setting up the BalancedFractionalEncoder, + one must specify how many coefficients of a plaintext polynomial are reserved for the integral + part and how many for the fractional. The sum of these numbers can be at most equal to the + degree of the polynomial modulus minus one. If homomorphic multiplications are performed, it is + also necessary to leave enough room for the fractional part to "grow down". + + @par Negative Integers + Negative integers in the base-b encoding are represented the same way as positive integers, + namely, both positive and negative integers can have both positive and negative digits in their + base-b representation. Negative coefficients are stored in the plaintext polynomials as unsigned + integers that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 + would be stored as a polynomial coefficient plain_modulus-1. + + @par BinaryFractionalEncoder and BalancedFractionalEncoder + Under the hood FractionalEncoder uses either the BinaryFractionalEncoder or the + BalancedFractionalEncoder classes to do the encoding. The first one is used when the base is 2, + and the second one when the base is at least 3. Currently the BinaryFractionalEncoder and + BalancedFractionalEncoder classes can also be used directly, but this might change in future releases. + + @see BinaryFractionalEncoder for encoding using the binary representation. + @see BalancedFractionalEncoder for encoding using base-b representation for b greater than 2. + @see IntegerEncoder for encoding integers. + */ + class FractionalEncoder : public AbstractFractionalEncoder + { + public: + /** + Creates a new FractionalEncoder object. The constructor takes as input a reference + to the plaintext modulus, the degree of the polynomial modulus, and the numbers of + coefficients that are reserved for the integral and fractional parts, and optionally + an integer, at least 2, that is used as the base in the encoding. The coefficients + for the integral part are counted starting from the low-degree end of the polynomial, + and the coefficients for the fractional part are counted from the high-degree end. + + @param[in] plain_modulus The plaintext modulus (represented by SmallModulus) + @param[in] poly_modulus_degree The degree of the polynomial modulus + @param[in] integer_coeff_count The number of polynomial coefficients reserved for the integral part + @param[in] fraction_coeff_count The number of polynomial coefficients reserved for the fractional part + @param[in] base The base to be used for encoding (default value is 2) + @throws std::invalid_argument if plain_modulus is not at least base + @throws std::invalid_argument if integer_coeff_count is not strictly positive + @throws std::invalid_argument if fraction_coeff_count is not strictly positive + @throws std::invalid_argument if poly_modulus_degree is too small for the integral and fractional parts + @throws std::invalid_argument if base is not an integer and at least 2 + */ + FractionalEncoder(const SmallModulus &plain_modulus, std::size_t poly_modulus_degree, + std::size_t integer_coeff_count, std::size_t fraction_coeff_count, std::uint64_t base = 2); + + /** + Creates a copy of a FractionalEncoder. + + @param[in] copy The FractionalEncoder to copy from + */ + FractionalEncoder(const FractionalEncoder ©); + + /** + Creates a new FractionalEncoder by moving an old one. + + @param[in] source The FractionalEncoder to move from + */ + FractionalEncoder(FractionalEncoder &&source) = default; + + /** + Destroys the FractionalEncoder. + */ + virtual ~FractionalEncoder() override; + + /** + Encodes a double precision floating point number into a plaintext polynomial. + + @param[in] value The double-precision floating-point number to encode + */ + virtual Plaintext encode(double value) override + { + return encoder_->encode(value); + } + + /** + Decodes a plaintext polynomial and returns the result as a double-precision + floating-point number. + + @param[in] plain The plaintext to be decoded + @throws std::invalid_argument if plain does not represent a valid plaintext polynomial + @throws std::invalid_argument if the integral part does not fit in std::int64_t + */ + virtual double decode(const Plaintext &plain) override + { + return encoder_->decode(plain); + } + + /** + Returns a reference to the plaintext modulus. + */ + virtual const SmallModulus &plain_modulus() const override + { + return encoder_->plain_modulus(); + } + + /** + Returns the degree of the polynomial modulus. + */ + virtual std::size_t poly_modulus_degree() const override + { + return encoder_->poly_modulus_degree(); + } + + /** + Returns the base used for encoding. + */ + virtual std::uint64_t base() const override + { + return encoder_->base(); + } + + /** + Returns the number of coefficients reserved for the fractional part. + */ + virtual std::size_t fraction_coeff_count() const override + { + return encoder_->fraction_coeff_count(); + } + + /** + Returns the number of coefficients reserved for the integral part. + */ + virtual std::size_t integer_coeff_count() const override + { + return encoder_->integer_coeff_count(); + } + + private: + FractionalEncoder &operator =(const FractionalEncoder &assign) = delete; + + FractionalEncoder &operator =(FractionalEncoder &&assign) = delete; + + AbstractFractionalEncoder *encoder_; + }; +} diff --git a/src/seal/encryptionparams.cpp b/src/seal/encryptionparams.cpp new file mode 100644 index 000000000..b0b525307 --- /dev/null +++ b/src/seal/encryptionparams.cpp @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/encryptionparams.h" +#include + +using namespace std; +using namespace seal::util; + +namespace seal +{ + void EncryptionParameters::Save(const EncryptionParameters &parms, ostream &stream) + { + // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit + auto old_except_mask = stream.exceptions(); + try + { + stream.exceptions(ios_base::badbit | ios_base::failbit); + + uint64_t poly_modulus_degree64 = static_cast(parms.poly_modulus_degree()); + uint64_t coeff_mod_count64 = static_cast(parms.coeff_modulus().size()); + auto scheme = parms.scheme(); + + stream.write(reinterpret_cast(&scheme), sizeof(scheme_type)); + stream.write(reinterpret_cast(&poly_modulus_degree64), sizeof(uint64_t)); + stream.write(reinterpret_cast(&coeff_mod_count64), sizeof(uint64_t)); + for (const auto &mod : parms.coeff_modulus()) + { + mod.save(stream); + } + parms.plain_modulus().save(stream); + double noise_standard_deviation = parms.noise_standard_deviation(); + stream.write(reinterpret_cast(&noise_standard_deviation), sizeof(double)); + } + catch (const exception &) + { + stream.exceptions(old_except_mask); + throw; + } + } + + EncryptionParameters EncryptionParameters::Load(istream &stream) + { + // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit + auto old_except_mask = stream.exceptions(); + try + { + stream.exceptions(ios_base::badbit | ios_base::failbit); + + // Read the scheme identifier + scheme_type scheme; + stream.read(reinterpret_cast(&scheme), sizeof(scheme_type)); + + // This constructor will throw if scheme is invalid + EncryptionParameters parms(scheme); + + // Read the poly_modulus_degree + uint64_t poly_modulus_degree64 = 0; + stream.read(reinterpret_cast(&poly_modulus_degree64), sizeof(uint64_t)); + if (poly_modulus_degree64 < SEAL_POLY_MOD_DEGREE_MIN || + poly_modulus_degree64 > SEAL_POLY_MOD_DEGREE_MAX) + { + throw invalid_argument("poly_modulus_degree is invalid"); + } + + // Read the coeff_modulus size + uint64_t coeff_mod_count64 = 0; + stream.read(reinterpret_cast(&coeff_mod_count64), sizeof(uint64_t)); + if (coeff_mod_count64 > SEAL_COEFF_MOD_COUNT_MAX || + coeff_mod_count64 < SEAL_COEFF_MOD_COUNT_MIN) + { + throw invalid_argument("coeff_modulus is invalid"); + } + + // Read the coeff_modulus + vector coeff_modulus(coeff_mod_count64); + for (auto &mod : coeff_modulus) + { + mod.load(stream); + } + + // Read the plain_modulus + SmallModulus plain_modulus; + plain_modulus.load(stream); + + // Read noise_standard_deviation + double noise_standard_deviation; + stream.read(reinterpret_cast(&noise_standard_deviation), sizeof(double)); + + // Supposedly everything worked so set the values of member variables + parms.set_poly_modulus_degree(safe_cast(poly_modulus_degree64)); + parms.set_coeff_modulus(coeff_modulus); + parms.set_plain_modulus(plain_modulus); + parms.set_noise_standard_deviation(noise_standard_deviation); + + stream.exceptions(old_except_mask); + return parms; + } + catch (const exception &) + { + stream.exceptions(old_except_mask); + throw; + } + catch (...) + { + stream.exceptions(old_except_mask); + throw; + } + } + + void EncryptionParameters::compute_parms_id() + { + size_t coeff_mod_count = coeff_modulus_.size(); + + size_t total_uint64_count = add_safe( + size_t(1), // scheme + size_t(1), // poly_modulus_degree + coeff_mod_count, + plain_modulus_.uint64_count(), + size_t(1) // noise_standard_deviation + ); + + auto param_data(allocate_uint(total_uint64_count, pool_)); + uint64_t *param_data_ptr = param_data.get(); + + // Write the scheme identifier + *param_data_ptr++ = static_cast(scheme_); + + // Write the poly_modulus_degree. Note that it will always be positive. + *param_data_ptr++ = static_cast(poly_modulus_degree_); + + for(const auto &mod : coeff_modulus_) + { + *param_data_ptr++ = mod.value(); + } + + set_uint_uint(plain_modulus_.data(), plain_modulus_.uint64_count(), param_data_ptr); + param_data_ptr += plain_modulus_.uint64_count(); + + memcpy(param_data_ptr++, &noise_standard_deviation_, sizeof(double)); + + HashFunction::sha3_hash(param_data.get(), total_uint64_count, parms_id_); + + // Did we somehow manage to get a zero block as result? This is reserved for + // plaintexts to indicate non-NTT-transformed form. + if (parms_id_ == parms_id_zero) + { + throw logic_error("parms_id cannot be zero"); + } + } +} diff --git a/src/seal/encryptionparams.h b/src/seal/encryptionparams.h new file mode 100644 index 000000000..b3a07279c --- /dev/null +++ b/src/seal/encryptionparams.h @@ -0,0 +1,416 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include "seal/util/defines.h" +#include "seal/util/globals.h" +#include "seal/randomgen.h" +#include "seal/smallmodulus.h" +#include "seal/util/hash.h" +#include "seal/memorymanager.h" + +namespace seal +{ + enum class scheme_type : std::uint8_t + { + BFV = 0x1, + CKKS = 0x2 + }; + + inline bool is_valid_scheme(scheme_type scheme) noexcept + { + return (scheme == scheme_type::BFV) || + (scheme == scheme_type::CKKS); + } + + /** + The data type to store unique identifiers of encryption parameters. + */ + using parms_id_type = util::HashFunction::sha3_block_type; + + /** + A parms_id_type value consisting of zeros. + */ + static constexpr parms_id_type parms_id_zero = + util::HashFunction::sha3_zero_block; + + /** + Represents user-customizable encryption scheme settings. The parameters (most + importantly poly_modulus, coeff_modulus, plain_modulus) significantly affect + the performance, capabilities, and security of the encryption scheme. Once + an instance of EncryptionParameters is populated with appropriate parameters, + it can be used to create an instance of the SEALContext class, which verifies + the validity of the parameters, and performs necessary pre-computations. + + Picking appropriate encryption parameters is essential to enable a particular + application while balancing performance and security. Some encryption settings + will not allow some inputs (e.g. attempting to encrypt a polynomial with more + coefficients than poly_modulus or larger coefficients than plain_modulus) or, + support the desired computations (with noise growing too fast due to too large + plain_modulus and too small coeff_modulus). + + @par parms_id + The EncryptionParameters class maintains at all times a 256-bit SHA-3 hash of + the currently set encryption parameters. This hash acts as a unique identifier + of the encryption parameters and is used by all further objects created for + these encryption parameters. The parms_id is not intended to be directly modified + by the user but is used internally for pre-computation data lookup and input + validity checks. In modulus switching the user can use the parms_id to map the + chain of encryption parameters. + + @par Thread Safety + In general, reading from EncryptionParameters is thread-safe, while mutating + is not. + + @warning Choosing inappropriate encryption parameters may lead to an encryption + scheme that is not secure, does not perform well, and/or does not support the + input and computation of the desired application. We highly recommend consulting + an expert in RLWE-based encryption when selecting parameters, as this is where + inexperienced users seem to most often make critical mistakes. + */ + class EncryptionParameters + { + public: + /** + Creates an empty set of encryption parameters. At a minimum, the user needs + to specify the parameters poly_modulus, coeff_modulus, and plain_modulus + for the parameters to be usable. + + @throw std::invalid_argument if scheme is not supported + @see scheme_type for the supported schemes + */ + EncryptionParameters(scheme_type scheme) + { + // Check that a valid scheme is given + if (!is_valid_scheme(scheme)) + { + throw std::invalid_argument("unsupported scheme"); + } + + scheme_ = scheme; + compute_parms_id(); + } + + /** + Creates a copy of a given instance of EncryptionParameters. + + @param[in] copy The EncryptionParameters to copy from + */ + EncryptionParameters(const EncryptionParameters ©) = default; + + /** + Overwrites the EncryptionParameters instance with a copy of a given instance. + + @param[in] assign The EncryptionParameters to copy from + */ + EncryptionParameters &operator =(const EncryptionParameters &assign) = default; + + /** + Creates a new EncryptionParameters instance by moving a given instance. + + @param[in] source The EncryptionParameters to move from + */ + EncryptionParameters(EncryptionParameters &&source) = default; + + /** + Overwrites the EncryptionParameters instance by moving a given instance. + + @param[in] assign The EncryptionParameters to move from + */ + EncryptionParameters &operator =(EncryptionParameters &&assign) = default; + + /** + Sets the degree of the polynomial modulus parameter to the specified value. + The polynomial modulus directly affects the number of coefficients in + plaintext polynomials, the size of ciphertext elements, the computational + performance of the scheme (bigger is worse), and the security level (bigger + is better). In SEAL the degree of the polynomial modulus must be a power + of 2 (e.g. 1024, 2048, 4096, 8192, 16384, or 32768). + + @param[in] poly_modulus_degree The new polynomial modulus degree + */ + inline void set_poly_modulus_degree(std::size_t poly_modulus_degree) + { + // Set the degree + poly_modulus_degree_ = poly_modulus_degree; + + // Re-compute the parms_id + compute_parms_id(); + } + + /** + Sets the coefficient modulus parameter. The coefficient modulus consists + of a list of distinct prime numbers, and is represented by a vector of + SmallModulus objects. The coefficient modulus directly affects the size + of ciphertext elements, the amount of computation that the scheme can perform + (bigger is better), and the security level (bigger is worse). In SEAL each + of the prime numbers in the coefficient modulus must be at most 60 bits, + and must be congruent to 1 modulo 2*degree(poly_modulus). + + @param[in] coeff_modulus The new coefficient modulus + @throws std::invalid_argument if size of coeff_modulus is invalid + */ + inline void set_coeff_modulus(const std::vector &coeff_modulus) + { + // Set the coeff_modulus_ + if (coeff_modulus.size() > SEAL_COEFF_MOD_COUNT_MAX || + coeff_modulus.size() < SEAL_COEFF_MOD_COUNT_MIN) + { + throw std::invalid_argument("coeff_modulus is invalid"); + } + + coeff_modulus_ = coeff_modulus; + + // Re-compute the parms_id + compute_parms_id(); + } + + /** + Sets the plaintext modulus parameter. The plaintext modulus is an integer + modulus represented by the SmallModulus class. The plaintext modulus + determines the largest coefficient that plaintext polynomials can represent. + It also affects the amount of computation that the scheme can perform + (bigger is worse). In SEAL the plaintext modulus can be at most 60 bits + long, but can otherwise be any integer. Note, however, that some features + (e.g. batching) require the plaintext modulus to be of a particular form. + + @param[in] plain_modulus The new plaintext modulus + @throws std::logic_error if scheme is not scheme_type::BFV + */ + inline void set_plain_modulus(const SmallModulus &plain_modulus) + { + // CKKS does not use plain_modulus + if (scheme_ != scheme_type::BFV) + { + throw std::logic_error("unsupported scheme"); + } + + plain_modulus_ = plain_modulus; + + // Re-compute the parms_id + compute_parms_id(); + } + + /** + Sets the plaintext modulus parameter. The plaintext modulus is an integer + modulus represented by the SmallModulus class. This constructor instead + takes a std::uint64_t and automatically creates the SmallModulus object. + The plaintext modulus determines the largest coefficient that plaintext + polynomials can represent. It also affects the amount of computation that + the scheme can perform (bigger is worse). In SEAL the plaintext modulus + can be at most 60 bits long, but can otherwise be any integer. Note, + however, that some features (e.g. batching) require the plaintext modulus + to be of a particular form. + + @param[in] plain_modulus The new plaintext modulus + @throws std::invalid_argument if plain_modulus is invalid + */ + inline void set_plain_modulus(std::uint64_t plain_modulus) + { + set_plain_modulus(SmallModulus(plain_modulus)); + + // Re-compute the parms_id + compute_parms_id(); + } + + /** + Sets the standard deviation of the noise distribution used for error + sampling. This parameter directly affects the security level of the scheme. + However, it should not be necessary for most users to change this parameter + from its default value. + + @param[in] noise_standard_deviation The new standard deviation + @throw std::invalid_argument if noise_standard_deviation is negative or + too large + */ + inline void set_noise_standard_deviation(double noise_standard_deviation) + { + if (std::signbit(noise_standard_deviation) || + (noise_standard_deviation > std::numeric_limits::max() / + util::global_variables::noise_distribution_width_multiplier)) + { + throw std::invalid_argument("noise_standard_deviation is invalid"); + } + + noise_standard_deviation_ = noise_standard_deviation; + noise_max_deviation_ = + util::global_variables::noise_distribution_width_multiplier * + noise_standard_deviation_; + + // Re-compute the parms_id + compute_parms_id(); + } + + /** + Sets the random number generator factory to use for encryption. By default, + the random generator is set to UniformRandomGeneratorFactory::default_factory(). + Setting this value allows a user to specify a custom random number generator + source. + + @param[in] random_generator Pointer to the random generator factory + */ + inline void set_random_generator( + std::shared_ptr random_generator) + { + random_generator_ = std::move(random_generator); + } + + /** + Returns the encryption scheme type. + */ + inline scheme_type scheme() const + { + return scheme_; + } + + /** + Returns the degree of the polynomial modulus parameter. + */ + inline std::size_t poly_modulus_degree() const + { + return poly_modulus_degree_; + } + + /** + Returns a const reference to the currently set coefficient modulus parameter. + */ + inline const std::vector &coeff_modulus() const + { + return coeff_modulus_; + } + + /** + Returns a const reference to the currently set plaintext modulus parameter. + */ + inline const SmallModulus &plain_modulus() const + { + return plain_modulus_; + } + + /** + Returns the currently set standard deviation of the noise distribution. + */ + inline double noise_standard_deviation() const + { + return noise_standard_deviation_; + } + + /** + Returns the currently set maximum deviation of the noise distribution. + This value cannot be directly controlled by the user, and is automatically + set to be an appropriate multiple of the noise_standard_deviation parameter. + */ + inline double noise_max_deviation() const + { + return noise_max_deviation_; + } + + /** + Returns a pointer to the random number generator factory to use for encryption. + */ + inline std::shared_ptr random_generator() const + { + return random_generator_; + } + + /** + Compares a given set of encryption parameters to the current set of + encryption parameters. The comparison is performed by comparing the + parms_ids of the parameter sets rather than comparing the parameters + individually. + + @parms[in] other The EncryptionParameters to compare against + */ + inline bool operator ==(const EncryptionParameters &other) const + { + return (parms_id_ == other.parms_id_); + } + + /** + Compares a given set of encryption parameters to the current set of + encryption parameters. The comparison is performed by comparing + parms_ids of the parameter sets rather than comparing the parameters + individually. + + @parms[in] other The EncryptionParameters to compare against + */ + inline bool operator !=(const EncryptionParameters &other) const + { + return (parms_id_ != other.parms_id_); + } + + /** + Returns the parms_id of the current parameters. This function is intended + mainly for internal use. + */ + inline auto &parms_id() const + { + return parms_id_; + } + + /** + Saves EncryptionParameters to an output stream. The output is in binary + format and is not human-readable. The output stream must have the "binary" + flag set. + + @param[in] stream The stream to save the EncryptionParameters to + @throws std::exception if the EncryptionParameters could not be written + to stream + */ + static void Save(const EncryptionParameters &parms, std::ostream &stream); + + /** + Loads EncryptionParameters from an input stream. + + @param[in] stream The stream to load the EncryptionParameters from + @throws std::exception if valid EncryptionParameters could not be read + from stream + */ + static EncryptionParameters Load(std::istream &stream); + + private: + void compute_parms_id(); + + MemoryPoolHandle pool_ = MemoryManager::GetPool(); + + scheme_type scheme_; + + std::size_t poly_modulus_degree_ = 0; + + std::vector coeff_modulus_{}; + + double noise_standard_deviation_ = + util::global_variables::default_noise_standard_deviation; + + double noise_max_deviation_ = + util::global_variables::noise_distribution_width_multiplier * + util::global_variables::default_noise_standard_deviation; + + std::shared_ptr random_generator_{ nullptr }; + + SmallModulus plain_modulus_{}; + + parms_id_type parms_id_ = parms_id_zero; + }; +} + +/** +Specializes the std::hash template for parms_id_type. +*/ +namespace std +{ + template<> + struct hash + { + std::size_t operator()( + const seal::parms_id_type &parms_id) const + { + return std::accumulate(parms_id.begin(), parms_id.end(), std::size_t(0), + [](std::size_t acc, std::uint64_t curr) { return acc ^ curr; }); + } + }; +} diff --git a/src/seal/encryptor.cpp b/src/seal/encryptor.cpp new file mode 100644 index 000000000..745e4b189 --- /dev/null +++ b/src/seal/encryptor.cpp @@ -0,0 +1,413 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include "seal/encryptor.h" +#include "seal/util/common.h" +#include "seal/util/uintarith.h" +#include "seal/util/polyarithsmallmod.h" +#include "seal/util/clipnormal.h" +#include "seal/util/randomtostd.h" +#include "seal/util/smallntt.h" +#include "seal/smallmodulus.h" + +using namespace std; +using namespace seal::util; + +namespace seal +{ + Encryptor::Encryptor(shared_ptr context, + const PublicKey &public_key) : context_(move(context)) + { + // Verify parameters + if (!context_) + { + throw invalid_argument("invalid context"); + } + if (!context_->parameters_set()) + { + throw invalid_argument("encryption parameters are not set correctly"); + } + if (public_key.parms_id() != context_->first_parms_id()) + { + throw invalid_argument("public key is not valid for encryption parameters"); + } + + auto &parms = context_->context_data()->parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + + // Quick sanity check + if (!product_fits_in(coeff_count, coeff_mod_count, size_t(2))) + { + throw logic_error("invalid parameters"); + } + + // Allocate space and copy over key + public_key_ = allocate_poly(2 * coeff_count, coeff_mod_count, pool_); + set_poly_poly(public_key.data().data(0), 2 * coeff_count, coeff_mod_count, + public_key_.get()); + } + + void Encryptor::encrypt(const Plaintext &plain, + Ciphertext &destination, MemoryPoolHandle pool) + { + // Verify parameters. + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + + auto &context_data = *context_->context_data(); + auto &parms = context_data.parms(); + + switch (parms.scheme()) + { + case scheme_type::BFV: + bfv_encrypt(plain, destination, move(pool)); + return; + + case scheme_type::CKKS: + ckks_encrypt(plain, destination, move(pool)); + return; + + default: + throw invalid_argument("unsupported scheme"); + } + } + + void Encryptor::bfv_encrypt(const Plaintext &plain, + Ciphertext &destination, MemoryPoolHandle pool) + { + if (plain.is_ntt_form()) + { + throw invalid_argument("plain cannot be in NTT form"); + } + + auto &context_data = *context_->context_data(); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t first_coeff_mod_count = + context_->context_data()->parms().coeff_modulus().size(); + size_t coeff_mod_count = coeff_modulus.size(); + + // Verify more parameters. + if (plain.coeff_count() > coeff_count) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } + + auto &small_ntt_tables = context_data.small_ntt_tables(); +#ifdef SEAL_DEBUG + if (!are_poly_coefficients_less_than(plain.data(), + plain.coeff_count(), parms.plain_modulus().value())) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#endif + // Make destination have right size and parms_id + destination.resize(context_, parms.parms_id(), 2); + destination.is_ntt_form() = false; + + /* + Ciphertext (c_0,c_1) + c_0 = Delta * m + public_key_[0] * u + e_1 where u sampled from R_2 and e_1 sampled from chi. + c_1 = public_key_[1] * u + e_2 where e_2 sampled from chi. + */ + + // Generate u + auto u(allocate_poly(coeff_count, coeff_mod_count, pool)); + shared_ptr random(parms.random_generator()->create()); + + set_poly_coeffs_zero_one_negone(u.get(), random, context_data); + + // Multiply both u * public_key_[0] and u * public_key_[1] using the same FFT + for (size_t i = 0; i < coeff_mod_count; i++) + { + ntt_negacyclic_harvey_lazy(u.get() + (i * coeff_count), small_ntt_tables[i]); + + dyadic_product_coeffmod(u.get() + (i * coeff_count), + public_key_.get() + (i * coeff_count), coeff_count, + coeff_modulus[i], destination.data() + (i * coeff_count)); + inverse_ntt_negacyclic_harvey(destination.data() + (i * coeff_count), + small_ntt_tables[i]); + + dyadic_product_coeffmod(u.get() + (i * coeff_count), + public_key_.get() + (coeff_count * first_coeff_mod_count) + (i * coeff_count), + coeff_count, coeff_modulus[i], destination.data(1) + (i * coeff_count)); + inverse_ntt_negacyclic_harvey(destination.data(1) + (i * coeff_count), + small_ntt_tables[i]); + } + + // Multiply plain by scalar coeff_div_plaintext and reposition if in upper-half. + // Result gets added into the c_0 term of ciphertext (c_0,c_1). + preencrypt(plain.data(), plain.coeff_count(), context_data, destination.data()); + + // Generate e_0, add this value into destination[0]. + set_poly_coeffs_normal(u.get(), random, context_data); + for (size_t i = 0; i < coeff_mod_count; i++) + { + add_poly_poly_coeffmod(u.get() + (i * coeff_count), + destination.data() + (i * coeff_count), coeff_count, + coeff_modulus[i], destination.data() + (i * coeff_count)); + } + // Generate e_1, add this value into destination[1]. + set_poly_coeffs_normal(u.get(), random, context_data); + for (size_t i = 0; i < coeff_mod_count; i++) + { + add_poly_poly_coeffmod(u.get() + (i * coeff_count), + destination.data(1) + (i * coeff_count), coeff_count, + coeff_modulus[i], destination.data(1) + (i * coeff_count)); + } + } + + void Encryptor::ckks_encrypt(const Plaintext &plain, + Ciphertext &destination, MemoryPoolHandle pool) + { + if (!plain.is_ntt_form()) + { + throw invalid_argument("plain must be in NTT form"); + } + + auto context_data_ptr = context_->context_data(plain.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } + + auto &context_data = *context_data_ptr; + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t first_coeff_mod_count = + context_->context_data()->parms().coeff_modulus().size(); + size_t coeff_mod_count = coeff_modulus.size(); + + auto &small_ntt_tables = context_data.small_ntt_tables(); +#ifdef SEAL_DEBUG + // Check that the plaintext doesn't have more coefficients than allowed + if (unsigned_gt(plain.coeff_count(), mul_safe(coeff_count, coeff_mod_count))) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#endif + // Make destination have right size and hash block + destination.resize(context_, parms.parms_id(), 2); + destination.is_ntt_form() = true; + destination.scale() = plain.scale(); + + /* + Ciphertext (c_0,c_1) + c_0 = m + public_key_[0] * u + e_1 where u sampled from R_2 and e_1 sampled from chi. + c_1 = public_key_[1] * u + e_2 where e_2 sampled from chi. + */ + + // Generate u + auto u(allocate_poly(coeff_count, coeff_mod_count, pool)); + shared_ptr random(parms.random_generator()->create()); + + set_poly_coeffs_zero_one_negone(u.get(), random, context_data); + + // Multiply both u * public_key_[0] and u * public_key_[1] using the same FFT + for (size_t i = 0; i < coeff_mod_count; i++) + { + ntt_negacyclic_harvey(u.get() + (i * coeff_count), small_ntt_tables[i]); + dyadic_product_coeffmod( + u.get() + (i * coeff_count), + public_key_.get() + (i * coeff_count), + coeff_count, + coeff_modulus[i], + destination.data() + (i * coeff_count)); + dyadic_product_coeffmod( + u.get() + (i * coeff_count), + public_key_.get() + (coeff_count * first_coeff_mod_count) + (i * coeff_count), + coeff_count, + coeff_modulus[i], + destination.data(1) + (i * coeff_count)); + } + + auto tmp(allocate_uint(coeff_count, pool)); + // The plaintext gets added into the c_0 term of ciphertext (c_0,c_1). + for (size_t i = 0; i < coeff_mod_count; i++) + { + add_poly_poly_coeffmod(destination.data() + (i * coeff_count), + plain.data() + (i * coeff_count), coeff_count, + coeff_modulus[i], destination.data() + (i * coeff_count)); + } + + // Generate e_0, add this value into destination[0]. + set_poly_coeffs_normal(u.get(), random, context_data); + + for (size_t i = 0; i < coeff_mod_count; i++) + { + ntt_negacyclic_harvey(u.get() + (i * coeff_count), small_ntt_tables[i]); + add_poly_poly_coeffmod(u.get() + (i * coeff_count), + destination.data() + (i * coeff_count), coeff_count, + coeff_modulus[i], destination.data() + (i * coeff_count)); + } + // Generate e_1, add this value into destination[1]. + set_poly_coeffs_normal(u.get(), random, context_data); + + for (size_t i = 0; i < coeff_mod_count; i++) + { + ntt_negacyclic_harvey(u.get() + (i * coeff_count), small_ntt_tables[i]); + add_poly_poly_coeffmod(u.get() + (i * coeff_count), + destination.data(1) + (i * coeff_count), coeff_count, + coeff_modulus[i], destination.data(1) + (i * coeff_count)); + } + } + + void Encryptor::preencrypt(const uint64_t *plain, size_t plain_coeff_count, + const SEALContext::ContextData &context_data, uint64_t *destination) + { + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + + auto coeff_div_plain_modulus = context_data.coeff_div_plain_modulus(); + auto plain_upper_half_threshold = context_data.plain_upper_half_threshold(); + auto upper_half_increment = context_data.upper_half_increment(); + + // Multiply plain by scalar coeff_div_plain_modulus_ and reposition if in upper-half. + for (size_t i = 0; i < plain_coeff_count; i++) + { + if (plain[i] >= plain_upper_half_threshold) + { + // Loop over primes + for (size_t j = 0; j < coeff_mod_count; j++) + { + unsigned long long temp[2]{ 0, 0 }; + multiply_uint64(coeff_div_plain_modulus[j], plain[i], temp); + temp[1] += add_uint64(temp[0], upper_half_increment[j], 0, temp); + uint64_t scaled_plain_coeff = barrett_reduce_128(temp, coeff_modulus[j]); + destination[j * coeff_count] = add_uint_uint_mod( + destination[j * coeff_count], scaled_plain_coeff, coeff_modulus[j]); + } + } + else + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + uint64_t scaled_plain_coeff = multiply_uint_uint_mod( + coeff_div_plain_modulus[j], plain[i], coeff_modulus[j]); + destination[j * coeff_count] = add_uint_uint_mod( + destination[j * coeff_count], scaled_plain_coeff, coeff_modulus[j]); + } + } + destination++; + } + } + + void Encryptor::set_poly_coeffs_zero_one_negone(uint64_t *poly, + std::shared_ptr random, + const SEALContext::ContextData &context_data) const + { + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + + RandomToStandardAdapter engine(random); + uniform_int_distribution dist(-1, 1); + for (size_t i = 0; i < coeff_count; i++) + { + int rand_index = dist(engine); + if (rand_index == 1) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + poly[i + (j * coeff_count)] = 1; + } + } + else if (rand_index == -1) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + poly[i + (j * coeff_count)] = coeff_modulus[j].value() - 1; + } + } + else + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + poly[i + (j * coeff_count)] = 0; + } + } + } + } + + void Encryptor::set_poly_coeffs_zero_one(uint64_t *poly, + std::shared_ptr random, + const SEALContext::ContextData &context_data) const + { + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + + RandomToStandardAdapter engine(random); + uniform_int_distribution dist(0, 1); + + set_zero_poly(coeff_count, coeff_mod_count, poly); + for (size_t i = 0; i < coeff_count; i++) + { + int rand_index = dist(engine); + for (size_t j = 0; j < coeff_mod_count; j++) + { + poly[i + (j * coeff_count)] = static_cast(rand_index); + } + } + } + + void Encryptor::set_poly_coeffs_normal(uint64_t *poly, + std::shared_ptr random, + const SEALContext::ContextData &context_data) const + { + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + + if ((parms.noise_standard_deviation() == 0.0) || + (parms.noise_max_deviation() == 0.0)) + { + set_zero_poly(coeff_count, coeff_mod_count, poly); + return; + } + + RandomToStandardAdapter engine(random); + ClippedNormalDistribution dist(0, parms.noise_standard_deviation(), + parms.noise_max_deviation()); + for (size_t i = 0; i < coeff_count; i++) + { + int64_t noise = static_cast(dist(engine)); + if (noise > 0) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + poly[i + (j * coeff_count)] = static_cast(noise); + } + } + else if (noise < 0) + { + noise = -noise; + for (size_t j = 0; j < coeff_mod_count; j++) + { + poly[i + (j * coeff_count)] = + coeff_modulus[j].value() - static_cast(noise); + } + } + else + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + poly[i + (j * coeff_count)] = 0; + } + } + } + } +} diff --git a/src/seal/encryptor.h b/src/seal/encryptor.h new file mode 100644 index 000000000..9fb43e481 --- /dev/null +++ b/src/seal/encryptor.h @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include "seal/encryptionparams.h" +#include "seal/plaintext.h" +#include "seal/ciphertext.h" +#include "seal/memorymanager.h" +#include "seal/context.h" +#include "seal/util/smallntt.h" +#include "seal/publickey.h" + +namespace seal +{ + /** + Encrypts Plaintext objects into Ciphertext objects. Constructing an Encryptor + requires a SEALContext with valid encryption parameters, and the public key. + + @par Overloads + For the encrypt function we provide two overloads concerning the memory pool + used in allocations needed during the operation. In one overload the global + memory pool is used for this purpose, and in another overload the user can + supply a MemoryPoolHandle to to be used instead. This is to allow one single + Encryptor to be used concurrently by several threads without running into thread + contention in allocations taking place during operations. For example, one can + share one single Encryptor across any number of threads, but in each thread + call the encrypt function by giving it a thread-local MemoryPoolHandle to use. + It is important for a developer to understand how this works to avoid unnecessary + performance bottlenecks. + + @par NTT form + When using the BFV scheme (scheme_type::BFV), all plaintext and ciphertexts should + remain by default in the usual coefficient representation, i.e. not in NTT form. + When using the CKKS scheme (scheme_type::CKKS), all plaintexts and ciphertexts + should remain by default in NTT form. We call these scheme-specific NTT states + the "default NTT form". Decryption requires the input ciphertexts to be in + the default NTT form, and will throw an exception if this is not the case. + */ + class Encryptor + { + public: + /** + Creates an Encryptor instance initialized with the specified SEALContext + and public key. + + @param[in] context The SEALContext + @param[in] public_key The public key + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + @throws std::invalid_argument if public_key is not valid + */ + Encryptor(std::shared_ptr context, const PublicKey &public_key); + + /** + Encrypts a Plaintext and stores the result in the destination parameter. + Dynamic memory allocations in the process are allocated from the memory + pool pointed to by the given MemoryPoolHandle. + + @param[in] plain The plaintext to encrypt + @param[out] destination The ciphertext to overwrite with the encrypted plaintext + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if plain is not valid for the encryption parameters + @throws std::invalid_argument if plain is not in default NTT form + @throws std::invalid_argument if pool is uninitialized + */ + void encrypt(const Plaintext &plain, Ciphertext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()); + + private: + Encryptor(const Encryptor ©) = delete; + + Encryptor(Encryptor &&source) = delete; + + Encryptor &operator =(const Encryptor &assign) = delete; + + Encryptor &operator =(Encryptor &&assign) = delete; + + void preencrypt(const std::uint64_t *plain, std::size_t plain_coeff_count, + const SEALContext::ContextData &context_data, std::uint64_t *destination); + + void set_poly_coeffs_normal(std::uint64_t *poly, + std::shared_ptr random, + const SEALContext::ContextData &context_data) const; + + void set_poly_coeffs_zero_one_negone(uint64_t *poly, + std::shared_ptr random, + const SEALContext::ContextData &context_data) const; + + void set_poly_coeffs_zero_one(uint64_t *poly, + std::shared_ptr random, + const SEALContext::ContextData &context_data) const; + + void bfv_encrypt(const Plaintext &plain, Ciphertext &destination, + MemoryPoolHandle pool); + + void ckks_encrypt(const Plaintext &plain, Ciphertext &destination, + MemoryPoolHandle pool); + + MemoryPoolHandle pool_ = MemoryManager::GetPool(); + + std::shared_ptr context_{ nullptr }; + + util::Pointer public_key_; + }; +} diff --git a/src/seal/evaluator.cpp b/src/seal/evaluator.cpp new file mode 100644 index 000000000..d72178628 --- /dev/null +++ b/src/seal/evaluator.cpp @@ -0,0 +1,3233 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include +#include "seal/evaluator.h" +#include "seal/util/common.h" +#include "seal/util/uintarith.h" +#include "seal/util/polycore.h" +#include "seal/util/polyarithsmallmod.h" + +using namespace std; +using namespace seal::util; + +namespace seal +{ + namespace + { + template + bool are_same_scale(T value1, S value2) + { + return util::are_close(value1.scale(), value2.scale()); + } + } + + Evaluator::Evaluator(shared_ptr context) : context_(move(context)) + { + // Verify parameters + if (!context_) + { + throw invalid_argument("invalid context"); + } + if (!context_->parameters_set()) + { + throw invalid_argument("encryption parameters are not set correctly"); + } + + // Calculate map from Zmstar to generator representation + populate_Zmstar_to_generator(); + } + + void Evaluator::populate_Zmstar_to_generator() + { + uint64_t n = static_cast( + context_->context_data()->parms().poly_modulus_degree()); + uint64_t m = n << 1; + + for (uint64_t i = 0; i < n / 2; i++) + { + uint64_t galois_elt = exponentiate_uint64(3, i) & (m - 1); + pair temp_pair1{ i, 0 }; + Zmstar_to_generator_.emplace(galois_elt, temp_pair1); + galois_elt = (exponentiate_uint64(3, i) * (m - 1)) & (m - 1); + pair temp_pair2{ i, 1 }; + Zmstar_to_generator_.emplace(galois_elt, temp_pair2); + } + } + + void Evaluator::negate_inplace(Ciphertext &encrypted) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypted is not valid for encryption parameters"); + } + + // Extract encryption parameters. + auto &context_data = *context_data_ptr; + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t encrypted_size = encrypted.size(); + + // Negate each poly in the array + for (size_t j = 0; j < encrypted_size; j++) + { + for (size_t i = 0; i < coeff_mod_count; i++) + { + negate_poly_coeffmod(encrypted.data(j) + (i * coeff_count), + coeff_count, coeff_modulus[i], encrypted.data(j) + (i * coeff_count)); + } + } + } + + void Evaluator::add_inplace(Ciphertext &encrypted1, const Ciphertext &encrypted2) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted1.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypted1 is not valid for encryption parameters"); + } + if (encrypted1.parms_id() != encrypted2.parms_id()) + { + throw invalid_argument("encrypted1 and encrypted2 parameter mismatch"); + } + if (encrypted1.is_ntt_form() != encrypted2.is_ntt_form()) + { + throw invalid_argument("NTT form mismatch"); + } + if (!are_same_scale(encrypted1, encrypted2)) + { + throw invalid_argument("scale mismatch"); + } + + // Extract encryption parameters. + auto &context_data = *context_data_ptr; + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t encrypted1_size = encrypted1.size(); + size_t encrypted2_size = encrypted2.size(); + size_t max_count = max(encrypted1_size, encrypted2_size); + size_t min_count = min(encrypted1_size, encrypted2_size); + + // Size check + if (!product_fits_in(max_count, coeff_count)) + { + throw logic_error("invalid parameters"); + } + + // Prepare destination + encrypted1.resize(context_, parms.parms_id(), max_count); + + // Add ciphertexts + for (size_t j = 0; j < min_count; j++) + { + for (size_t i = 0; i < coeff_mod_count; i++) + { + add_poly_poly_coeffmod(encrypted1.data(j) + (i * coeff_count), + encrypted2.data(j) + (i * coeff_count), coeff_count, coeff_modulus[i], + encrypted1.data(j) + (i * coeff_count)); + } + } + + // Copy the remainding polys of the array with larger count into encrypted1 + if (encrypted1_size < encrypted2_size) + { + set_poly_poly(encrypted2.data(min_count), + coeff_count * (encrypted2_size - encrypted1_size), + coeff_mod_count, encrypted1.data(encrypted1_size)); + } + } + + void Evaluator::add_many(const vector &encrypteds, Ciphertext &destination) + { + if (encrypteds.empty()) + { + throw invalid_argument("encrypteds cannot be empty"); + } + for (size_t i = 0; i < encrypteds.size(); i++) + { + if (&encrypteds[i] == &destination) + { + throw invalid_argument("encrypteds must be different from destination"); + } + } + destination = encrypteds[0]; + for (size_t i = 1; i < encrypteds.size(); i++) + { + add_inplace(destination, encrypteds[i]); + } + } + + void Evaluator::sub_inplace(Ciphertext &encrypted1, const Ciphertext &encrypted2) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted1.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypted1 is not valid for encryption parameters"); + } + if (encrypted1.parms_id() != encrypted2.parms_id()) + { + throw invalid_argument("encrypted1 and encrypted2 parameter mismatch"); + } + if (encrypted1.is_ntt_form() != encrypted2.is_ntt_form()) + { + throw invalid_argument("NTT form mismatch"); + } + if (!are_same_scale(encrypted1, encrypted2)) + { + throw invalid_argument("scale mismatch"); + } + + // Extract encryption parameters. + auto &context_data = *context_data_ptr; + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t encrypted1_size = encrypted1.size(); + size_t encrypted2_size = encrypted2.size(); + size_t max_count = max(encrypted1_size, encrypted2_size); + size_t min_count = min(encrypted1_size, encrypted2_size); + + // Size check + if (!product_fits_in(max_count, coeff_count)) + { + throw logic_error("invalid parameters"); + } + + // Prepare destination + encrypted1.resize(context_, parms.parms_id(), max_count); + + // Subtract polynomials. + for (size_t j = 0; j < min_count; j++) + { + for (size_t i = 0; i < coeff_mod_count; i++) + { + sub_poly_poly_coeffmod(encrypted1.data(j) + (i * coeff_count), + encrypted2.data(j) + (i * coeff_count), coeff_count, coeff_modulus[i], + encrypted1.data(j) + (i * coeff_count)); + } + } + + // If encrypted2 has larger count, negate remaining entries + if (encrypted1_size < encrypted2_size) + { + for (size_t i = 0; i < coeff_mod_count; i++) + { + negate_poly_coeffmod(encrypted2.data(encrypted1_size) + (i * coeff_count), + coeff_count * (encrypted2_size - encrypted1_size), coeff_modulus[i], + encrypted1.data(encrypted1_size) + (i * coeff_count)); + } + } + } + + void Evaluator::multiply_inplace(Ciphertext &encrypted1, + const Ciphertext &encrypted2, MemoryPoolHandle pool) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted1.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypted1 is not valid for encryption parameters"); + } + if (encrypted1.parms_id() != encrypted2.parms_id()) + { + throw invalid_argument("encrypted1 and encrypted2 parameter mismatch"); + } + + switch (context_data_ptr->parms().scheme()) + { + case scheme_type::BFV: + bfv_multiply(encrypted1, encrypted2, pool); + return; + + case scheme_type::CKKS: + ckks_multiply(encrypted1, encrypted2, pool); + return; + + default: + throw invalid_argument("unsupported scheme"); + } + } + + void Evaluator::bfv_multiply(Ciphertext &encrypted1, + const Ciphertext &encrypted2, MemoryPoolHandle pool) + { + if (encrypted1.is_ntt_form() || encrypted2.is_ntt_form()) + { + throw invalid_argument("encrypted1 or encrypted2 cannot be in NTT form"); + } + + // Extract encryption parameters. + auto &context_data = *context_->context_data(encrypted1.parms_id()); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t encrypted1_size = encrypted1.size(); + size_t encrypted2_size = encrypted2.size(); + + uint64_t plain_modulus = parms.plain_modulus().value(); + auto &base_converter = context_data.base_converter(); + auto &bsk_modulus = base_converter->get_bsk_mod_array(); + size_t bsk_base_mod_count = base_converter->bsk_base_mod_count(); + size_t bsk_mtilde_count = add_safe(bsk_base_mod_count, size_t(1)); + auto &coeff_small_ntt_tables = context_data.small_ntt_tables(); + auto &bsk_small_ntt_tables = base_converter->get_bsk_small_ntt_tables(); + + // Determine destination.size() + // Default is 3 (c_0, c_1, c_2) + size_t dest_count = sub_safe(add_safe(encrypted1_size, encrypted2_size), size_t(1)); + + // Size check + if (!product_fits_in(dest_count, coeff_count, bsk_mtilde_count)) + { + throw logic_error("invalid parameters"); + } + + // Prepare destination + encrypted1.resize(context_, parms.parms_id(), dest_count); + + size_t encrypted_ptr_increment = coeff_count * coeff_mod_count; + size_t encrypted_bsk_mtilde_ptr_increment = coeff_count * bsk_mtilde_count; + size_t encrypted_bsk_ptr_increment = coeff_count * bsk_base_mod_count; + + // Make temp polys for FastBConverter result from q ---> Bsk U {m_tilde} + auto tmp_encrypted1_bsk_mtilde(allocate_poly( + coeff_count * encrypted1_size, bsk_mtilde_count, pool)); + auto tmp_encrypted2_bsk_mtilde(allocate_poly( + coeff_count * encrypted2_size, bsk_mtilde_count, pool)); + + // Make temp polys for FastBConverter result from Bsk U {m_tilde} -----> Bsk + auto tmp_encrypted1_bsk(allocate_poly( + coeff_count * encrypted1_size, bsk_base_mod_count, pool)); + auto tmp_encrypted2_bsk(allocate_poly( + coeff_count * encrypted2_size, bsk_base_mod_count, pool)); + + // Step 0: fast base convert from q to Bsk U {m_tilde} + // Step 1: reduce q-overflows in Bsk + // Iterate over all the ciphertexts inside encrypted1 + for (size_t i = 0; i < encrypted1_size; i++) + { + base_converter->fastbconv_mtilde( + encrypted1.data(i), + tmp_encrypted1_bsk_mtilde.get() + (i * encrypted_bsk_mtilde_ptr_increment), + pool); + base_converter->mont_rq( + tmp_encrypted1_bsk_mtilde.get() + (i * encrypted_bsk_mtilde_ptr_increment), + tmp_encrypted1_bsk.get() + (i * encrypted_bsk_ptr_increment)); + } + + // Iterate over all the ciphertexts inside encrypted2 + for (size_t i = 0; i < encrypted2_size; i++) + { + base_converter->fastbconv_mtilde( + encrypted2.data(i), + tmp_encrypted2_bsk_mtilde.get() + (i * encrypted_bsk_mtilde_ptr_increment), pool); + base_converter->mont_rq( + tmp_encrypted2_bsk_mtilde.get() + (i * encrypted_bsk_mtilde_ptr_increment), + tmp_encrypted2_bsk.get() + (i * encrypted_bsk_ptr_increment)); + } + + // Step 2: compute product and multiply plain modulus to the result + // We need to multiply both in q and Bsk. Values in encrypted_safe are in + // base q and values in tmp_encrypted_bsk are in base Bsk. We iterate over + // destination poly array and generate each poly based on the indices of + // inputs (arbitrary sizes for ciphertexts). First allocate two temp polys: + // one for results in base q and the other for the result in base Bsk. These + // need to be zero for the arbitrary size multiplication; not for 2x2 though + auto tmp_des_coeff_base(allocate_zero_poly( + coeff_count * dest_count, coeff_mod_count, pool)); + auto tmp_des_bsk_base(allocate_zero_poly( + coeff_count * dest_count, bsk_base_mod_count, pool)); + + // Allocate two tmp polys: one for NTT multiplication results in base q and + // one for result in base Bsk + auto tmp1_poly_coeff_base(allocate_poly(coeff_count, coeff_mod_count, pool)); + auto tmp1_poly_bsk_base(allocate_poly(coeff_count, bsk_base_mod_count, pool)); + auto tmp2_poly_coeff_base(allocate_poly(coeff_count, coeff_mod_count, pool)); + auto tmp2_poly_bsk_base(allocate_poly(coeff_count, bsk_base_mod_count, pool)); + + size_t current_encrypted1_limit = 0; + + // First convert all the inputs into NTT form + auto copy_encrypted1_ntt_coeff_mod(allocate_poly( + coeff_count * encrypted1_size, coeff_mod_count, pool)); + set_poly_poly(encrypted1.data(), coeff_count * encrypted1_size, + coeff_mod_count, copy_encrypted1_ntt_coeff_mod.get()); + + auto copy_encrypted1_ntt_bsk_base_mod(allocate_poly( + coeff_count * encrypted1_size, bsk_base_mod_count, pool)); + set_poly_poly(tmp_encrypted1_bsk.get(), coeff_count * encrypted1_size, + bsk_base_mod_count, copy_encrypted1_ntt_bsk_base_mod.get()); + + auto copy_encrypted2_ntt_coeff_mod(allocate_poly( + coeff_count * encrypted2_size, coeff_mod_count, pool)); + set_poly_poly(encrypted2.data(), coeff_count * encrypted2_size, + coeff_mod_count, copy_encrypted2_ntt_coeff_mod.get()); + + auto copy_encrypted2_ntt_bsk_base_mod(allocate_poly( + coeff_count * encrypted2_size, bsk_base_mod_count, pool)); + set_poly_poly(tmp_encrypted2_bsk.get(), coeff_count * encrypted2_size, + bsk_base_mod_count, copy_encrypted2_ntt_bsk_base_mod.get()); + + for (size_t i = 0; i < encrypted1_size; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + // Lazy reduction + ntt_negacyclic_harvey_lazy(copy_encrypted1_ntt_coeff_mod.get() + + (j * coeff_count) + (i * encrypted_ptr_increment), coeff_small_ntt_tables[j]); + } + for (size_t j = 0; j < bsk_base_mod_count; j++) + { + // Lazy reduction + ntt_negacyclic_harvey_lazy(copy_encrypted1_ntt_bsk_base_mod.get() + + (j * coeff_count) + (i * encrypted_bsk_ptr_increment), bsk_small_ntt_tables[j]); + } + } + + for (size_t i = 0; i < encrypted2_size; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + // Lazy reduction + ntt_negacyclic_harvey_lazy(copy_encrypted2_ntt_coeff_mod.get() + + (j * coeff_count) + (i * encrypted_ptr_increment), coeff_small_ntt_tables[j]); + } + for (size_t j = 0; j < bsk_base_mod_count; j++) + { + // Lazy reduction + ntt_negacyclic_harvey_lazy(copy_encrypted2_ntt_bsk_base_mod.get() + + (j * coeff_count) + (i * encrypted_bsk_ptr_increment), bsk_small_ntt_tables[j]); + } + } + + // Perform Karatsuba multiplication on size 2 ciphertexts + if (encrypted1_size == 2 && encrypted2_size == 2) + { + auto tmp_first_mul_coeff_base(allocate_poly(coeff_count, coeff_mod_count, pool)); + + // Compute c0 + c1 and c0*d0 in base q + uint64_t *temp_ptr_1 = tmp1_poly_coeff_base.get(); + uint64_t *temp_ptr_2 = copy_encrypted1_ntt_coeff_mod.get(); + uint64_t *temp_ptr_3 = temp_ptr_2 + encrypted_ptr_increment; + for (size_t i = 0; i < coeff_mod_count; i++) + { + //add_poly_poly_coeffmod(copy_encrypted1_ntt_coeff_mod.get() + (i * coeff_count), + // copy_encrypted1_ntt_coeff_mod.get() + (i * coeff_count) + encrypted_ptr_increment, + // coeff_count, coeff_modulus_[i], tmp1_poly_coeff_base.get() + (i * coeff_count)); + + // Lazy reduction + for (size_t j = 0; j < coeff_count; j++) + { + *temp_ptr_1++ = *temp_ptr_2++ + *temp_ptr_3++; + } + dyadic_product_coeffmod( + copy_encrypted1_ntt_coeff_mod.get() + (i * coeff_count), + copy_encrypted2_ntt_coeff_mod.get() + (i * coeff_count), + coeff_count, coeff_modulus[i], + tmp_first_mul_coeff_base.get() + (i * coeff_count)); + } + + auto tmp_first_mul_bsk_base(allocate_poly(coeff_count, bsk_base_mod_count, pool)); + + // Compute c0 + c1 and c0*d0 in base bsk + temp_ptr_1 = tmp1_poly_bsk_base.get(); + temp_ptr_2 = copy_encrypted1_ntt_bsk_base_mod.get(); + temp_ptr_3 = temp_ptr_2 + encrypted_bsk_ptr_increment; + for (size_t i = 0; i < bsk_base_mod_count; i++) + { + //add_poly_poly_coeffmod(copy_encrypted1_ntt_bsk_base_mod.get() + (i * coeff_count), + // copy_encrypted1_ntt_bsk_base_mod.get() + (i * coeff_count) + encrypted_bsk_ptr_increment, + // coeff_count, bsk_mod_array_[i], tmp1_poly_bsk_base.get() + (i * coeff_count)); + for (size_t j = 0; j < coeff_count; j++) + { + *temp_ptr_1++ = *temp_ptr_2++ + *temp_ptr_3++; + } + dyadic_product_coeffmod( + copy_encrypted1_ntt_bsk_base_mod.get() + (i * coeff_count), + copy_encrypted2_ntt_bsk_base_mod.get() + (i * coeff_count), + coeff_count, bsk_modulus[i], + tmp_first_mul_bsk_base.get() + (i * coeff_count)); + } + + auto tmp_second_mul_coeff_base(allocate_poly(coeff_count, coeff_mod_count, pool)); + + // Compute d0 + d1 and c1*d1 in base q + temp_ptr_1 = tmp2_poly_coeff_base.get(); + temp_ptr_2 = copy_encrypted2_ntt_coeff_mod.get(); + temp_ptr_3 = temp_ptr_2 + encrypted_ptr_increment; + for (size_t i = 0; i < coeff_mod_count; i++) + { + //add_poly_poly_coeffmod(copy_encrypted2_ntt_coeff_mod.get() + (i * coeff_count), + // copy_encrypted2_ntt_coeff_mod.get() + (i * coeff_count) + encrypted_ptr_increment, + // coeff_count, coeff_modulus_[i], tmp2_poly_coeff_base.get() + (i * coeff_count)); + for (size_t j = 0; j < coeff_count; j++) + { + *temp_ptr_1++ = *temp_ptr_2++ + *temp_ptr_3++; + } + dyadic_product_coeffmod( + copy_encrypted1_ntt_coeff_mod.get() + (i * coeff_count) + encrypted_ptr_increment, + copy_encrypted2_ntt_coeff_mod.get() + (i * coeff_count) + encrypted_ptr_increment, + coeff_count, coeff_modulus[i], tmp_second_mul_coeff_base.get() + (i * coeff_count)); + } + + auto tmp_second_mul_bsk_base(allocate_poly(coeff_count, bsk_base_mod_count, pool)); + + // Compute d0 + d1 and c1*d1 in base bsk + temp_ptr_1 = tmp2_poly_bsk_base.get(); + temp_ptr_2 = copy_encrypted2_ntt_bsk_base_mod.get(); + temp_ptr_3 = temp_ptr_2 + encrypted_bsk_ptr_increment; + for (size_t i = 0; i < bsk_base_mod_count; i++) + { + //add_poly_poly_coeffmod(copy_encrypted2_ntt_bsk_base_mod.get() + (i * coeff_count), + // copy_encrypted2_ntt_bsk_base_mod.get() + (i * coeff_count) + encrypted_bsk_ptr_increment, + // coeff_count, bsk_mod_array_[i], tmp2_poly_bsk_base.get() + (i * coeff_count)); + for (size_t j = 0; j < coeff_count; j++) + { + *temp_ptr_1++ = *temp_ptr_2++ + *temp_ptr_3++; + } + dyadic_product_coeffmod( + copy_encrypted1_ntt_bsk_base_mod.get() + (i * coeff_count) + encrypted_bsk_ptr_increment, + copy_encrypted2_ntt_bsk_base_mod.get() + (i * coeff_count) + encrypted_bsk_ptr_increment, + coeff_count, bsk_modulus[i], tmp_second_mul_bsk_base.get() + (i * coeff_count)); + } + + auto tmp_mul_poly_coeff_base(allocate_poly(coeff_count, coeff_mod_count, pool)); + auto tmp_mul_poly_bsk_base(allocate_poly(coeff_count, bsk_base_mod_count, pool)); + + // Set destination first and third polys in base q + // Des[0] in base q + set_poly_poly(tmp_first_mul_coeff_base.get(), coeff_count, + coeff_mod_count, tmp_des_coeff_base.get()); + + // Des[2] in base q + set_poly_poly(tmp_second_mul_coeff_base.get(), coeff_count, + coeff_mod_count, tmp_des_coeff_base.get() + 2 * encrypted_ptr_increment); + + // Compute (c0 + c1)*(d0 + d1) - c0*d0 - c1*d1 in base q + for (size_t i = 0; i < coeff_mod_count; i++) + { + dyadic_product_coeffmod( + tmp1_poly_coeff_base.get() + (i * coeff_count), + tmp2_poly_coeff_base.get() + (i * coeff_count), + coeff_count, coeff_modulus[i], + tmp_mul_poly_coeff_base.get() + (i * coeff_count)); + sub_poly_poly_coeffmod( + tmp_mul_poly_coeff_base.get() + (i * coeff_count), + tmp_first_mul_coeff_base.get() + (i * coeff_count), + coeff_count, coeff_modulus[i], + tmp_mul_poly_coeff_base.get() + (i * coeff_count)); + + // Des[1] in base q + sub_poly_poly_coeffmod( + tmp_mul_poly_coeff_base.get() + (i * coeff_count), + tmp_second_mul_coeff_base.get() + (i * coeff_count), + coeff_count, coeff_modulus[i], + tmp_des_coeff_base.get() + (i * coeff_count) + encrypted_ptr_increment); + } + + // Set destination first and third polys in base bsk + // Des[0] in base bsk + set_poly_poly(tmp_first_mul_bsk_base.get(), coeff_count, + bsk_base_mod_count, tmp_des_bsk_base.get()); + + // Des[2] in base q + set_poly_poly(tmp_second_mul_bsk_base.get(), coeff_count, bsk_base_mod_count, + tmp_des_bsk_base.get() + 2 * encrypted_bsk_ptr_increment); + + // Compute (c0 + c1)*(d0 + d1) - c0d0 - c1d1 in base bsk + for (size_t i = 0; i < bsk_base_mod_count; i++) + { + dyadic_product_coeffmod( + tmp1_poly_bsk_base.get() + (i * coeff_count), + tmp2_poly_bsk_base.get() + (i * coeff_count), + coeff_count, bsk_modulus[i], + tmp_mul_poly_bsk_base.get() + (i * coeff_count)); + sub_poly_poly_coeffmod( + tmp_mul_poly_bsk_base.get() + (i * coeff_count), + tmp_first_mul_bsk_base.get() + (i * coeff_count), + coeff_count, bsk_modulus[i], + tmp_mul_poly_bsk_base.get() + (i * coeff_count)); + + // Des[1] in bsk + sub_poly_poly_coeffmod( + tmp_mul_poly_bsk_base.get() + (i * coeff_count), + tmp_second_mul_bsk_base.get() + (i * coeff_count), + coeff_count, bsk_modulus[i], + tmp_des_bsk_base.get() + (i * coeff_count) + encrypted_bsk_ptr_increment); + } + } + else + { + // Perform multiplication on arbitrary size ciphertexts + for (size_t secret_power_index = 0; + secret_power_index < dest_count; secret_power_index++) + { + // Loop over encrypted1 components [i], seeing if a match exists with an encrypted2 + // component [j] such that [i+j]=[secret_power_index] + // Only need to check encrypted1 components up to and including [secret_power_index], + // and strictly less than [encrypted_array.size()] + current_encrypted1_limit = min(encrypted1_size, secret_power_index + 1); + + for (size_t encrypted1_index = 0; + encrypted1_index < current_encrypted1_limit; encrypted1_index++) + { + // check if a corresponding component in encrypted2 exists + if (encrypted2_size > secret_power_index - encrypted1_index) + { + size_t encrypted2_index = secret_power_index - encrypted1_index; + + // NTT Multiplication and addition for results in q + for (size_t i = 0; i < coeff_mod_count; i++) + { + dyadic_product_coeffmod( + copy_encrypted1_ntt_coeff_mod.get() + (i * coeff_count) + + (encrypted_ptr_increment * encrypted1_index), + copy_encrypted2_ntt_coeff_mod.get() + (i * coeff_count) + + (encrypted_ptr_increment * encrypted2_index), + coeff_count, coeff_modulus[i], + tmp1_poly_coeff_base.get() + (i * coeff_count)); + add_poly_poly_coeffmod( + tmp1_poly_coeff_base.get() + (i * coeff_count), + tmp_des_coeff_base.get() + (i * coeff_count) + + (secret_power_index * coeff_count * coeff_mod_count), + coeff_count, coeff_modulus[i], + tmp_des_coeff_base.get() + (i * coeff_count) + + (secret_power_index * coeff_count * coeff_mod_count)); + } + + // NTT Multiplication and addition for results in Bsk + for (size_t i = 0; i < bsk_base_mod_count; i++) + { + dyadic_product_coeffmod( + copy_encrypted1_ntt_bsk_base_mod.get() + (i * coeff_count) + + (encrypted_bsk_ptr_increment * encrypted1_index), + copy_encrypted2_ntt_bsk_base_mod.get() + (i * coeff_count) + + (encrypted_bsk_ptr_increment * encrypted2_index), + coeff_count, bsk_modulus[i], + tmp1_poly_bsk_base.get() + (i * coeff_count)); + add_poly_poly_coeffmod( + tmp1_poly_bsk_base.get() + (i * coeff_count), + tmp_des_bsk_base.get() + (i * coeff_count) + + (secret_power_index * coeff_count * bsk_base_mod_count), + coeff_count, bsk_modulus[i], + tmp_des_bsk_base.get() + (i * coeff_count) + + (secret_power_index * coeff_count * bsk_base_mod_count)); + } + } + } + } + } + // Convert back outputs from NTT form + for (size_t i = 0; i < dest_count; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + inverse_ntt_negacyclic_harvey( + tmp_des_coeff_base.get() + (i * (encrypted_ptr_increment)) + + (j * coeff_count), coeff_small_ntt_tables[j]); + } + for (size_t j = 0; j < bsk_base_mod_count; j++) + { + inverse_ntt_negacyclic_harvey( + tmp_des_bsk_base.get() + (i * (encrypted_bsk_ptr_increment)) + + (j * coeff_count), bsk_small_ntt_tables[j]); + } + } + + // Now we multiply plain modulus to both results in base q and Bsk and allocate them together in one + // container as (te0)q(te'0)Bsk | ... |te count)q (te' count)Bsk to make it ready for fast_floor + auto tmp_coeff_bsk_together(allocate_poly( + coeff_count, dest_count * (coeff_mod_count + bsk_base_mod_count), pool)); + uint64_t *tmp_coeff_bsk_together_ptr = tmp_coeff_bsk_together.get(); + + // Base q + for (size_t i = 0; i < dest_count; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + multiply_poly_scalar_coeffmod( + tmp_des_coeff_base.get() + (j * coeff_count) + (i * encrypted_ptr_increment), + coeff_count, plain_modulus, coeff_modulus[j], + tmp_coeff_bsk_together_ptr + (j * coeff_count)); + } + tmp_coeff_bsk_together_ptr += encrypted_ptr_increment; + + for (size_t k = 0; k < bsk_base_mod_count; k++) + { + multiply_poly_scalar_coeffmod( + tmp_des_bsk_base.get() + (k * coeff_count) + (i * encrypted_bsk_ptr_increment), + coeff_count, plain_modulus, bsk_modulus[k], + tmp_coeff_bsk_together_ptr + (k * coeff_count)); + } + tmp_coeff_bsk_together_ptr += encrypted_bsk_ptr_increment; + } + + // Allocate a new poly for fast floor result in Bsk + auto tmp_result_bsk(allocate_poly( + coeff_count, dest_count * bsk_base_mod_count, pool)); + for (size_t i = 0; i < dest_count; i++) + { + // Step 3: fast floor from q U {Bsk} to Bsk + base_converter->fast_floor( + tmp_coeff_bsk_together.get() + + (i * (encrypted_ptr_increment + encrypted_bsk_ptr_increment)), + tmp_result_bsk.get() + (i * encrypted_bsk_ptr_increment), pool); + + // Step 4: fast base convert from Bsk to q + base_converter->fastbconv_sk( + tmp_result_bsk.get() + (i * encrypted_bsk_ptr_increment), + encrypted1.data(i), pool); + } + } + + void Evaluator::ckks_multiply(Ciphertext &encrypted1, + const Ciphertext &encrypted2, MemoryPoolHandle pool) + { + if (!(encrypted1.is_ntt_form() && encrypted2.is_ntt_form())) + { + throw invalid_argument("encrypted1 or encrypted2 must be in NTT form"); + } + + // Extract encryption parameters. + auto &context_data = *context_->context_data(encrypted1.parms_id()); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t encrypted1_size = encrypted1.size(); + size_t encrypted2_size = encrypted2.size(); + + double new_scale = encrypted1.scale() * encrypted2.scale(); + + // Check that scale is positive and not too large + if (new_scale <= 0 || (static_cast(log2(new_scale)) >= + context_data.total_coeff_modulus_bit_count())) + { + throw invalid_argument("scale out of bounds"); + } + + // Determine destination.size() + // Default is 3 (c_0, c_1, c_2) + size_t dest_count = sub_safe(add_safe(encrypted1_size, encrypted2_size), size_t(1)); + + // Size check + if (!product_fits_in(dest_count, coeff_count, coeff_mod_count)) + { + throw logic_error("invalid parameters"); + } + + // Prepare destination + encrypted1.resize(context_, parms.parms_id(), dest_count); + + //pointer increment to switch to a next polynomial + size_t encrypted_ptr_increment = coeff_count * coeff_mod_count; + + //Step 1: naive multiplication modulo the coefficient modulus + //First allocate two temp polys : + //one for results in base q. This need to be zero + //for the arbitrary size multiplication; not for 2x2 though + auto tmp_des(allocate_zero_poly( + coeff_count * dest_count, coeff_mod_count, pool)); + + //Allocate tmp polys for NTT multiplication results in base q + auto tmp1_poly(allocate_poly(coeff_count, coeff_mod_count, pool)); + auto tmp2_poly(allocate_poly(coeff_count, coeff_mod_count, pool)); + + // First convert all the inputs into NTT form + auto copy_encrypted1_ntt(allocate_poly( + coeff_count * encrypted1_size, coeff_mod_count, pool)); + set_poly_poly(encrypted1.data(), coeff_count * encrypted1_size, + coeff_mod_count, copy_encrypted1_ntt.get()); + + auto copy_encrypted2_ntt(allocate_poly( + coeff_count * encrypted2_size, coeff_mod_count, pool)); + set_poly_poly(encrypted2.data(), coeff_count * encrypted2_size, + coeff_mod_count, copy_encrypted2_ntt.get()); + + // Perform Karatsuba multiplication on size 2 ciphertexts + if (encrypted1_size == 2 && encrypted2_size == 2) + { + //Compute c0 + c1 and c0*d0 modulo q + //tmp poly to keep c0 * d0 + auto tmp_first_mul(allocate_poly(coeff_count, coeff_mod_count, pool)); + + uint64_t *temp_ptr_1 = tmp1_poly.get(); //pointer to the result of c0 + c1 in NTT + uint64_t *temp_ptr_2 = copy_encrypted1_ntt.get(); //Pointer to NTT version of c0 + uint64_t *temp_ptr_3 = temp_ptr_2 + encrypted_ptr_increment; //Pointer to NTT version of c1 + + for (size_t i = 0; i < coeff_mod_count; i++) + { + //Lazy reduction (c0 + c1) + for (size_t j = 0; j < coeff_count; j++) + { + *temp_ptr_1++ = *temp_ptr_2++ + *temp_ptr_3++; + } + //c0 * d0 in NTT + dyadic_product_coeffmod( + copy_encrypted1_ntt.get() + (i * coeff_count), + copy_encrypted2_ntt.get() + (i * coeff_count), + coeff_count, coeff_modulus[i], + tmp_first_mul.get() + (i * coeff_count)); + } + + //Compute d0 + d1 and c1 * d1 modulo q + //tmp poly to keep c1 * d1 + auto tmp_second_mul(allocate_poly(coeff_count, coeff_mod_count, pool)); + + // Compute d0 + d1 and c1*d1 in base q + temp_ptr_1 = tmp2_poly.get(); //Pointer to the result of d0 + d1 + temp_ptr_2 = copy_encrypted2_ntt.get(); //Pointer to d0 + temp_ptr_3 = temp_ptr_2 + encrypted_ptr_increment; //Pointer to d1 + + for (size_t i = 0; i < coeff_mod_count; i++) + { + //Lazy reduction (d0 + d1) + for (size_t j = 0; j < coeff_count; j++) + { + *temp_ptr_1++ = *temp_ptr_2++ + *temp_ptr_3++; + } + + //c1 * d1 in NTT + dyadic_product_coeffmod( + copy_encrypted1_ntt.get() + (i * coeff_count) + encrypted_ptr_increment, + copy_encrypted2_ntt.get() + (i * coeff_count) + encrypted_ptr_increment, + coeff_count, coeff_modulus[i], tmp_second_mul.get() + (i * coeff_count)); + } + + // Set destination of the first and third polys in base q + // Des[0] in base q + set_poly_poly(tmp_first_mul.get(), coeff_count, + coeff_mod_count, tmp_des.get()); + + // Des[2] in base q + set_poly_poly(tmp_second_mul.get(), coeff_count, + coeff_mod_count, tmp_des.get() + 2 * encrypted_ptr_increment); + + // Compute (c0 + c1) * (d0 + d1) - c0 * d0 - c1 * d1 modulo q + auto tmp_mul_poly(allocate_poly(coeff_count, coeff_mod_count, pool)); + + for (size_t i = 0; i < coeff_mod_count; i++) + { + // (c0 + c1) * (d0 + d1) in NTT + dyadic_product_coeffmod( + tmp1_poly.get() + (i * coeff_count), + tmp2_poly.get() + (i * coeff_count), + coeff_count, coeff_modulus[i], + tmp_mul_poly.get() + (i * coeff_count)); + // (c0 + c1) * (d0 + d1) - c0 * d0 in NTT + sub_poly_poly_coeffmod( + tmp_mul_poly.get() + (i * coeff_count), + tmp_first_mul.get() + (i * coeff_count), + coeff_count, coeff_modulus[i], + tmp_mul_poly.get() + (i * coeff_count)); + // (c0 + c1) * (d0 + d1) - c0 * d0 - c1 * d1 in NTT + // set the result to Des[1] + sub_poly_poly_coeffmod( + tmp_mul_poly.get() + (i * coeff_count), + tmp_second_mul.get() + (i * coeff_count), + coeff_count, coeff_modulus[i], + tmp_des.get() + (i * coeff_count) + encrypted_ptr_increment); + } + } + else + { + // Perform multiplication on arbitrary size ciphertexts + + // Loop over encrypted1 components [i], seeing if a match exists with an encrypted2 + // component [j] such that [i+j]=[secret_power_index] + // Only need to check encrypted1 components up to and including [secret_power_index], + // and strictly less than [encrypted_array.size()] + + // Number of encrypted1 components to check + size_t current_encrypted1_limit = 0; + + for (size_t secret_power_index = 0; + secret_power_index < dest_count; secret_power_index++) + { + current_encrypted1_limit = min(encrypted1_size, secret_power_index + 1); + + for (size_t encrypted1_index = 0; + encrypted1_index < current_encrypted1_limit; encrypted1_index++) + { + // check if a corresponding component in encrypted2 exists + if (encrypted2_size > secret_power_index - encrypted1_index) + { + size_t encrypted2_index = secret_power_index - encrypted1_index; + + // NTT Multiplication and addition for results in q + for (size_t i = 0; i < coeff_mod_count; i++) + { + // ci * dj + dyadic_product_coeffmod( + copy_encrypted1_ntt.get() + (i * coeff_count) + + (encrypted_ptr_increment * encrypted1_index), + copy_encrypted2_ntt.get() + (i * coeff_count) + + (encrypted_ptr_increment * encrypted2_index), + coeff_count, coeff_modulus[i], + tmp1_poly.get() + (i * coeff_count)); + // Dest[i+j] + add_poly_poly_coeffmod( + tmp1_poly.get() + (i * coeff_count), + tmp_des.get() + (i * coeff_count) + + (secret_power_index * coeff_count * coeff_mod_count), + coeff_count, coeff_modulus[i], + tmp_des.get() + (i * coeff_count) + + (secret_power_index * coeff_count * coeff_mod_count)); + } + } + } + } + } + + // Set the final result + set_poly_poly(tmp_des.get(), coeff_count * dest_count, + coeff_mod_count, encrypted1.data()); + + // Set the scale + encrypted1.scale() = new_scale; + } + + void Evaluator::square_inplace(Ciphertext &encrypted, MemoryPoolHandle pool) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypted is not valid for encryption parameters"); + } + + switch (context_data_ptr->parms().scheme()) + { + case scheme_type::BFV: + bfv_square(encrypted, move(pool)); + return; + + case scheme_type::CKKS: + ckks_square(encrypted, move(pool)); + return; + + default: + throw invalid_argument("unsupported scheme"); + } + } + + void Evaluator::bfv_square(Ciphertext &encrypted, MemoryPoolHandle pool) + { + if (encrypted.is_ntt_form()) + { + throw invalid_argument("encrypted cannot be in NTT form"); + } + + // Extract encryption parameters. + auto &context_data = *context_->context_data(encrypted.parms_id()); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t encrypted_size = encrypted.size(); + + uint64_t plain_modulus = parms.plain_modulus().value(); + auto &base_converter = context_data.base_converter(); + auto &bsk_modulus = base_converter->get_bsk_mod_array(); + size_t bsk_base_mod_count = base_converter->bsk_base_mod_count(); + size_t bsk_mtilde_count = add_safe(bsk_base_mod_count, size_t(1)); + auto &coeff_small_ntt_tables = context_data.small_ntt_tables(); + auto &bsk_small_ntt_tables = base_converter->get_bsk_small_ntt_tables(); + + // Optimization implemented currently only for size 2 ciphertexts + if (encrypted_size != 2) + { + bfv_multiply(encrypted, encrypted, move(pool)); + return; + } + + // Determine destination_array.size() + size_t dest_count = sub_safe(add_safe(encrypted_size, encrypted_size), size_t(1)); + + // Size check + if (!product_fits_in(dest_count, coeff_count, bsk_mtilde_count)) + { + throw logic_error("invalid parameters"); + } + + size_t encrypted_ptr_increment = coeff_count * coeff_mod_count; + size_t encrypted_bsk_mtilde_ptr_increment = coeff_count * bsk_mtilde_count; + size_t encrypted_bsk_ptr_increment = coeff_count * bsk_base_mod_count; + + // Prepare destination + encrypted.resize(context_, parms.parms_id(), dest_count); + + // Make temp poly for FastBConverter result from q ---> Bsk U {m_tilde} + auto tmp_encrypted_bsk_mtilde(allocate_poly( + coeff_count * encrypted_size, bsk_mtilde_count, pool)); + + // Make temp poly for FastBConverter result from Bsk U {m_tilde} -----> Bsk + auto tmp_encrypted_bsk(allocate_poly( + coeff_count * encrypted_size, bsk_base_mod_count, pool)); + + // Step 0: fast base convert from q to Bsk U {m_tilde} + // Step 1: reduce q-overflows in Bsk + // Iterate over all the ciphertexts inside encrypted1 + for (size_t i = 0; i < encrypted_size; i++) + { + base_converter->fastbconv_mtilde( + encrypted.data(i), + tmp_encrypted_bsk_mtilde.get() + + (i * encrypted_bsk_mtilde_ptr_increment), pool); + base_converter->mont_rq( + tmp_encrypted_bsk_mtilde.get() + + (i * encrypted_bsk_mtilde_ptr_increment), + tmp_encrypted_bsk.get() + (i * encrypted_bsk_ptr_increment)); + } + + // Step 2: compute product and multiply plain modulus to the result. + // We need to multiply both in q and Bsk. Values in encrypted_safe are + // in base q and values in tmp_encrypted_bsk are in base Bsk. We iterate + // over destination poly array and generate each poly based on the indices + // of inputs (arbitrary sizes for ciphertexts). First allocate two temp polys: + // one for results in base q and the other for the result in base Bsk. + auto tmp_des_coeff_base(allocate_poly( + coeff_count * dest_count, coeff_mod_count, pool)); + auto tmp_des_bsk_base(allocate_poly( + coeff_count * dest_count, bsk_base_mod_count, pool)); + + // First convert all the inputs into NTT form + auto copy_encrypted_ntt_coeff_mod(allocate_poly( + coeff_count * encrypted_size, coeff_mod_count, pool)); + set_poly_poly(encrypted.data(), coeff_count * encrypted_size, + coeff_mod_count, copy_encrypted_ntt_coeff_mod.get()); + + auto copy_encrypted_ntt_bsk_base_mod(allocate_poly( + coeff_count * encrypted_size, bsk_base_mod_count, pool)); + set_poly_poly(tmp_encrypted_bsk.get(), coeff_count * encrypted_size, + bsk_base_mod_count, copy_encrypted_ntt_bsk_base_mod.get()); + + for (size_t i = 0; i < encrypted_size; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + ntt_negacyclic_harvey_lazy( + copy_encrypted_ntt_coeff_mod.get() + (j * coeff_count) + + (i * encrypted_ptr_increment), coeff_small_ntt_tables[j]); + } + for (size_t j = 0; j < bsk_base_mod_count; j++) + { + ntt_negacyclic_harvey_lazy( + copy_encrypted_ntt_bsk_base_mod.get() + (j * coeff_count) + + (i * encrypted_bsk_ptr_increment), bsk_small_ntt_tables[j]); + } + } + + // Perform fast squaring + // Compute c0^2 in base q + for (size_t i = 0; i < coeff_mod_count; i++) + { + // Des[0] in q + dyadic_product_coeffmod( + copy_encrypted_ntt_coeff_mod.get() + (i * coeff_count), + copy_encrypted_ntt_coeff_mod.get() + (i * coeff_count), + coeff_count, coeff_modulus[i], + tmp_des_coeff_base.get() + (i * coeff_count)); + + // Des[2] in q + dyadic_product_coeffmod( + copy_encrypted_ntt_coeff_mod.get() + (i * coeff_count) + encrypted_ptr_increment, + copy_encrypted_ntt_coeff_mod.get() + (i * coeff_count) + encrypted_ptr_increment, + coeff_count, coeff_modulus[i], + tmp_des_coeff_base.get() + (i * coeff_count) + (2 * encrypted_ptr_increment)); + } + + // Compute c0^2 in base bsk + for (size_t i = 0; i < bsk_base_mod_count; i++) + { + // Des[0] in bsk + dyadic_product_coeffmod( + copy_encrypted_ntt_bsk_base_mod.get() + (i * coeff_count), + copy_encrypted_ntt_bsk_base_mod.get() + (i * coeff_count), + coeff_count, bsk_modulus[i], + tmp_des_bsk_base.get() + (i * coeff_count)); + + // Des[2] in bsk + dyadic_product_coeffmod( + copy_encrypted_ntt_bsk_base_mod.get() + (i * coeff_count) + encrypted_bsk_ptr_increment, + copy_encrypted_ntt_bsk_base_mod.get() + (i * coeff_count) + encrypted_bsk_ptr_increment, + coeff_count, bsk_modulus[i], + tmp_des_bsk_base.get() + (i * coeff_count) + (2 * encrypted_bsk_ptr_increment)); + } + + auto tmp_second_mul_coeff_base(allocate_poly(coeff_count, coeff_mod_count, pool)); + + // Compute 2*c0*c1 in base q + for (size_t i = 0; i < coeff_mod_count; i++) + { + dyadic_product_coeffmod( + copy_encrypted_ntt_coeff_mod.get() + (i * coeff_count), + copy_encrypted_ntt_coeff_mod.get() + (i * coeff_count) + encrypted_ptr_increment, + coeff_count, coeff_modulus[i], + tmp_second_mul_coeff_base.get() + (i * coeff_count)); + add_poly_poly_coeffmod( + tmp_second_mul_coeff_base.get() + (i * coeff_count), + tmp_second_mul_coeff_base.get() + (i * coeff_count), + coeff_count, coeff_modulus[i], + tmp_des_coeff_base.get() + (i * coeff_count) + encrypted_ptr_increment); + } + + auto tmp_second_mul_bsk_base(allocate_poly(coeff_count, bsk_base_mod_count, pool)); + + // Compute 2*c0*c1 in base bsk + for (size_t i = 0; i < bsk_base_mod_count; i++) + { + dyadic_product_coeffmod( + copy_encrypted_ntt_bsk_base_mod.get() + (i * coeff_count), + copy_encrypted_ntt_bsk_base_mod.get() + (i * coeff_count) + encrypted_bsk_ptr_increment, + coeff_count, bsk_modulus[i], + tmp_second_mul_bsk_base.get() + (i * coeff_count)); + add_poly_poly_coeffmod( + tmp_second_mul_bsk_base.get() + (i * coeff_count), + tmp_second_mul_bsk_base.get() + (i * coeff_count), + coeff_count, bsk_modulus[i], + tmp_des_bsk_base.get() + (i * coeff_count) + encrypted_bsk_ptr_increment); + } + + // Convert back outputs from NTT form + for (size_t i = 0; i < dest_count; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + inverse_ntt_negacyclic_harvey_lazy( + tmp_des_coeff_base.get() + (i * (encrypted_ptr_increment)) + (j * coeff_count), + coeff_small_ntt_tables[j]); + } + for (size_t j = 0; j < bsk_base_mod_count; j++) + { + inverse_ntt_negacyclic_harvey_lazy( + tmp_des_bsk_base.get() + (i * (encrypted_bsk_ptr_increment)) + + (j * coeff_count), bsk_small_ntt_tables[j]); + } + } + + // Now we multiply plain modulus to both results in base q and Bsk and + // allocate them together in one container as (te0)q(te'0)Bsk | ... |te count)q (te' count)Bsk + // to make it ready for fast_floor + auto tmp_coeff_bsk_together(allocate_poly( + coeff_count, dest_count * (coeff_mod_count + bsk_base_mod_count), pool)); + uint64_t *tmp_coeff_bsk_together_ptr = tmp_coeff_bsk_together.get(); + + // Base q + for (size_t i = 0; i < dest_count; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + multiply_poly_scalar_coeffmod( + tmp_des_coeff_base.get() + (j * coeff_count) + (i * encrypted_ptr_increment), + coeff_count, plain_modulus, coeff_modulus[j], + tmp_coeff_bsk_together_ptr + (j * coeff_count)); + } + tmp_coeff_bsk_together_ptr += encrypted_ptr_increment; + + for (size_t k = 0; k < bsk_base_mod_count; k++) + { + multiply_poly_scalar_coeffmod( + tmp_des_bsk_base.get() + (k * coeff_count) + (i * encrypted_bsk_ptr_increment), + coeff_count, plain_modulus, bsk_modulus[k], + tmp_coeff_bsk_together_ptr + (k * coeff_count)); + } + tmp_coeff_bsk_together_ptr += encrypted_bsk_ptr_increment; + } + + // Allocate a new poly for fast floor result in Bsk + auto tmp_result_bsk(allocate_poly(coeff_count, dest_count * bsk_base_mod_count, pool)); + for (size_t i = 0; i < dest_count; i++) + { + // Step 3: fast floor from q U {Bsk} to Bsk + base_converter->fast_floor( + tmp_coeff_bsk_together.get() + (i * (encrypted_ptr_increment + encrypted_bsk_ptr_increment)), + tmp_result_bsk.get() + (i * encrypted_bsk_ptr_increment), pool); + + // Step 4: fast base convert from Bsk to q + base_converter->fastbconv_sk( + tmp_result_bsk.get() + (i * encrypted_bsk_ptr_increment), encrypted.data(i), pool); + } + } + + void Evaluator::ckks_square(Ciphertext &encrypted, MemoryPoolHandle pool) + { + if (!encrypted.is_ntt_form()) + { + throw invalid_argument("encrypted must be in NTT form"); + } + + // Extract encryption parameters. + auto &context_data = *context_->context_data(encrypted.parms_id()); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t encrypted_size = encrypted.size(); + + double new_scale = encrypted.scale() * encrypted.scale(); + + // Check that scale is positive and not too large + if (new_scale <= 0 || (static_cast(log2(new_scale)) >= + context_data.total_coeff_modulus_bit_count())) + { + throw invalid_argument("scale out of bounds"); + } + + // Determine destination.size() + // Default is 3 (c_0, c_1, c_2) + size_t dest_count = sub_safe(add_safe(encrypted_size, encrypted_size), size_t(1)); + + // Size check + if (!product_fits_in(dest_count, coeff_count, coeff_mod_count)) + { + throw logic_error("invalid parameters"); + } + + // Prepare destination + encrypted.resize(context_, parms.parms_id(), dest_count); + + //pointer increment to switch to a next polynomial + size_t encrypted_ptr_increment = coeff_count * coeff_mod_count; + + //Step 1: naive multiplication modulo the coefficient modulus + //First allocate two temp polys : + //one for results in base q. This need to be zero + //for the arbitrary size multiplication; not for 2x2 though + auto tmp_des(allocate_zero_poly( + coeff_count * dest_count, coeff_mod_count, pool)); + + //Allocate tmp polys for NTT multiplication results in base q + auto tmp1_poly(allocate_poly(coeff_count, coeff_mod_count, pool)); + auto tmp2_poly(allocate_poly(coeff_count, coeff_mod_count, pool)); + + // First convert all the inputs into NTT form + auto copy_encrypted_ntt(allocate_poly( + coeff_count * encrypted_size, coeff_mod_count, pool)); + set_poly_poly(encrypted.data(), coeff_count * encrypted_size, + coeff_mod_count, copy_encrypted_ntt.get()); + + // The simplest case when the ciphertext dimension is 2 + if (encrypted_size == 2) + { + //Compute c0^2, 2*c0 + c1 and c1^2 modulo q + //tmp poly to keep 2 * c0 * c1 + auto tmp_second_mul(allocate_poly(coeff_count, coeff_mod_count, pool)); + + for (size_t i = 0; i < coeff_mod_count; i++) + { + //Des[0] = c0^2 in NTT + dyadic_product_coeffmod( + copy_encrypted_ntt.get() + (i * coeff_count), + copy_encrypted_ntt.get() + (i * coeff_count), + coeff_count, coeff_modulus[i], + tmp_des.get() + (i * coeff_count)); + + //Des[1] = 2 * c0 * c1 + dyadic_product_coeffmod( + copy_encrypted_ntt.get() + (i * coeff_count), + copy_encrypted_ntt.get() + (i * coeff_count) + encrypted_ptr_increment, + coeff_count, coeff_modulus[i], + tmp_second_mul.get() + (i * coeff_count)); + add_poly_poly_coeffmod( + tmp_second_mul.get() + (i * coeff_count), + tmp_second_mul.get() + (i * coeff_count), + coeff_count, coeff_modulus[i], + tmp_des.get() + (i * coeff_count) + encrypted_ptr_increment); + + //Des[2] = c1^2 in NTT + dyadic_product_coeffmod( + copy_encrypted_ntt.get() + (i * coeff_count) + encrypted_ptr_increment, + copy_encrypted_ntt.get() + (i * coeff_count) + encrypted_ptr_increment, + coeff_count, coeff_modulus[i], + tmp_des.get() + (i * coeff_count) + (2 * encrypted_ptr_increment)); + } + } + else + { + // Perform multiplication on arbitrary size ciphertexts + + // Loop over encrypted1 components [i], seeing if a match exists with an encrypted2 + // component [j] such that [i+j]=[secret_power_index] + // Only need to check encrypted1 components up to and including [secret_power_index], + // and strictly less than [encrypted_array.size()] + + // Number of encrypted1 components to check + size_t current_encrypted_limit = 0; + + for (size_t secret_power_index = 0; secret_power_index < dest_count; secret_power_index++) + { + current_encrypted_limit = min(encrypted_size, secret_power_index + 1); + + for (size_t encrypted1_index = 0; encrypted1_index < current_encrypted_limit; + encrypted1_index++) + { + // check if a corresponding component in encrypted2 exists + if (encrypted_size > secret_power_index - encrypted1_index) + { + size_t encrypted2_index = secret_power_index - encrypted1_index; + + // NTT Multiplication and addition for results in q + for (size_t i = 0; i < coeff_mod_count; i++) + { + // ci * dj + dyadic_product_coeffmod( + copy_encrypted_ntt.get() + (i * coeff_count) + + (encrypted_ptr_increment * encrypted1_index), + copy_encrypted_ntt.get() + (i * coeff_count) + + (encrypted_ptr_increment * encrypted2_index), + coeff_count, coeff_modulus[i], + tmp1_poly.get() + (i * coeff_count)); + // Dest[i+j] + add_poly_poly_coeffmod( + tmp1_poly.get() + (i * coeff_count), + tmp_des.get() + (i * coeff_count) + + (secret_power_index * coeff_count * coeff_mod_count), + coeff_count, coeff_modulus[i], + tmp_des.get() + (i * coeff_count) + + (secret_power_index * coeff_count * coeff_mod_count)); + } + } + } + } + } + + // Set the final result + set_poly_poly(tmp_des.get(), coeff_count * dest_count, coeff_mod_count, encrypted.data()); + + // Set the scale + encrypted.scale() = new_scale; + } + + void Evaluator::relinearize_internal(Ciphertext &encrypted, + const RelinKeys &relin_keys, size_t destination_size, + MemoryPoolHandle pool) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypted is not valid for encryption parameters"); + } + if (relin_keys.parms_id() != context_->first_parms_id()) + { + throw invalid_argument("parameter mismatch"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + + // Extract encryption parameters. + auto &context_data = *context_data_ptr; + auto &parms = context_data.parms(); + size_t encrypted_size = encrypted.size(); + + // Verify parameters. + if (destination_size < 2 || destination_size > encrypted_size) + { + throw invalid_argument("destination_size must be at least 2 and less than or equal to current count"); + } + if (relin_keys.size() < sub_safe(encrypted_size, size_t(2))) + { + throw invalid_argument("not enough relinearization keys"); + } + + // If encrypted is already at the desired level, return + if (destination_size == encrypted_size) + { + return; + } + + // Calculate number of relinearize_one_step calls needed + size_t relins_needed = encrypted_size - destination_size; + + // Update temp to store the current result after relinearization + switch (context_data_ptr->parms().scheme()) + { + case scheme_type::BFV: + { + if (encrypted.is_ntt_form()) + { + throw invalid_argument("BFV encrypted cannot be in NTT form"); + } + for (size_t i = 0; i < relins_needed; i++) + { + bfv_relinearize_one_step(encrypted.data(), encrypted_size, + context_data, relin_keys, pool); + encrypted_size--; + } + break; + } + + case scheme_type::CKKS: + { + if (!encrypted.is_ntt_form()) + { + throw invalid_argument("CKKS encrypted must be in NTT form"); + } + for (size_t i = 0; i < relins_needed; i++) + { + ckks_relinearize_one_step(encrypted.data(), encrypted_size, + context_data, relin_keys, pool); + encrypted_size--; + } + break; + } + + default: + throw invalid_argument("unsupported scheme"); + } + + // Put the output of final relinearization into destination. + // Prepare destination only at this point because we are resizing down + encrypted.resize(context_, parms.parms_id(), destination_size); + } + + void Evaluator::bfv_relinearize_one_step(uint64_t *encrypted, + size_t encrypted_size, const SEALContext::ContextData &context_data, + const RelinKeys &relin_keys, MemoryPool &pool) + { + // Extract encryption parameters. + // Parameters corresponding to the ciphertext level + auto &parms = context_data.parms(); + + // q_l corresponding to the ciphertext level + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + + // number of factors in q_l + size_t coeff_mod_count = coeff_modulus.size(); + + // Size test + if (!product_fits_in(encrypted_size, coeff_count, coeff_mod_count)) + { + throw logic_error("invalid parameters"); + } + + // n * number of factors in q_l + size_t rns_poly_uint64_count = coeff_count * coeff_mod_count; +#ifdef SEAL_DEBUG + if (encrypted == nullptr) + { + throw invalid_argument("encrypted cannot be null"); + } + if (encrypted_size <= 2) + { + throw invalid_argument("encrypted_size must be at least 3"); + } + if (relin_keys.size() < sub_safe(encrypted_size, size_t(2))) + { + throw invalid_argument("not enough relinearization keys"); + } +#endif + // q/qi mod qi + auto &first_context_data = *context_->context_data(); + auto &inv_coeff_products_mod_coeff_array = + first_context_data.base_converter()->get_inv_coeff_mod_coeff_array(); + auto &coeff_small_ntt_tables = first_context_data.small_ntt_tables(); + + // Decompose encrypted_array[count-1] into base w + // Want to create an array of polys, each of whose components i is + // (encrypted_array[count-1])^(i) - in the notation of FV paper. + // This allocation stores one of the decomposed factors modulo one of the primes. + auto decomp_encrypted_last(allocate_uint(coeff_count, pool)); + + // Lazy reduction + auto wide_innerresult0(allocate_zero_poly(coeff_count, 2 * coeff_mod_count, pool)); + auto wide_innerresult1(allocate_zero_poly(coeff_count, 2 * coeff_mod_count, pool)); + auto innerresult(allocate_poly(coeff_count, coeff_mod_count, pool)); + auto temp_decomp_coeff(allocate_uint(coeff_count, pool)); + + /* + For lazy reduction to work here, we need to ensure that the 128-bit accumulators + (wide_innerresult0 and wide_innerresult1) do not overflow. Since the modulus primes + are at most 60 bits, if the total number of summands is K, then the size of the + total sum of products (without reduction) is at most 62 + 60 + bit_length(K). + We need this to be at most 128, thus we need bit_length(K) <= 6. Thus, we need K <= 63. + In this case, this means sum_i relin_keys.data()[encrypted_size - 3][i].size() / 2 <= 63. + */ + const uint64_t *encrypted_coeff = encrypted + (encrypted_size - 1) * rns_poly_uint64_count; + auto encrypted_coeff_prod_inv_coeff(allocate_uint(coeff_count, pool)); + + for (size_t i = 0; i < coeff_mod_count; i++) + { + multiply_poly_scalar_coeffmod( + encrypted_coeff + (i * coeff_count), coeff_count, + inv_coeff_products_mod_coeff_array[i], coeff_modulus[i], + encrypted_coeff_prod_inv_coeff.get()); + + int shift = 0; + auto &key_component_ref = relin_keys.data()[encrypted_size - 3][i]; + size_t keys_size = key_component_ref.size(); + for (size_t k = 0; k < keys_size; k += 2) + { + const uint64_t *key_ptr_0 = key_component_ref.data(k); + const uint64_t *key_ptr_1 = key_component_ref.data(k + 1); + + // Decompose here + int decomposition_bit_count = relin_keys.decomposition_bit_count(); + for (size_t coeff_index = 0; coeff_index < coeff_count; coeff_index++) + { + decomp_encrypted_last[coeff_index] = + encrypted_coeff_prod_inv_coeff[coeff_index] >> shift; + decomp_encrypted_last[coeff_index] &= + (uint64_t(1) << decomposition_bit_count) - 1; + } + + uint64_t *wide_innerresult0_ptr = wide_innerresult0.get(); + uint64_t *wide_innerresult1_ptr = wide_innerresult1.get(); + for (size_t j = 0; j < coeff_mod_count; j++) + { + uint64_t *temp_decomp_coeff_ptr = temp_decomp_coeff.get(); + set_uint_uint(decomp_encrypted_last.get(), coeff_count, temp_decomp_coeff_ptr); + + // We don't reduce here, so might get up to two extra bits. Thus 62 bits at most. + ntt_negacyclic_harvey_lazy(temp_decomp_coeff_ptr, coeff_small_ntt_tables[j]); + + // Lazy reduction + unsigned long long wide_innerproduct[2]; + unsigned long long temp; + for (size_t m = 0; m < coeff_count; m++, wide_innerresult0_ptr += 2) + { + multiply_uint64(*temp_decomp_coeff_ptr++, *key_ptr_0++, wide_innerproduct); + unsigned char carry = add_uint64(wide_innerresult0_ptr[0], + wide_innerproduct[0], &temp); + wide_innerresult0_ptr[0] = temp; + wide_innerresult0_ptr[1] += wide_innerproduct[1] + carry; + } + + temp_decomp_coeff_ptr = temp_decomp_coeff.get(); + for (size_t m = 0; m < coeff_count; m++, wide_innerresult1_ptr += 2) + { + multiply_uint64(*temp_decomp_coeff_ptr++, *key_ptr_1++, wide_innerproduct); + unsigned char carry = add_uint64(wide_innerresult1_ptr[0], + wide_innerproduct[0], &temp); + wide_innerresult1_ptr[0] = temp; + wide_innerresult1_ptr[1] += wide_innerproduct[1] + carry; + } + } + shift += decomposition_bit_count; + } + } + + uint64_t *innerresult_poly_ptr = innerresult.get(); + uint64_t *wide_innerresult_poly_ptr = wide_innerresult0.get(); + uint64_t *encrypted_ptr = encrypted; + uint64_t *innerresult_coeff_ptr = innerresult_poly_ptr; + uint64_t *wide_innerresult_coeff_ptr = wide_innerresult_poly_ptr; + for (size_t i = 0; i < coeff_mod_count; i++, innerresult_poly_ptr += coeff_count, + wide_innerresult_poly_ptr += 2 * coeff_count, encrypted_ptr += coeff_count) + { + for (size_t m = 0; m < coeff_count; m++, wide_innerresult_coeff_ptr += 2) + { + *innerresult_coeff_ptr++ = barrett_reduce_128( + wide_innerresult_coeff_ptr, coeff_modulus[i]); + } + inverse_ntt_negacyclic_harvey(innerresult_poly_ptr, coeff_small_ntt_tables[i]); + add_poly_poly_coeffmod(encrypted_ptr, innerresult_poly_ptr, coeff_count, + coeff_modulus[i], encrypted_ptr); + } + + innerresult_poly_ptr = innerresult.get(); + wide_innerresult_poly_ptr = wide_innerresult1.get(); + encrypted_ptr = encrypted + rns_poly_uint64_count; + innerresult_coeff_ptr = innerresult_poly_ptr; + wide_innerresult_coeff_ptr = wide_innerresult_poly_ptr; + for (size_t i = 0; i < coeff_mod_count; i++, innerresult_poly_ptr += coeff_count, + wide_innerresult_poly_ptr += 2 * coeff_count, encrypted_ptr += coeff_count) + { + for (size_t m = 0; m < coeff_count; m++, wide_innerresult_coeff_ptr += 2) + { + *innerresult_coeff_ptr++ = barrett_reduce_128( + wide_innerresult_coeff_ptr, coeff_modulus[i]); + } + inverse_ntt_negacyclic_harvey(innerresult_poly_ptr, coeff_small_ntt_tables[i]); + add_poly_poly_coeffmod(encrypted_ptr, innerresult_poly_ptr, coeff_count, + coeff_modulus[i], encrypted_ptr); + } + } + + void Evaluator::ckks_relinearize_one_step(uint64_t *encrypted, + size_t encrypted_size, const SEALContext::ContextData &context_data, + const RelinKeys &relin_keys, MemoryPool &pool) + { + // Extract encryption parameters. + // Parameters corresponding to the ciphertext level + auto &parms = context_data.parms(); + + // q_l corresponding to the ciphertext level + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + + // number of factors in q_l + size_t coeff_mod_count = coeff_modulus.size(); + + // Size test + if (!product_fits_in(encrypted_size, coeff_count, coeff_mod_count)) + { + throw logic_error("invalid parameters"); + } + + // n * number of factors in q_l + size_t rns_poly_uint64_count = coeff_count * coeff_mod_count; +#ifdef SEAL_DEBUG + if (encrypted == nullptr) + { + throw invalid_argument("encrypted cannot be null"); + } + if (encrypted_size <= 2) + { + throw invalid_argument("encrypted_size must be at least 3"); + } + if (relin_keys.size() < sub_safe(encrypted_size, size_t(2))) + { + throw invalid_argument("not enough evaluation keys"); + } +#endif + // q/qi mod qi + auto &first_context_data = *context_->context_data(); + auto &inv_coeff_products_mod_coeff_array = + first_context_data.base_converter()->get_inv_coeff_mod_coeff_array(); + auto &coeff_small_ntt_tables = first_context_data.small_ntt_tables(); + + // Decompose encrypted_array[count-1] into base w + // Want to create an array of polys, each of whose components i is + // (encrypted_array[count-1])^(i) - in the notation of FV paper. + // This allocation stores one of the decomposed factors modulo one of the primes. + auto decomp_encrypted_last(allocate_uint(coeff_count, pool)); + + // Lazy reduction + auto wide_innerresult0(allocate_zero_poly(coeff_count, 2 * coeff_mod_count, pool)); + auto wide_innerresult1(allocate_zero_poly(coeff_count, 2 * coeff_mod_count, pool)); + auto innerresult(allocate_poly(coeff_count, coeff_mod_count, pool)); + auto temp_decomp_coeff(allocate_uint(coeff_count, pool)); + + /* + For lazy reduction to work here, we need to ensure that the 128-bit accumulators + (wide_innerresult0 and wide_innerresult1) do not overflow. Since the modulus primes + are at most 60 bits, if the total number of summands is K, then the size of the + total sum of products (without reduction) is at most 62 + 60 + bit_length(K). + We need this to be at most 128, thus we need bit_length(K) <= 6. Thus, we need K <= 63. + In this case, this means sum_i evaluation_keys.data()[encrypted_size - 3][i].size() / 2 <= 63. + */ + uint64_t *encrypted_coeff = encrypted + (encrypted_size - 1) * rns_poly_uint64_count; + auto encrypted_coeff_prod_inv_coeff(allocate_uint(coeff_count, pool)); + + // inner product of evaluation keys and the bit-decomposition of the last ciphertext polynomial + for (size_t i = 0; i < coeff_mod_count; i++) + { + // Convert the last polynomial of encrypted from NTT to create a bit-decomposition + inverse_ntt_negacyclic_harvey( + encrypted_coeff + (i * coeff_count), coeff_small_ntt_tables[i]); + // c*(q_i/q) mod q_i + multiply_poly_scalar_coeffmod( + encrypted_coeff + (i * coeff_count), coeff_count, + inv_coeff_products_mod_coeff_array[i], coeff_modulus[i], + encrypted_coeff_prod_inv_coeff.get()); + + int shift = 0; + auto &key_component_ref = relin_keys.data()[encrypted_size - 3][i]; + size_t keys_size = key_component_ref.size(); + for (size_t k = 0; k < keys_size; k += 2) + { + const uint64_t *key_ptr_0 = key_component_ref.data(k); + const uint64_t *key_ptr_1 = key_component_ref.data(k + 1); + + // Decompose here + int decomposition_bit_count = relin_keys.decomposition_bit_count(); + for (size_t coeff_index = 0; coeff_index < coeff_count; coeff_index++) + { + decomp_encrypted_last[coeff_index] = + encrypted_coeff_prod_inv_coeff[coeff_index] >> shift; + decomp_encrypted_last[coeff_index] &= + (uint64_t(1) << decomposition_bit_count) - 1; + } + + uint64_t *wide_innerresult0_ptr = wide_innerresult0.get(); + uint64_t *wide_innerresult1_ptr = wide_innerresult1.get(); + for (size_t j = 0; j < coeff_mod_count; j++) + { + uint64_t *temp_decomp_coeff_ptr = temp_decomp_coeff.get(); + set_uint_uint(decomp_encrypted_last.get(), coeff_count, temp_decomp_coeff_ptr); + + // We don't reduce here, so might get up to two extra bits. Thus 62 bits at most. + ntt_negacyclic_harvey_lazy(temp_decomp_coeff_ptr, coeff_small_ntt_tables[j]); + + // Lazy reduction + unsigned long long wide_innerproduct[2]; + unsigned long long temp; + for (size_t m = 0; m < coeff_count; m++, wide_innerresult0_ptr += 2) + { + multiply_uint64(*temp_decomp_coeff_ptr++, *key_ptr_0++, wide_innerproduct); + unsigned char carry = add_uint64(wide_innerresult0_ptr[0], + wide_innerproduct[0], &temp); + wide_innerresult0_ptr[0] = temp; + wide_innerresult0_ptr[1] += wide_innerproduct[1] + carry; + } + + temp_decomp_coeff_ptr = temp_decomp_coeff.get(); + for (size_t m = 0; m < coeff_count; m++, wide_innerresult1_ptr += 2) + { + multiply_uint64(*temp_decomp_coeff_ptr++, *key_ptr_1++, wide_innerproduct); + unsigned char carry = add_uint64(wide_innerresult1_ptr[0], + wide_innerproduct[0], &temp); + wide_innerresult1_ptr[0] = temp; + wide_innerresult1_ptr[1] += wide_innerproduct[1] + carry; + } + } + shift += decomposition_bit_count; + } + } + + uint64_t *innerresult_poly_ptr = innerresult.get(); + uint64_t *wide_innerresult_poly_ptr = wide_innerresult0.get(); + uint64_t *encrypted_ptr = encrypted; + uint64_t *innerresult_coeff_ptr = innerresult_poly_ptr; + uint64_t *wide_innerresult_coeff_ptr = wide_innerresult_poly_ptr; + for (size_t i = 0; i < coeff_mod_count; i++, innerresult_poly_ptr += coeff_count, + wide_innerresult_poly_ptr += 2 * coeff_count, encrypted_ptr += coeff_count) + { + for (size_t m = 0; m < coeff_count; m++, wide_innerresult_coeff_ptr += 2) + { + *innerresult_coeff_ptr++ = barrett_reduce_128( + wide_innerresult_coeff_ptr, coeff_modulus[i]); + } + add_poly_poly_coeffmod(encrypted_ptr, innerresult_poly_ptr, coeff_count, + coeff_modulus[i], encrypted_ptr); + } + + innerresult_poly_ptr = innerresult.get(); + wide_innerresult_poly_ptr = wide_innerresult1.get(); + encrypted_ptr = encrypted + rns_poly_uint64_count; + innerresult_coeff_ptr = innerresult_poly_ptr; + wide_innerresult_coeff_ptr = wide_innerresult_poly_ptr; + for (size_t i = 0; i < coeff_mod_count; i++, innerresult_poly_ptr += coeff_count, + wide_innerresult_poly_ptr += 2 * coeff_count, encrypted_ptr += coeff_count) + { + for (size_t m = 0; m < coeff_count; m++, wide_innerresult_coeff_ptr += 2) + { + *innerresult_coeff_ptr++ = barrett_reduce_128( + wide_innerresult_coeff_ptr, coeff_modulus[i]); + } + add_poly_poly_coeffmod(encrypted_ptr, innerresult_poly_ptr, coeff_count, + coeff_modulus[i], encrypted_ptr); + } + } + + void Evaluator::mod_switch_scale_to_next(const Ciphertext &encrypted, + Ciphertext &destination, MemoryPoolHandle pool) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + if (context_data_ptr->parms().scheme() == scheme_type::BFV && + encrypted.is_ntt_form()) + { + throw invalid_argument("BFV encrypted cannot be in NTT form"); + } + if (context_data_ptr->parms().scheme() == scheme_type::CKKS && + !encrypted.is_ntt_form()) + { + throw invalid_argument("CKKS encrypted must be in NTT form"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + + // Extract encryption parameters. + auto &context_data = *context_data_ptr; + auto &next_parms = context_data.next_context_data()->parms(); + + // q_1,...,q_{k-1} + auto &next_coeff_modulus = next_parms.coeff_modulus(); + size_t next_coeff_mod_count = next_coeff_modulus.size(); + size_t coeff_count = next_parms.poly_modulus_degree(); + size_t encrypted_size = encrypted.size(); + auto &inv_last_coeff_mod_array = + context_data.base_converter()->get_inv_last_coeff_mod_array(); + + // Size test + if (!product_fits_in(coeff_count, encrypted_size, next_coeff_mod_count)) + { + throw logic_error("invalid parameters"); + } + + // In CKKS need to transform away from NTT form + Ciphertext encrypted_copy(pool); + encrypted_copy = encrypted; + if (next_parms.scheme() == scheme_type::CKKS) + { + transform_from_ntt_inplace(encrypted_copy); + } + + auto temp1(allocate_uint(coeff_count, pool)); + + // Allocate enough room for the result + auto temp2(allocate_poly(coeff_count * encrypted_size, next_coeff_mod_count, pool)); + auto temp2_ptr = temp2.get(); + + for (size_t poly_index = 0; poly_index < encrypted_size; poly_index++) + { + // Set temp1 to ct mod qk + set_uint_uint(encrypted_copy.data(poly_index) + next_coeff_mod_count * coeff_count, + coeff_count, temp1.get()); + for (size_t mod_index = 0; mod_index < next_coeff_mod_count; mod_index++, + temp2_ptr += coeff_count) + { + // (ct mod qk) mod qi + modulo_poly_coeffs(temp1.get(), coeff_count, + next_coeff_modulus[mod_index], temp2_ptr); + // (-(ct mod qk)) mod qi + negate_poly_coeffmod(temp2_ptr, coeff_count, + next_coeff_modulus[mod_index], temp2_ptr); + // ((ct mod qi) - (ct mod qk)) mod qi + add_poly_poly_coeffmod( + encrypted_copy.data(poly_index) + mod_index * coeff_count, temp2_ptr, + coeff_count, next_coeff_modulus[mod_index], temp2_ptr); + // qk^(-1) * ((ct mod qi) - (ct mod qk)) mod qi + multiply_poly_scalar_coeffmod(temp2_ptr, coeff_count, + inv_last_coeff_mod_array[mod_index], + next_coeff_modulus[mod_index], temp2_ptr); + } + } + + // Resize destination + destination.resize(context_, next_parms.parms_id(), encrypted_size); + destination.is_ntt_form() = false; + + set_poly_poly(temp2.get(), coeff_count * encrypted_size, next_coeff_mod_count, + destination.data()); + + // In CKKS need to transform back to NTT form + if (next_parms.scheme() == scheme_type::CKKS) + { + transform_to_ntt_inplace(destination); + + // Also change the scale + destination.scale() = encrypted.scale() / + static_cast(context_data.parms().coeff_modulus().back().value()); + } + } + + void Evaluator::mod_switch_drop_to_next(const Ciphertext &encrypted, + Ciphertext &destination) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + if (context_data_ptr->parms().scheme() == scheme_type::CKKS && + !encrypted.is_ntt_form()) + { + throw invalid_argument("CKKS encrypted must be in NTT form"); + } + + // Extract encryption parameters. + auto &next_context_data = *context_data_ptr->next_context_data(); + auto &next_parms = next_context_data.parms(); + + // Check that scale is positive and not too large + if (encrypted.scale() <= 0 || (static_cast(log2(encrypted.scale())) >= + next_context_data.total_coeff_modulus_bit_count())) + { + throw invalid_argument("scale out of bounds"); + } + + // q_1,...,q_{k-1} + size_t next_coeff_mod_count = next_parms.coeff_modulus().size(); + size_t coeff_count = next_parms.poly_modulus_degree(); + size_t encrypted_size = encrypted.size(); + + // Size check + if (!product_fits_in(encrypted_size, coeff_count, next_coeff_mod_count)) + { + throw logic_error("invalid parameters"); + } + + size_t rns_poly_total_count = next_coeff_mod_count * coeff_count; + + for (size_t i = 0; i < encrypted_size; i++) + { + for (size_t j = 0; j < next_coeff_mod_count; j++) + { + set_uint_uint(encrypted.data(i) + (j * coeff_count), coeff_count, + destination.data() + (i * rns_poly_total_count) + (j * coeff_count)); + } + } + + // Resize destination + destination.resize(context_, next_parms.parms_id(), encrypted_size); + destination.is_ntt_form() = true; + destination.scale() = encrypted.scale(); + } + + void Evaluator::mod_switch_drop_to_next(Plaintext &plain) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(plain.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } + if (!plain.is_ntt_form()) + { + throw invalid_argument("plain is not in NTT form"); + } + if (!context_data_ptr->next_context_data()) + { + throw invalid_argument("end of modulus switching chain reached"); + } + + // Extract encryption parameters. + auto &next_context_data = *context_data_ptr->next_context_data(); + auto &next_parms = context_data_ptr->next_context_data()->parms(); + + // Check that scale is positive and not too large + if (plain.scale() <= 0 || (static_cast(log2(plain.scale())) >= + next_context_data.total_coeff_modulus_bit_count())) + { + throw invalid_argument("scale out of bounds"); + } + + // q_1,...,q_{k-1} + auto &next_coeff_modulus = next_parms.coeff_modulus(); + size_t next_coeff_mod_count = next_coeff_modulus.size(); + size_t coeff_count = next_parms.poly_modulus_degree(); + + // Compute destination size first for exception safety + auto dest_size = mul_safe(next_coeff_mod_count, coeff_count); + + plain.parms_id() = parms_id_zero; + plain.resize(dest_size); + plain.parms_id() = next_parms.parms_id(); + } + + void Evaluator::mod_switch_to_next(const Ciphertext &encrypted, + Ciphertext &destination, MemoryPoolHandle pool) + { + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypted is not valid for encryption parameters"); + } + if (context_->last_parms_id() == encrypted.parms_id()) + { + throw invalid_argument("end of modulus switching chain reached"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + if (encrypted.size() > 2) + { + throw invalid_argument("encrypted size must be 2"); + } + + switch (context_->context_data()->parms().scheme()) + { + case scheme_type::BFV: + // Modulus switching with scaling + mod_switch_scale_to_next(encrypted, destination, move(pool)); + return; + + case scheme_type::CKKS: + // Modulus switching without scaling + mod_switch_drop_to_next(encrypted, destination); + return; + + default: + throw invalid_argument("unsupported scheme"); + } + } + + void Evaluator::mod_switch_to_inplace(Ciphertext &encrypted, + parms_id_type parms_id, MemoryPoolHandle pool) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + auto target_context_data_ptr = context_->context_data(parms_id); + if (!context_data_ptr) + { + throw invalid_argument("encrypted is not valid for encryption parameters"); + } + if (!target_context_data_ptr) + { + throw invalid_argument("parms_id is not valid for encryption parameters"); + } + if (context_data_ptr->chain_index() < target_context_data_ptr->chain_index()) + { + throw invalid_argument("cannot switch to higher level modulus"); + } + + while (encrypted.parms_id() != parms_id) + { + mod_switch_to_next_inplace(encrypted, pool); + } + } + + void Evaluator::mod_switch_to_inplace(Plaintext &plain, parms_id_type parms_id) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(plain.parms_id()); + auto target_context_data_ptr = context_->context_data(parms_id); + if (!context_data_ptr) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } + if (!context_->context_data(parms_id)) + { + throw invalid_argument("parms_id is not valid for encryption parameters"); + } + if (!plain.is_ntt_form()) + { + throw invalid_argument("plain is not in NTT form"); + } + if (context_data_ptr->chain_index() < target_context_data_ptr->chain_index()) + { + throw invalid_argument("cannot switch to higher level modulus"); + } + + while (plain.parms_id() != parms_id) + { + mod_switch_to_next_inplace(plain); + } + } + + void Evaluator::rescale_to_next(const Ciphertext &encrypted, Ciphertext &destination, + MemoryPoolHandle pool) + { + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypted is not valid for encryption parameters"); + } + if (context_->last_parms_id() == encrypted.parms_id()) + { + throw invalid_argument("end of modulus switching chain reached"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + if (encrypted.size() > 2) + { + throw invalid_argument("encrypted size must be 2"); + } + + switch (context_->context_data()->parms().scheme()) + { + case scheme_type::BFV: + throw invalid_argument("unsupported operation for scheme type"); + + case scheme_type::CKKS: + // Modulus switching with scaling + mod_switch_scale_to_next(encrypted, destination, move(pool)); + return; + + default: + throw invalid_argument("unsupported scheme"); + } + } + + void Evaluator::rescale_to_inplace(Ciphertext &encrypted, parms_id_type parms_id, + MemoryPoolHandle pool) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + auto target_context_data_ptr = context_->context_data(parms_id); + if (!context_data_ptr) + { + throw invalid_argument("encrypted is not valid for encryption parameters"); + } + if (!target_context_data_ptr) + { + throw invalid_argument("parms_id is not valid for encryption parameters"); + } + if (context_data_ptr->chain_index() < target_context_data_ptr->chain_index()) + { + throw invalid_argument("cannot switch to higher level modulus"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + + switch (context_data_ptr->parms().scheme()) + { + case scheme_type::BFV: + throw invalid_argument("unsupported operation for scheme type"); + + case scheme_type::CKKS: + while (encrypted.parms_id() != parms_id) + { + // Modulus switching with scaling + mod_switch_scale_to_next(encrypted, encrypted, move(pool)); + } + return; + + default: + throw invalid_argument("unsupported scheme"); + } + } + + void Evaluator::multiply_many(vector &encrypteds, + const RelinKeys &relin_keys, Ciphertext &destination, + MemoryPoolHandle pool) + { + // Verify parameters. + if (encrypteds.size() == 0) + { + throw invalid_argument("encrypteds vector must not be empty"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + for (size_t i = 0; i < encrypteds.size(); i++) + { + if (&encrypteds[i] == &destination) + { + throw invalid_argument("encrypteds must be different from destination"); + } + } + + // There is at least one ciphertext + auto context_data_ptr = context_->context_data(encrypteds[0].parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypteds is not valid for encryption parameters"); + } + + // Extract encryption parameters. + auto &context_data = *context_data_ptr; + auto &parms = context_data.parms(); + + if (parms.scheme() != scheme_type::BFV) + { + throw logic_error("unsupported scheme"); + } + + // If there is only one ciphertext, return it. + if (encrypteds.size() == 1) + { + destination = encrypteds[0]; + return; + } + + // Repeatedly multiply and add to the back of the vector until the end is reached + Ciphertext product(context_, parms.parms_id(), pool); + for (size_t i = 0; i < encrypteds.size() - 1; i += 2) + { + // We only compare pointers to determine if a faster path can be taken. + // This is under the assumption that if the two pointers are the same and + // the parameter sets match, then it makes no sense for one of the ciphertexts + // to be of different size than the other. More generally, it seems like + // a reasonable assumption that if the pointers are the same, then the + // ciphertexts are the same. + if (encrypteds[i].data() == encrypteds[i + 1].data()) + { + square(encrypteds[i], product); + } + else + { + multiply(encrypteds[i], encrypteds[i + 1], product); + } + relinearize_inplace(product, relin_keys, pool); + encrypteds.emplace_back(product); + } + + destination = encrypteds[encrypteds.size() - 1]; + } + + void Evaluator::exponentiate_inplace(Ciphertext &encrypted, uint64_t exponent, + const RelinKeys &relin_keys, MemoryPoolHandle pool) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypted is not valid for encryption parameters"); + } + if (!context_->context_data(relin_keys.parms_id())) + { + throw invalid_argument("relin_keys is not valid for encryption parameters"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + if (exponent == 0) + { + throw invalid_argument("exponent cannot be 0"); + } + + // Fast case + if (exponent == 1) + { + return; + } + + // Create a vector of copies of encrypted + vector exp_vector(exponent, encrypted); + multiply_many(exp_vector, relin_keys, encrypted, move(pool)); + } + + void Evaluator::add_plain_inplace(Ciphertext &encrypted, const Plaintext &plain) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypted is not valid for encryption parameters"); + } + if (context_data_ptr->parms().scheme() == scheme_type::BFV && + encrypted.is_ntt_form()) + { + throw invalid_argument("BFV encrypted cannot be in NTT form"); + } + if (context_data_ptr->parms().scheme() == scheme_type::CKKS && + !encrypted.is_ntt_form()) + { + throw invalid_argument("CKKS encrypted must be in NTT form"); + } + if (plain.is_ntt_form() != encrypted.is_ntt_form()) + { + throw invalid_argument("NTT form mismatch"); + } + if (encrypted.is_ntt_form() && + (encrypted.parms_id() != plain.parms_id())) + { + throw invalid_argument("encrypted and plain parameter mismatch"); + } + if (!are_same_scale(encrypted, plain)) + { + throw invalid_argument("scale mismatch"); + } + + // Extract encryption parameters. + auto &context_data = *context_data_ptr; + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + + // Size check + if (!product_fits_in(coeff_count, coeff_mod_count)) + { + throw logic_error("invalid parameters"); + } + + switch (parms.scheme()) + { + case scheme_type::BFV: + { + // Verify more parameters. + if (plain.coeff_count() > coeff_count) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#ifdef SEAL_DEBUG + if (!are_poly_coefficients_less_than(plain.data(), + plain.coeff_count(), parms.plain_modulus().value())) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#endif + auto coeff_div_plain_modulus = context_data.coeff_div_plain_modulus(); + auto plain_upper_half_threshold = context_data.plain_upper_half_threshold(); + auto upper_half_increment = context_data.upper_half_increment(); + + for (size_t i = 0; i < plain.coeff_count(); i++) + { + // This is Encryptor::preencrypt + // Multiply plain by scalar coeff_div_plain_modulus and reposition + // if in upper-half. + if (plain[i] >= plain_upper_half_threshold) + { + // Loop over primes + for (size_t j = 0; j < coeff_mod_count; j++) + { + unsigned long long temp[2]{ 0, 0 }; + multiply_uint64(coeff_div_plain_modulus[j], plain[i], temp); + temp[1] += add_uint64(temp[0], upper_half_increment[j], temp); + uint64_t scaled_plain_coeff = barrett_reduce_128(temp, coeff_modulus[j]); + *(encrypted.data() + i + (j * coeff_count)) = add_uint_uint_mod( + *(encrypted.data() + i + (j * coeff_count)), + scaled_plain_coeff, coeff_modulus[j]); + } + } + else + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + uint64_t scaled_plain_coeff = multiply_uint_uint_mod( + coeff_div_plain_modulus[j], plain[i], coeff_modulus[j]); + *(encrypted.data() + i + (j * coeff_count)) = add_uint_uint_mod( + *(encrypted.data() + i + (j * coeff_count)), + scaled_plain_coeff, coeff_modulus[j]); + } + } + } + return; + } + + case scheme_type::CKKS: + { + for (size_t j = 0; j < coeff_mod_count; j++) + { +#ifdef SEAL_DEBUG + if (!are_poly_coefficients_less_than(plain.data(j * coeff_count), + coeff_count, coeff_modulus[j].value())) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#endif + add_poly_poly_coeffmod(encrypted.data() + (j * coeff_count), + plain.data() + (j*coeff_count), coeff_count, + coeff_modulus[j], encrypted.data() + (j * coeff_count)); + } + return; + } + + default: + throw invalid_argument("unsupported scheme"); + } + } + + void Evaluator::sub_plain_inplace(Ciphertext &encrypted, const Plaintext &plain) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypted is not valid for encryption parameters"); + } + if (context_data_ptr->parms().scheme() == scheme_type::BFV && + encrypted.is_ntt_form()) + { + throw invalid_argument("BFV encrypted cannot be in NTT form"); + } + if (context_data_ptr->parms().scheme() == scheme_type::CKKS && + !encrypted.is_ntt_form()) + { + throw invalid_argument("CKKS encrypted must be in NTT form"); + } + if (plain.is_ntt_form() != encrypted.is_ntt_form()) + { + throw invalid_argument("NTT form mismatch"); + } + if (encrypted.is_ntt_form() && + (encrypted.parms_id() != plain.parms_id())) + { + throw invalid_argument("encrypted and plain parameter mismatch"); + } + if (!are_same_scale(encrypted, plain)) + { + throw invalid_argument("scale mismatch"); + } + + // Extract encryption parameters. + auto &context_data = *context_data_ptr; + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + + // Size check + if (!product_fits_in(coeff_count, coeff_mod_count)) + { + throw logic_error("invalid parameters"); + } + + switch (parms.scheme()) + { + case scheme_type::BFV: + { + // Verify more parameters. + if (plain.coeff_count() > coeff_count) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#ifdef SEAL_DEBUG + if (!are_poly_coefficients_less_than(plain.data(), + plain.coeff_count(), parms.plain_modulus().value())) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#endif + auto coeff_div_plain_modulus = context_data.coeff_div_plain_modulus(); + auto plain_upper_half_threshold = context_data.plain_upper_half_threshold(); + auto upper_half_increment = context_data.upper_half_increment(); + + for (size_t i = 0; i < plain.coeff_count(); i++) + { + // This is Encryptor::preencrypt changed to subtract instead + // Multiply plain by scalar coeff_div_plain_modulus and reposition + // if in upper-half. + if (plain[i] >= plain_upper_half_threshold) + { + // Loop over primes + for (size_t j = 0; j < coeff_mod_count; j++) + { + unsigned long long temp[2]{ 0, 0 }; + multiply_uint64(coeff_div_plain_modulus[j], plain[i], temp); + temp[1] += add_uint64(temp[0], upper_half_increment[j], temp); + uint64_t scaled_plain_coeff = barrett_reduce_128(temp, coeff_modulus[j]); + *(encrypted.data() + i + (j * coeff_count)) = sub_uint_uint_mod( + *(encrypted.data() + i + (j * coeff_count)), + scaled_plain_coeff, coeff_modulus[j]); + } + } + else + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + uint64_t scaled_plain_coeff = multiply_uint_uint_mod( + coeff_div_plain_modulus[j], plain[i], coeff_modulus[j]); + *(encrypted.data() + i + (j * coeff_count)) = sub_uint_uint_mod( + *(encrypted.data() + i + (j * coeff_count)), + scaled_plain_coeff, coeff_modulus[j]); + } + } + } + return; + } + + case scheme_type::CKKS: + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + sub_poly_poly_coeffmod(encrypted.data() + (j * coeff_count), + plain.data() + (j * coeff_count), coeff_count, + coeff_modulus[j], encrypted.data() + (j * coeff_count)); + } + return; + } + + default: + throw invalid_argument("unsupported scheme"); + } + } + + void Evaluator::multiply_plain_inplace(Ciphertext &encrypted, + const Plaintext &plain, MemoryPoolHandle pool) + { + // Verify parameters. + if (!context_->context_data(encrypted.parms_id())) + { + throw invalid_argument("encrypted is not valid for encryption parameters"); + } + if (encrypted.is_ntt_form() != plain.is_ntt_form()) + { + throw invalid_argument("NTT form mismatch"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + + if (encrypted.is_ntt_form()) + { + multiply_plain_ntt(encrypted, plain); + } + else + { + multiply_plain_normal(encrypted, plain, move(pool)); + } + } + + void Evaluator::multiply_plain_normal(Ciphertext &encrypted, + const Plaintext &plain, MemoryPool &pool) + { + // Extract encryption parameters. + auto &context_data = *context_->context_data(encrypted.parms_id()); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + + auto plain_upper_half_threshold = context_data.plain_upper_half_threshold(); + auto plain_upper_half_increment = context_data.plain_upper_half_increment(); + auto &coeff_small_ntt_tables = context_data.small_ntt_tables(); + + size_t encrypted_size = encrypted.size(); + size_t plain_coeff_count = plain.coeff_count(); + + // Size check + if (!product_fits_in(encrypted_size, coeff_count, coeff_mod_count)) + { + throw logic_error("invalid parameters"); + } + + double new_scale = encrypted.scale() * plain.scale(); + + // Check that scale is positive and not too large + if (new_scale <= 0 || (static_cast(log2(new_scale)) >= + context_data.total_coeff_modulus_bit_count())) + { + throw invalid_argument("scale out of bounds"); + } + + // Verify more parameters. +#ifdef SEAL_THROW_ON_MULTIPLY_PLAIN_BY_ZERO + if (plain.is_zero()) + { + throw invalid_argument("plain cannot be zero"); + } +#endif + if (plain.coeff_count() > coeff_count) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#ifdef SEAL_DEBUG + if (!are_poly_coefficients_less_than(plain.data(), + plain.coeff_count(), parms.plain_modulus().value())) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#endif + // Set the scale + encrypted.scale() = new_scale; + + // Multiplying just by a constant? + if (plain_coeff_count == 1) + { + if (!context_data.qualifiers().using_fast_plain_lift) + { + auto adjusted_coeff(allocate_uint(coeff_mod_count, pool)); + if (plain[0] >= plain_upper_half_threshold) + { + auto decomposed_coeff(allocate_uint(coeff_mod_count, pool)); + add_uint_uint64(plain_upper_half_increment, plain[0], + coeff_mod_count, adjusted_coeff.get()); + decompose_single_coeff(context_data, adjusted_coeff.get(), + decomposed_coeff.get(), pool); + + for (size_t i = 0; i < encrypted_size; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + multiply_poly_scalar_coeffmod( + encrypted.data(i) + (j * coeff_count), coeff_count, + decomposed_coeff[j], coeff_modulus[j], + encrypted.data(i) + (j * coeff_count)); + } + } + } + else + { + for (size_t i = 0; i < encrypted_size; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + multiply_poly_scalar_coeffmod( + encrypted.data(i) + (j * coeff_count), coeff_count, + plain[0], coeff_modulus[j], + encrypted.data(i) + (j * coeff_count)); + } + } + } + return; + } + else + { + // Need for lift plain coefficient in RNS form regarding to each qi + if (plain[0] >= plain_upper_half_threshold) + { + for (size_t i = 0; i < encrypted_size; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + multiply_poly_scalar_coeffmod( + encrypted.data(i) + (j * coeff_count), coeff_count, + plain[0] + plain_upper_half_increment[j], + coeff_modulus[j], encrypted.data(i) + (j * coeff_count)); + } + } + } + // No need for lifting + else + { + for (size_t i = 0; i < encrypted_size; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + multiply_poly_scalar_coeffmod( + encrypted.data(i) + (j * coeff_count), coeff_count, + plain[0], coeff_modulus[j], + encrypted.data(i) + (j * coeff_count)); + } + } + } + return; + } + } + + // Generic plain case + auto adjusted_poly(allocate_zero_uint(coeff_count * coeff_mod_count, pool)); + auto decomposed_poly(allocate_uint(coeff_count * coeff_mod_count, pool)); + uint64_t *poly_to_transform = nullptr; + if (!context_data.qualifiers().using_fast_plain_lift) + { + // Reposition coefficients. + const uint64_t *plain_ptr = plain.data(); + uint64_t *adjusted_poly_ptr = adjusted_poly.get(); + for (size_t i = 0; i < plain_coeff_count; i++, plain_ptr++, + adjusted_poly_ptr += coeff_mod_count) + { + if (*plain_ptr >= plain_upper_half_threshold) + { + add_uint_uint64(plain_upper_half_increment, + *plain_ptr, coeff_mod_count, adjusted_poly_ptr); + } + else + { + *adjusted_poly_ptr = *plain_ptr; + } + } + decompose(context_data, adjusted_poly.get(), decomposed_poly.get(), pool); + poly_to_transform = decomposed_poly.get(); + } + else + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + const uint64_t *plain_ptr = plain.data(); + uint64_t *adjusted_poly_ptr = adjusted_poly.get() + (j * coeff_count); + uint64_t current_plain_upper_half_increment = plain_upper_half_increment[j]; + for (size_t i = 0; i < plain_coeff_count; i++, plain_ptr++, adjusted_poly_ptr++) + { + // Need to lift the coefficient in each qi + if (*plain_ptr >= plain_upper_half_threshold) + { + *adjusted_poly_ptr = *plain_ptr + current_plain_upper_half_increment; + } + // No need for lifting + else + { + *adjusted_poly_ptr = *plain_ptr; + } + } + } + poly_to_transform = adjusted_poly.get(); + } + + // Need to multiply each component in encrypted with decomposed_poly (plain poly) + // Transform plain poly only once + for (size_t i = 0; i < coeff_mod_count; i++) + { + ntt_negacyclic_harvey( + poly_to_transform + (i * coeff_count), coeff_small_ntt_tables[i]); + } + + for (size_t i = 0; i < encrypted_size; i++) + { + uint64_t *encrypted_ptr = encrypted.data(i); + for (size_t j = 0; j < coeff_mod_count; j++, encrypted_ptr += coeff_count) + { + // Explicit inline to avoid unnecessary copy + //ntt_multiply_poly_nttpoly(encrypted.data(i) + (j * coeff_count), + //poly_to_transform + (j * coeff_count), + // coeff_small_ntt_tables_[j], encrypted.data(i) + (j * coeff_count), pool); + + // Lazy reduction + ntt_negacyclic_harvey_lazy(encrypted_ptr, coeff_small_ntt_tables[j]); + dyadic_product_coeffmod(encrypted_ptr, poly_to_transform + (j * coeff_count), + coeff_count, coeff_modulus[j], encrypted_ptr); + inverse_ntt_negacyclic_harvey(encrypted_ptr, coeff_small_ntt_tables[j]); + } + } + } + + void Evaluator::multiply_plain_ntt(Ciphertext &encrypted_ntt, + const Plaintext &plain_ntt) + { + // Verify parameters. + if (!plain_ntt.is_ntt_form()) + { + throw invalid_argument("plain_ntt is not in NTT form"); + } + if (encrypted_ntt.parms_id() != plain_ntt.parms_id()) + { + throw invalid_argument("encrypted_ntt and plain_ntt parameter mismatch"); + } + + // Extract encryption parameters. + auto &context_data = *context_->context_data(encrypted_ntt.parms_id()); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t encrypted_ntt_size = encrypted_ntt.size(); + + // Size check + if (!product_fits_in(encrypted_ntt_size, coeff_count, coeff_mod_count)) + { + throw logic_error("invalid parameters"); + } + + double new_scale = encrypted_ntt.scale() * plain_ntt.scale(); + + // Check that scale is positive and not too large + if (new_scale <= 0 || (static_cast(log2(new_scale)) >= + context_data.total_coeff_modulus_bit_count())) + { + throw invalid_argument("scale out of bounds"); + } + + // Verify more parameters. + if (plain_ntt.coeff_count() != coeff_count * coeff_mod_count) + { + throw invalid_argument("plain_ntt is not valid for encryption parameters"); + } +#ifdef SEAL_DEBUG + for (size_t i = 0; i < coeff_mod_count; i++) + { + if (poly_infty_norm_coeffmod(plain_ntt.data(i * coeff_count), coeff_count, + coeff_modulus[i]) >= coeff_modulus[i].value()) + { + throw invalid_argument("plain_ntt is not valid for encryption parameters"); + } + } +#endif +#ifdef SEAL_THROW_ON_MULTIPLY_PLAIN_BY_ZERO + if (plain_ntt.is_zero()) + { + throw invalid_argument("plain_ntt cannot be zero"); + } +#endif + for (size_t i = 0; i < encrypted_ntt_size; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + dyadic_product_coeffmod( + encrypted_ntt.data(i) + (j * coeff_count), + plain_ntt.data() + (j * coeff_count), + coeff_count, coeff_modulus[j], + encrypted_ntt.data(i) + (j * coeff_count)); + } + } + + // Set the scale + encrypted_ntt.scale() = new_scale; + } + + void Evaluator::transform_to_ntt_inplace(Plaintext &plain, + parms_id_type parms_id, MemoryPoolHandle pool) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(parms_id); + if (!context_data_ptr) + { + throw invalid_argument("parms_id is not valid for the current context"); + } + if (plain.is_ntt_form()) + { + throw invalid_argument("plain is already in NTT form"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + + // Extract encryption parameters. + auto &context_data = *context_data_ptr; + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t plain_coeff_count = plain.coeff_count(); + + auto plain_upper_half_threshold = context_data.plain_upper_half_threshold(); + auto plain_upper_half_increment = context_data.plain_upper_half_increment(); + + // Verify more parameters. + if (plain.coeff_count() > coeff_count) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#ifdef SEAL_DEBUG + if (!are_poly_coefficients_less_than(plain.data(), + plain.coeff_count(), parms.plain_modulus().value())) + { + throw invalid_argument("plain is not valid for encryption parameters"); + } +#endif + auto &coeff_small_ntt_tables = context_data.small_ntt_tables(); + + // Size check + if (!product_fits_in(coeff_count, coeff_mod_count)) + { + throw logic_error("invalid parameters"); + } + + // Resize to fit the entire NTT transformed (ciphertext size) polynomial + // Note that the new coefficients are automatically set to 0 + plain.resize(coeff_count * coeff_mod_count); + + // Verify if plain lift is needed + if (!context_data.qualifiers().using_fast_plain_lift) + { + auto adjusted_poly(allocate_zero_uint(coeff_count * coeff_mod_count, pool)); + for (size_t i = 0; i < plain_coeff_count; i++) + { + if (plain[i] >= plain_upper_half_threshold) + { + add_uint_uint64(plain_upper_half_increment, plain[i], + coeff_mod_count, adjusted_poly.get() + (i * coeff_mod_count)); + } + else + { + adjusted_poly[i * coeff_mod_count] = plain[i]; + } + } + decompose(context_data, adjusted_poly.get(), plain.data(), pool); + } + // No need for composed plain lift and decomposition + else + { + for (size_t j = coeff_mod_count; j--; ) + { + const uint64_t *plain_ptr = plain.data(); + uint64_t *adjusted_poly_ptr = plain.data() + (j * coeff_count); + uint64_t current_plain_upper_half_increment = plain_upper_half_increment[j]; + for (size_t i = 0; i < plain_coeff_count; i++, plain_ptr++, adjusted_poly_ptr++) + { + // Need to lift the coefficient in each qi + if (*plain_ptr >= plain_upper_half_threshold) + { + *adjusted_poly_ptr = *plain_ptr + current_plain_upper_half_increment; + } + // No need for lifting + else + { + *adjusted_poly_ptr = *plain_ptr; + } + } + } + } + + // Transform to NTT domain + for (size_t i = 0; i < coeff_mod_count; i++) + { + ntt_negacyclic_harvey( + plain.data() + (i * coeff_count), coeff_small_ntt_tables[i]); + } + + plain.parms_id() = parms_id; + } + + void Evaluator::transform_to_ntt_inplace(Ciphertext &encrypted) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypted is not valid for encryption parameters"); + } + if (encrypted.is_ntt_form()) + { + throw invalid_argument("encrypted is already in NTT form"); + } + + // Extract encryption parameters. + auto &context_data = *context_data_ptr; + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t encrypted_size = encrypted.size(); + + auto &coeff_small_ntt_tables = context_data.small_ntt_tables(); + + // Size check + if (!product_fits_in(coeff_count, coeff_mod_count)) + { + throw logic_error("invalid parameters"); + } + + // Transform each polynomial to NTT domain + for (size_t i = 0; i < encrypted_size; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + ntt_negacyclic_harvey( + encrypted.data(i) + (j * coeff_count), coeff_small_ntt_tables[j]); + } + } + + // Finally change the is_ntt_transformed flag + encrypted.is_ntt_form() = true; + } + + void Evaluator::transform_from_ntt_inplace(Ciphertext &encrypted_ntt) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted_ntt.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypted_ntt is not valid for encryption parameters"); + } + if (!encrypted_ntt.is_ntt_form()) + { + throw invalid_argument("encrypted_ntt is not in NTT form"); + } + + // Extract encryption parameters. + auto &context_data = *context_data_ptr; + auto &parms = context_data.parms(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = parms.coeff_modulus().size(); + size_t encrypted_ntt_size = encrypted_ntt.size(); + + auto &coeff_small_ntt_tables = context_data.small_ntt_tables(); + + // Size check + if (!product_fits_in(coeff_count, coeff_mod_count)) + { + throw logic_error("invalid parameters"); + } + + // Transform each polynomial from NTT domain + for (size_t i = 0; i < encrypted_ntt_size; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + inverse_ntt_negacyclic_harvey( + encrypted_ntt.data(i) + (j * coeff_count), coeff_small_ntt_tables[j]); + } + } + + // Finally change the is_ntt_transformed flag + encrypted_ntt.is_ntt_form() = false; + } + + void Evaluator::apply_galois_inplace(Ciphertext &encrypted, uint64_t galois_elt, + const GaloisKeys &galois_keys, MemoryPoolHandle pool) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypted is not valid for encryption parameters"); + } + if (galois_keys.parms_id() != context_->first_parms_id()) + { + throw invalid_argument("parameter mismatch"); + } + if (context_data_ptr->parms().scheme() == scheme_type::BFV && + encrypted.is_ntt_form()) + { + throw invalid_argument("BFV encrypted cannot be in NTT form"); + } + if (context_data_ptr->parms().scheme() == scheme_type::CKKS && + !encrypted.is_ntt_form()) + { + throw invalid_argument("CKKS encrypted must be in NTT form"); + } + if (!pool) + { + throw invalid_argument("pool is uninitialized"); + } + + // Extract encryption parameters. + auto &context_data = *context_data_ptr; + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t encrypted_size = encrypted.size(); + + // Size check + if (!product_fits_in(coeff_count, coeff_mod_count)) + { + throw logic_error("invalid parameters"); + } + + uint64_t m = mul_safe(static_cast(coeff_count), uint64_t(2)); + uint64_t subgroup_size = static_cast(coeff_count >> 1); + int n_power_of_two = get_power_of_two(static_cast(coeff_count)); + + // Verify parameters + if (!(galois_elt & 1) || unsigned_geq(galois_elt, m)) + { + throw invalid_argument("galois element is not valid"); + } + if (encrypted_size > 2) + { + throw invalid_argument("encrypted size must be 2"); + } + + auto &first_context_data = *context_->context_data(); + auto &inv_coeff_products_mod_coeff_array = + first_context_data.base_converter()->get_inv_coeff_mod_coeff_array(); + auto &coeff_small_ntt_tables = first_context_data.small_ntt_tables(); + + // Check if Galois key is generated or not. + // If not, attempt a bit decomposition; maybe we have log(n) many keys + if (!galois_keys.has_key(galois_elt)) + { + // galois_elt = 3^order1 * (-1)^order2 + uint64_t order1 = Zmstar_to_generator_.at(galois_elt).first; + uint64_t order2 = Zmstar_to_generator_.at(galois_elt).second; + + // We use either 3 or -3 as our generator, depending on which gives smaller HW + uint64_t two_power_of_gen = 3; + + // Does order1 or n/2-order1 have smaller Hamming weight? + if (hamming_weight(subgroup_size - order1) < hamming_weight(order1)) + { + order1 = subgroup_size - order1; + try_mod_inverse(3, m, two_power_of_gen); + } + + while(order1) + { + if (order1 & 1) + { + if (!galois_keys.has_key(two_power_of_gen)) + { + throw invalid_argument("galois key not present"); + } + apply_galois_inplace(encrypted, two_power_of_gen, galois_keys, pool); + } + two_power_of_gen = mul_safe(two_power_of_gen, two_power_of_gen); + two_power_of_gen &= (m - 1); + order1 >>= 1; + } + if (order2) + { + if (!galois_keys.has_key(m - 1)) + { + throw invalid_argument("galois key not present"); + } + apply_galois_inplace(encrypted, m - 1, galois_keys, pool); + } + return; + } + + auto temp0(allocate_zero_uint(coeff_count * coeff_mod_count, pool)); + auto temp1(allocate_zero_uint(coeff_count * coeff_mod_count, pool)); + + if (context_data_ptr->parms().scheme() == scheme_type::BFV) + { + // Apply Galois for each ciphertext + for (size_t i = 0; i < coeff_mod_count; i++) + { + util::apply_galois(encrypted.data() + (i * coeff_count), n_power_of_two, + galois_elt, coeff_modulus[i], temp0.get() + (i * coeff_count)); + } + for (size_t i = 0; i < coeff_mod_count; i++) + { + util::apply_galois(encrypted.data(1) + (i * coeff_count), n_power_of_two, + galois_elt, coeff_modulus[i], temp1.get() + (i * coeff_count)); + } + } + else if (context_data_ptr->parms().scheme() == scheme_type::CKKS) + { + // Apply Galois for each ciphertext + for (size_t i = 0; i < coeff_mod_count; i++) + { + util::apply_galois_ntt(encrypted.data() + (i * coeff_count), n_power_of_two, + galois_elt, temp0.get() + (i * coeff_count)); + } + for (size_t i = 0; i < coeff_mod_count; i++) + { + util::apply_galois_ntt(encrypted.data(1) + (i * coeff_count), n_power_of_two, + galois_elt, temp1.get() + (i * coeff_count)); + } + + // Transform ct[1] from NTT + for (size_t i = 0; i < coeff_mod_count; i++) + { + inverse_ntt_negacyclic_harvey(temp1.get() + (i * coeff_count), + coeff_small_ntt_tables[i]); + } + } + else + { + throw logic_error("scheme not implemented"); + } + + // Calculate (temp1 * galois_key.first, temp1 * galois_key.second) + (temp0, 0) + const uint64_t *encrypted_coeff = temp1.get(); + auto encrypted_coeff_prod_inv_coeff(allocate_uint(coeff_count, pool)); + + // decompose encrypted_array[count-1] into base w + // want to create an array of polys, each of whose components i is + // (encrypted_array[count-1])^(i) - in the notation of FV paper. + // This allocation stores one of the decomposed factors modulo one of the primes. + auto decomp_encrypted_last(allocate_uint(coeff_count, pool)); + + // Lazy reduction + auto wide_innerresult0(allocate_zero_poly(coeff_count, 2 * coeff_mod_count, pool)); + auto wide_innerresult1(allocate_zero_poly(coeff_count, 2 * coeff_mod_count, pool)); + auto innerresult(allocate_poly(coeff_count, coeff_mod_count, pool)); + auto temp_decomp_coeff(allocate_uint(coeff_count, pool)); + + /* + For lazy reduction to work here, we need to ensure that the 128-bit accumulators + (wide_innerresult0 and wide_innerresult1) do not overflow. Since the modulus primes + are at most 60 bits, if the total number of summands is K, then the size of the + total sum of products (without reduction) is at most 62 + 60 + bit_length(K). + We need this to be at most 128, thus we need bit_length(K) <= 6. Thus, we need K <= 63. + In this case, this means sum_i galois_keys.key(galois_elt)[i].size() / 2 <= 63. + */ + for (size_t i = 0; i < coeff_mod_count; i++) + { + multiply_poly_scalar_coeffmod( + encrypted_coeff + (i * coeff_count), coeff_count, + inv_coeff_products_mod_coeff_array[i], coeff_modulus[i], + encrypted_coeff_prod_inv_coeff.get()); + + int shift = 0; + auto &key_component_ref = galois_keys.key(galois_elt)[i]; + size_t keys_size = key_component_ref.size(); + for (size_t k = 0; k < keys_size; k += 2) + { + const uint64_t *key_ptr_0 = key_component_ref.data(k); + const uint64_t *key_ptr_1 = key_component_ref.data(k + 1); + + // Decompose here + int decomposition_bit_count = galois_keys.decomposition_bit_count(); + for (size_t coeff_index = 0; coeff_index < coeff_count; coeff_index++) + { + decomp_encrypted_last[coeff_index] = + encrypted_coeff_prod_inv_coeff[coeff_index] >> shift; + decomp_encrypted_last[coeff_index] &= + (uint64_t(1) << decomposition_bit_count) - 1; + } + + uint64_t *wide_innerresult0_ptr = wide_innerresult0.get(); + uint64_t *wide_innerresult1_ptr = wide_innerresult1.get(); + for (size_t j = 0; j < coeff_mod_count; j++) + { + uint64_t *temp_decomp_coeff_ptr = temp_decomp_coeff.get(); + set_uint_uint(decomp_encrypted_last.get(), coeff_count, temp_decomp_coeff_ptr); + + // We don't reduce here, so might get up to two extra bits. Thus 62 bits at most. + ntt_negacyclic_harvey_lazy(temp_decomp_coeff_ptr, coeff_small_ntt_tables[j]); + + // Lazy reduction + unsigned long long wide_innerproduct[2]; + unsigned long long temp; + for (size_t l = 0; l < coeff_count; l++, wide_innerresult0_ptr += 2) + { + multiply_uint64(*temp_decomp_coeff_ptr++, *key_ptr_0++, wide_innerproduct); + unsigned char carry = add_uint64(wide_innerresult0_ptr[0], + wide_innerproduct[0], &temp); + wide_innerresult0_ptr[0] = temp; + wide_innerresult0_ptr[1] += wide_innerproduct[1] + carry; + } + + temp_decomp_coeff_ptr = temp_decomp_coeff.get(); + for (size_t l = 0; l < coeff_count; l++, wide_innerresult1_ptr += 2) + { + multiply_uint64(*temp_decomp_coeff_ptr++, *key_ptr_1++, wide_innerproduct); + unsigned char carry = add_uint64(wide_innerresult1_ptr[0], + wide_innerproduct[0], &temp); + wide_innerresult1_ptr[0] = temp; + wide_innerresult1_ptr[1] += wide_innerproduct[1] + carry; + } + } + shift += decomposition_bit_count; + } + } + + uint64_t *temp_ptr = temp0.get(); + uint64_t *innerresult_poly_ptr = innerresult.get(); + uint64_t *wide_innerresult_poly_ptr = wide_innerresult0.get(); + uint64_t *encrypted_ptr = encrypted.data(); + uint64_t *innerresult_coeff_ptr = innerresult_poly_ptr; + uint64_t *wide_innerresult_coeff_ptr = wide_innerresult_poly_ptr; + for (size_t i = 0; i < coeff_mod_count; i++, innerresult_poly_ptr += coeff_count, + wide_innerresult_poly_ptr += 2 * coeff_count, encrypted_ptr += coeff_count, + temp_ptr += coeff_count) + { + for (size_t k = 0; k < coeff_count; + k++, wide_innerresult_coeff_ptr += 2, innerresult_coeff_ptr++) + { + *innerresult_coeff_ptr = barrett_reduce_128( + wide_innerresult_coeff_ptr, coeff_modulus[i]); + } + if (context_data_ptr->parms().scheme() == scheme_type::BFV) + { + inverse_ntt_negacyclic_harvey(innerresult_poly_ptr, + coeff_small_ntt_tables[i]); + } + add_poly_poly_coeffmod(temp_ptr, innerresult_poly_ptr, coeff_count, + coeff_modulus[i], encrypted_ptr); + } + + innerresult_poly_ptr = innerresult.get(); + wide_innerresult_poly_ptr = wide_innerresult1.get(); + encrypted_ptr = encrypted.data(1); + wide_innerresult_coeff_ptr = wide_innerresult_poly_ptr; + for (size_t i = 0; i < coeff_mod_count; i++, innerresult_poly_ptr += coeff_count, + wide_innerresult_poly_ptr += 2 * coeff_count, encrypted_ptr += coeff_count) + { + innerresult_coeff_ptr = encrypted_ptr; + for (size_t k = 0; k < coeff_count; + k++, wide_innerresult_coeff_ptr += 2, innerresult_coeff_ptr++) + { + *innerresult_coeff_ptr = barrett_reduce_128( + wide_innerresult_coeff_ptr, coeff_modulus[i]); + } + if (context_data_ptr->parms().scheme() == scheme_type::BFV) + { + inverse_ntt_negacyclic_harvey(encrypted_ptr, coeff_small_ntt_tables[i]); + } + } + + // If CKKS, mark encrypted as NTT form + if (context_data_ptr->parms().scheme() == scheme_type::CKKS) + { + encrypted.is_ntt_form() = true; + } + } + + void Evaluator::rotate_internal(Ciphertext &encrypted, int steps, + const GaloisKeys &galois_keys, MemoryPoolHandle pool) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + if (!context_data_ptr) + { + throw invalid_argument("encrypted is not valid for encryption parameters"); + } + + // Extract encryption parameters. + auto &context_data = *context_data_ptr; + if (!context_data.qualifiers().using_batching) + { + throw logic_error("encryption parameters do not support batching"); + } + + // Is there anything to do? + if (steps == 0) + { + return; + } + + auto &parms = context_data.parms(); + size_t coeff_count = parms.poly_modulus_degree(); + + // Perform rotation and key switching + apply_galois_inplace(encrypted, + steps_to_galois_elt(steps, coeff_count), + galois_keys, move(pool)); + } +} diff --git a/src/seal/evaluator.h b/src/seal/evaluator.h new file mode 100644 index 000000000..172a99c19 --- /dev/null +++ b/src/seal/evaluator.h @@ -0,0 +1,1461 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include "seal/context.h" +#include "seal/relinkeys.h" +#include "seal/smallmodulus.h" +#include "seal/memorymanager.h" +#include "seal/ciphertext.h" +#include "seal/plaintext.h" +#include "seal/galoiskeys.h" +#include "seal/util/pointer.h" +#include "seal/secretkey.h" +#include "seal/util/uintarithsmallmod.h" +#include "seal/util/common.h" + +namespace seal +{ + /** + Provides operations on ciphertexts. Due to the properties of the encryption + scheme, the arithmetic operations pass through the encryption layer to the + underlying plaintext, changing it according to the type of the operation. Since + the plaintext elements are fundamentally polynomials in the polynomial quotient + ring Z_T[x]/(X^N+1), where T is the plaintext modulus and X^N+1 is the polynomial + modulus, this is the ring where the arithmetic operations will take place. + BatchEncoder (batching) provider an alternative possibly more convenient view + of the plaintext elements as 2-by-(N2/2) matrices of integers modulo the plaintext + modulus. In the batching view the arithmetic operations act on the matrices + element-wise. Some of the operations only apply in the batching view, such as + matrix row and column rotations. Other operations such as relinearization have + no semantic meaning but are necessary for performance reasons. + + @par Arithmetic Operations + The core operations are arithmetic operations, in particular multiplication + and addition of ciphertexts. In addition to these, we also provide negation, + subtraction, squaring, exponentiation, and multiplication and addition of + several ciphertexts for convenience. in many cases some of the inputs to a + computation are plaintext elements rather than ciphertexts. For this we + provide fast "plain" operations: plain addition, plain subtraction, and plain + multiplication. + + @par Relinearization + One of the most important non-arithmetic operations is relinearization, which + takes as input a ciphertext of size K+1 and relinearization keys (at least K-1 + keys are needed), and changes the size of the ciphertext down to 2 (minimum size). + For most use-cases only one relinearization key suffices, in which case + relinearization should be performed after every multiplication. Homomorphic + multiplication of ciphertexts of size K+1 and L+1 outputs a ciphertext of size + K+L+1, and the computational cost of multiplication is proportional to K*L. + Plain multiplication and addition operations of any type do not change the + size. The performance of relinearization is determined by the decomposition + bit count that the relinearization keys were generated with. + + @par Rotations + When batching is enabled, we provide operations for rotating the plaintext matrix + rows cyclically left or right, and for rotating the columns (swapping the rows). + Rotations require Galois keys to have been generated, and their performance + depends on the decomposition bit count that the Galois keys were generated with. + + @par Other Operations + We also provide operations for transforming ciphertexts to NTT form and back, + and for transforming plaintext polynomials to NTT form. These can be used in + a very fast plain multiplication variant, that assumes the inputs to be in NTT + form. Since the NTT has to be done in any case in plain multiplication, this + function can be used when e.g. one plaintext input is used in several plain + multiplication, and transforming it several times would not make sense. + + @par NTT form + When using the BFV scheme (scheme_type::BFV), all plaintexts and ciphertexts + should remain by default in the usual coefficient representation, i.e. not + in NTT form. When using the CKKS scheme (scheme_type::CKKS), all plaintexts + and ciphertexts should remain by default in NTT form. We call these scheme- + specific NTT states the "default NTT form". Some functions, such as add, work + even if the inputs are not in the default state, but others, such as multiply, + will throw an exception. The output of all evaluation functions will be in + the same state as the input(s), with the exception of the transform_to_ntt + and transform_from_ntt functions, which change the state. Ideally, unless these + two functions are called, all other functions should "just work". + + @see EncryptionParameters for more details on encryption parameters. + @see BatchEncoder for more details on batching + @see RelinKeys for more details on relinearization keys. + @see GaloisKeys for more details on Galois keys. + */ + class Evaluator + { + public: + /** + Creates an Evaluator instance initialized with the specified SEALContext. + + @param[in] context The SEALContext + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + */ + Evaluator(std::shared_ptr context); + + /** + Negates a ciphertext. + + @param[in] encrypted The ciphertext to negate + @throws std::invalid_argument if encrypted is not valid for the encryption + parameters + */ + void negate_inplace(Ciphertext &encrypted); + + /** + Negates a ciphertext and stores the result in the destination parameter. + + @param[in] encrypted The ciphertext to negate + @param[out] destination The ciphertext to overwrite with the negated result + @throws std::invalid_argument if encrypted is not valid for the encryption + parameters + */ + inline void negate(const Ciphertext &encrypted, Ciphertext &destination) + { + destination = encrypted; + negate_inplace(destination); + } + + /** + Adds two ciphertexts. This function adds together encrypted1 and encrypted2 + and stores the result in encrypted1. + + @param[in] encrypted1 The first ciphertext to add + @param[in] encrypted2 The second ciphertext to add + @throws std::invalid_argument if encrypted1 or encrypted2 is not valid for + the encryption parameters + @throws std::invalid_argument if encrypted1 and encrypted2 are in different + NTT forms + @throws std::invalid_argument if encrypted1 and encrypted2 have different scale + */ + void add_inplace(Ciphertext &encrypted1, const Ciphertext &encrypted2); + + /** + Adds two ciphertexts. This function adds together encrypted1 and encrypted2 + and stores the result in the destination parameter. + + @param[in] encrypted1 The first ciphertext to add + @param[in] encrypted2 The second ciphertext to add + @param[out] destination The ciphertext to overwrite with the addition result + @throws std::invalid_argument if encrypted1 or encrypted2 is not valid for + the encryption parameters + @throws std::invalid_argument if encrypted1 and encrypted2 are in different + NTT forms + @throws std::invalid_argument if encrypted1 and encrypted2 have different scale + */ + inline void add(const Ciphertext &encrypted1, const Ciphertext &encrypted2, + Ciphertext &destination) + { + if (&encrypted2 == &destination) + { + add_inplace(destination, encrypted1); + } + else + { + destination = encrypted1; + add_inplace(destination, encrypted2); + } + } + + /** + Adds together a vector of ciphertexts and stores the result in the destination + parameter. + + @param[in] encrypteds The ciphertexts to add + @param[out] destination The ciphertext to overwrite with the addition result + @throws std::invalid_argument if encrypteds is empty + @throws std::invalid_argument if the encrypteds are not valid for the encryption + parameters + @throws std::invalid_argument if encrypteds are in different NTT forms + @throws std::invalid_argument if encrypteds have different scale + @throws std::invalid_argument if destination is one of encrypteds + */ + void add_many(const std::vector &encrypteds, Ciphertext &destination); + + /** + Subtracts two ciphertexts. This function computes the difference of encrypted1 + and encrypted2, and stores the result in encrypted1. + + @param[in] encrypted1 The ciphertext to subtract from + @param[in] encrypted2 The ciphertext to subtract + @throws std::invalid_argument if encrypted1 or encrypted2 is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted1 and encrypted2 are in different + NTT forms + @throws std::invalid_argument if encrypted1 and encrypted2 have different scale + */ + void sub_inplace(Ciphertext &encrypted1, const Ciphertext &encrypted2); + + /** + Subtracts two ciphertexts. This function computes the difference of encrypted1 + and encrypted2 and stores the result in the destination parameter. + + @param[in] encrypted1 The ciphertext to subtract from + @param[in] encrypted2 The ciphertext to subtract + @param[out] destination The ciphertext to overwrite with the subtraction result + @throws std::invalid_argument if encrypted1 or encrypted2 is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted1 and encrypted2 are in different + NTT forms + @throws std::invalid_argument if encrypted1 and encrypted2 have different scale + */ + inline void sub(const Ciphertext &encrypted1, const Ciphertext &encrypted2, + Ciphertext &destination) + { + if (&encrypted2 == &destination) + { + sub_inplace(destination, encrypted1); + negate_inplace(destination); + } + else + { + destination = encrypted1; + sub_inplace(destination, encrypted2); + } + } + + /** + Multiplies two ciphertexts. This functions computes the product of encrypted1 + and encrypted2 and stores the result in encrypted1. Dynamic memory allocations + in the process are allocated from the memory pool pointed to by the given + MemoryPoolHandle. + + @param[in] encrypted1 The first ciphertext to multiply + @param[in] encrypted2 The second ciphertext to multiply + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if encrypted1 or encrypted2 is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted1 or encrypted2 is not in the default + NTT form + @throws std::invalid_argument if, when using scheme_type::CKKS, the output scale + is too large for the encryption parameters + @throws std::invalid_argument if pool is uninitialized + */ + void multiply_inplace(Ciphertext &encrypted1, const Ciphertext &encrypted2, + MemoryPoolHandle pool = MemoryManager::GetPool()); + + /** + Multiplies two ciphertexts. This functions computes the product of encrypted1 + and encrypted2 and stores the result in the destination parameter. Dynamic + memory allocations in the process are allocated from the memory pool pointed + to by the given MemoryPoolHandle. + + @param[in] encrypted1 The first ciphertext to multiply + @param[in] encrypted2 The second ciphertext to multiply + @param[out] destination The ciphertext to overwrite with the multiplication result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if encrypted1 or encrypted2 is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted1 or encrypted2 is not in the default + NTT form + @throws std::invalid_argument if, when using scheme_type::CKKS, the output scale + is too large for the encryption parameters + @throws std::invalid_argument if pool is uninitialized + */ + inline void multiply(const Ciphertext &encrypted1, + const Ciphertext &encrypted2, Ciphertext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + if (&encrypted2 == &destination) + { + multiply_inplace(destination, encrypted1, std::move(pool)); + } + else + { + destination = encrypted1; + multiply_inplace(destination, encrypted2, std::move(pool)); + } + } + + /** + Squares a ciphertext. This functions computes the square of encrypted. Dynamic + memory allocations in the process are allocated from the memory pool pointed + to by the given MemoryPoolHandle. + + @param[in] encrypted The ciphertext to square + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if encrypted is not valid for the encryption + parameters + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if, when using scheme_type::CKKS, the output scale + is too large for the encryption parameters + @throws std::invalid_argument if pool is uninitialized + */ + void square_inplace(Ciphertext &encrypted, + MemoryPoolHandle pool = MemoryManager::GetPool()); + + /** + Squares a ciphertext. This functions computes the square of encrypted and + stores the result in the destination parameter. Dynamic memory allocations + in the process are allocated from the memory pool pointed to by the given + MemoryPoolHandle. + + @param[in] encrypted The ciphertext to square + @param[out] destination The ciphertext to overwrite with the square + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if encrypted is not valid for the encryption + parameters + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if, when using scheme_type::CKKS, the output scale + is too large for the encryption parameters + @throws std::invalid_argument if pool is uninitialized + */ + inline void square(const Ciphertext &encrypted, Ciphertext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + destination = encrypted; + square_inplace(destination, std::move(pool)); + } + + /** + Relinearizes a ciphertext. This functions relinearizes encrypted, reducing + its size down to 2. If the size of encrypted is K+1, the given relinearization + keys need to have size at least K-1. Dynamic memory allocations in the + process are allocated from the memory pool pointed to by the given + MemoryPoolHandle. + + @param[in] encrypted The ciphertext to relinearize + @param[in] relin_keys The relinearization keys + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if encrypted or relin_keys is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if relin_keys do not correspond to the top level + parameters in the current context + @throws std::invalid_argument if the size of relin_keys is too small + @throws std::invalid_argument if pool is uninitialized + */ + inline void relinearize_inplace(Ciphertext &encrypted, const RelinKeys &relin_keys, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + relinearize_internal(encrypted, relin_keys, 2, std::move(pool)); + } + + /** + Relinearizes a ciphertext. This functions relinearizes encrypted, reducing + its size down to 2, and stores the result in the destination parameter. + If the size of encrypted is K+1, the given relinearization keys need to + have size at least K-1. Dynamic memory allocations in the process are allocated + from the memory pool pointed to by the given MemoryPoolHandle. + + @param[in] encrypted The ciphertext to relinearize + @param[in] relin_keys The relinearization keys + @param[out] destination The ciphertext to overwrite with the relinearized result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if encrypted or relin_keys is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if relin_keys do not correspond to the top level + parameters in the current context + @throws std::invalid_argument if the size of relin_keys is too small + @throws std::invalid_argument if pool is uninitialized + */ + inline void relinearize(const Ciphertext &encrypted, + const RelinKeys &relin_keys, Ciphertext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + destination = encrypted; + relinearize_inplace(destination, relin_keys, std::move(pool)); + } + + /** + Given a ciphertext encrypted modulo q_1...q_k, this function switches the + modulus down to q_1...q_{k-1} and stores the result in the destination + parameter. Dynamic memory allocations in the process are allocated from + the memory pool pointed to by the given MemoryPoolHandle. + + @param[in] encrypted The ciphertext to be switched to a smaller modulus + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @param[out] destination The ciphertext to overwrite with the modulus switched result + @throws std::invalid_argument if encrypted is not valid for the encryption parameters + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if encrypted is already at lowest level + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if, when using scheme_type::CKKS, the scale is too + large for the new encryption parameters + @throws std::invalid_argument if pool is uninitialized + */ + void mod_switch_to_next(const Ciphertext &encrypted, Ciphertext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()); + + /** + Given a ciphertext encrypted modulo q_1...q_k, this function switches the + modulus down to q_1...q_{k-1}. Dynamic memory allocations in the process + are allocated from the memory pool pointed to by the given MemoryPoolHandle. + + @param[in] encrypted The ciphertext to be switched to a smaller modulus + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if encrypted is not valid for the encryption parameters + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if encrypted is already at lowest level + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if, when using scheme_type::CKKS, the scale is too + large for the new encryption parameters + @throws std::invalid_argument if pool is uninitialized + */ + inline void mod_switch_to_next_inplace(Ciphertext &encrypted, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + mod_switch_to_next(encrypted, encrypted, std::move(pool)); + } + + /** + Modulus switches an NTT transformed plaintext from modulo q_1...q_k down + to modulo q_1...q_{k-1}. + + @param[in] plain The plaintext to be switched to a smaller modulus + @throws std::invalid_argument if plain is not in NTT form + @throws std::invalid_argument if plain is not valid for the encryption parameters + @throws std::invalid_argument if plain is already at lowest level + @throws std::invalid_argument if, when using scheme_type::CKKS, the scale is too + large for the new encryption parameters + */ + inline void mod_switch_to_next_inplace(Plaintext &plain) + { + mod_switch_drop_to_next(plain); + } + + /** + Modulus switches an NTT transformed plaintext from modulo q_1...q_k down + to modulo q_1...q_{k-1} and stores the result in the destination parameter. + + @param[in] plain The plaintext to be switched to a smaller modulus + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @param[out] destination The plaintext to overwrite with the modulus switched result + @throws std::invalid_argument if plain is not in NTT form + @throws std::invalid_argument if plain is not valid for the encryption parameters + @throws std::invalid_argument if plain is already at lowest level + @throws std::invalid_argument if, when using scheme_type::CKKS, the scale is too + large for the new encryption parameters + @throws std::invalid_argument if pool is uninitialized + */ + inline void mod_switch_to_next(const Plaintext &plain, Plaintext &destination) + { + destination = plain; + mod_switch_to_next_inplace(destination); + } + + /** + Given a ciphertext encrypted modulo q_1...q_k, this function switches the + modulus down until the parameters reach the given parms_id. Dynamic memory + allocations in the process are allocated from the memory pool pointed to + by the given MemoryPoolHandle. + + @param[in] encrypted The ciphertext to be switched to a smaller modulus + @param[in] parms_id The target parms_id + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if encrypted is not valid for the encryption parameters + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if parms_id is not valid for the encryption parameters + @throws std::invalid_argument if encrypted is already at lower level in modulus chain + than the parameters corresponding to parms_id + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if, when using scheme_type::CKKS, the scale is too + large for the new encryption parameters + @throws std::invalid_argument if pool is uninitialized + */ + void mod_switch_to_inplace(Ciphertext &encrypted, parms_id_type parms_id, + MemoryPoolHandle pool = MemoryManager::GetPool()); + + /** + Given a ciphertext encrypted modulo q_1...q_k, this function switches the + modulus down until the parameters reach the given parms_id and stores the + result in the destination parameter. Dynamic memory allocations in the process + are allocated from the memory pool pointed to by the given MemoryPoolHandle. + + @param[in] encrypted The ciphertext to be switched to a smaller modulus + @param[in] parms_id The target parms_id + @param[out] destination The ciphertext to overwrite with the modulus switched result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if encrypted is not valid for the encryption parameters + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if parms_id is not valid for the encryption parameters + @throws std::invalid_argument if encrypted is already at lower level in modulus chain + than the parameters corresponding to parms_id + @throws std::invalid_argument if, when using scheme_type::CKKS, the scale is too + large for the new encryption parameters + @throws std::invalid_argument if pool is uninitialized + */ + inline void mod_switch_to(const Ciphertext &encrypted, + parms_id_type parms_id, Ciphertext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + destination = encrypted; + mod_switch_to_inplace(destination, parms_id, std::move(pool)); + } + + /** + Given an NTT transformed plaintext modulo q_1...q_k, this function switches + the modulus down until the parameters reach the given parms_id. + + @param[in] plain The plaintext to be switched to a smaller modulus + @param[in] parms_id The target parms_id + @throws std::invalid_argument if plain is not in NTT form + @throws std::invalid_argument if plain is not valid for the encryption parameters + @throws std::invalid_argument if parms_id is not valid for the encryption parameters + @throws std::invalid_argument if plain is already at lower level in modulus chain + than the parameters corresponding to parms_id + @throws std::invalid_argument if, when using scheme_type::CKKS, the scale is too + large for the new encryption parameters + */ + void mod_switch_to_inplace(Plaintext &plain, parms_id_type parms_id); + + /** + Given an NTT transformed plaintext modulo q_1...q_k, this function switches + the modulus down until the parameters reach the given parms_id and stores + the result in the destination parameter. + + @param[in] plain The plaintext to be switched to a smaller modulus + @param[in] parms_id The target parms_id + @param[out] destination The plaintext to overwrite with the modulus switched result + @throws std::invalid_argument if plain is not in NTT form + @throws std::invalid_argument if plain is not valid for the encryption parameters + @throws std::invalid_argument if parms_id is not valid for the encryption parameters + @throws std::invalid_argument if plain is already at lower level in modulus chain + than the parameters corresponding to parms_id + @throws std::invalid_argument if, when using scheme_type::CKKS, the scale is too + large for the new encryption parameters + */ + inline void mod_switch_to(const Plaintext &plain, parms_id_type parms_id, + Plaintext &destination) + { + destination = plain; + mod_switch_to_inplace(destination, parms_id); + } + + /** + Given a ciphertext encrypted modulo q_1...q_k, this function switches the + modulus down to q_1...q_{k-1}, scales the message down accordingly, and + stores the result in the destination parameter. Dynamic memory allocations + in the process are allocated from the memory pool pointed to by the given + MemoryPoolHandle. + + @param[in] encrypted The ciphertext to be switched to a smaller modulus + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @param[out] destination The ciphertext to overwrite with the modulus switched result + @throws std::invalid_argument if the scheme is invalid for rescaling + @throws std::invalid_argument if encrypted is not valid for the encryption parameters + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if encrypted is already at lowest level + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if pool is uninitialized + */ + void rescale_to_next(const Ciphertext &encrypted, Ciphertext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()); + + /** + Given a ciphertext encrypted modulo q_1...q_k, this function switches the + modulus down to q_1...q_{k-1} and scales the message down accordingly. Dynamic + memory allocations in the process are allocated from the memory pool pointed + to by the given MemoryPoolHandle. + + @param[in] encrypted The ciphertext to be switched to a smaller modulus + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if the scheme is invalid for rescaling + @throws std::invalid_argument if encrypted is not valid for the encryption parameters + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if encrypted is already at lowest level + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if pool is uninitialized + */ + inline void rescale_to_next_inplace(Ciphertext &encrypted, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + rescale_to_next(encrypted, encrypted, std::move(pool)); + } + + /** + Given a ciphertext encrypted modulo q_1...q_k, this function switches the + modulus down until the parameters reach the given parms_id and scales the + message down accordingly. Dynamic memory allocations in the process are + allocated from the memory pool pointed to by the given MemoryPoolHandle. + + @param[in] encrypted The ciphertext to be switched to a smaller modulus + @param[in] parms_id The target parms_id + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if the scheme is invalid for rescaling + @throws std::invalid_argument if encrypted is not valid for the encryption parameters + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if parms_id is not valid for the encryption parameters + @throws std::invalid_argument if encrypted is already at lower level in modulus chain + than the parameters corresponding to parms_id + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if pool is uninitialized + */ + void rescale_to_inplace(Ciphertext &encrypted, parms_id_type parms_id, + MemoryPoolHandle pool = MemoryManager::GetPool()); + + /** + Given a ciphertext encrypted modulo q_1...q_k, this function switches the + modulus down until the parameters reach the given parms_id, scales the message + down accordingly, and stores the result in the destination parameter. Dynamic + memory allocations in the process are allocated from the memory pool pointed + to by the given MemoryPoolHandle. + + @param[in] encrypted The ciphertext to be switched to a smaller modulus + @param[in] parms_id The target parms_id + @param[out] destination The ciphertext to overwrite with the modulus switched result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if the scheme is invalid for rescaling + @throws std::invalid_argument if encrypted is not valid for the encryption parameters + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if parms_id is not valid for the encryption parameters + @throws std::invalid_argument if encrypted is already at lower level in modulus chain + than the parameters corresponding to parms_id + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if pool is uninitialized + */ + inline void rescale_to(const Ciphertext &encrypted, + parms_id_type parms_id, Ciphertext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + destination = encrypted; + rescale_to_inplace(destination, parms_id, std::move(pool)); + } + + /** + Multiplies several ciphertexts together. This function computes the product + of several ciphertext given as an std::vector and stores the result in the + destination parameter. The multiplication is done in a depth-optimal order, + and relinearization is performed automatically after every multiplication + in the process. In relinearization the given relinearization keys are used. + Dynamic memory allocations in the process are allocated from the memory + pool pointed to by the given MemoryPoolHandle. + + @param[in] encrypteds The ciphertexts to multiply + @param[in] relin_keys The relinearization keys + @param[out] destination The ciphertext to overwrite with the multiplication result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::logic_error if scheme is not scheme_type::BFV + @throws std::invalid_argument if encrypteds is empty + @throws std::invalid_argument if the ciphertexts or relin_keys are not valid for + the encryption parameters + @throws std::invalid_argument if encrypteds are not in the default NTT form + @throws std::invalid_argument if, when using scheme_type::CKKS, the output scale + is too large for the encryption parameters + @throws std::invalid_argument if the size of relin_keys is too small + @throws std::invalid_argument if pool is uninitialized + */ + void multiply_many(std::vector &encrypteds, + const RelinKeys &relin_keys, Ciphertext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()); + + /** + Exponentiates a ciphertext. This functions raises encrypted to a power. + Dynamic memory allocations in the process are allocated from the memory + pool pointed to by the given MemoryPoolHandle. The exponentiation is done + in a depth-optimal order, and relinearization is performed automatically + after every multiplication in the process. In relinearization the given + relinearization keys are used. + + @param[in] encrypted The ciphertext to exponentiate + @param[in] exponent The power to raise the ciphertext to + @param[in] relin_keys The relinearization keys + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::logic_error if scheme is not scheme_type::BFV + @throws std::invalid_argument if encrypted or relin_keys is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if, when using scheme_type::CKKS, the output scale + is too large for the encryption parameters + @throws std::invalid_argument if exponent is zero + @throws std::invalid_argument if the size of relin_keys is too small + @throws std::invalid_argument if pool is uninitialized + */ + void exponentiate_inplace(Ciphertext &encrypted, + std::uint64_t exponent, const RelinKeys &relin_keys, + MemoryPoolHandle pool = MemoryManager::GetPool()); + + /** + Exponentiates a ciphertext. This functions raises encrypted to a power and + stores the result in the destination parameter. Dynamic memory allocations + in the process are allocated from the memory pool pointed to by the given + MemoryPoolHandle. The exponentiation is done in a depth-optimal order, and + relinearization is performed automatically after every multiplication in + the process. In relinearization the given relinearization keys are used. + + @param[in] encrypted The ciphertext to exponentiate + @param[in] exponent The power to raise the ciphertext to + @param[in] relin_keys The relinearization keys + @param[out] destination The ciphertext to overwrite with the power + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::logic_error if scheme is not scheme_type::BFV + @throws std::invalid_argument if encrypted or relin_keys is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if, when using scheme_type::CKKS, the output scale + is too large for the encryption parameters + @throws std::invalid_argument if exponent is zero + @throws std::invalid_argument if the size of relin_keys is too small + @throws std::invalid_argument if pool is uninitialized + */ + inline void exponentiate(const Ciphertext &encrypted, std::uint64_t exponent, + const RelinKeys &relin_keys, Ciphertext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + destination = encrypted; + exponentiate_inplace(destination, exponent, relin_keys, std::move(pool)); + } + + /** + Adds a ciphertext and a plaintext. This function adds a plaintext to + a ciphertext. For the operation to be valid, the plaintext must have less + than degree(poly_modulus) many non-zero coefficients, and each coefficient + must be less than the plaintext modulus, i.e. the plaintext must be a valid + plaintext under the current encryption parameters. + + @param[in] encrypted The ciphertext to add + @param[in] plain The plaintext to add + @throws std::invalid_argument if encrypted or plain is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted or plain is in NTT form + */ + void add_plain_inplace(Ciphertext &encrypted, const Plaintext &plain); + + /** + Adds a ciphertext and a plaintext. This function adds a plaintext to + a ciphertext and stores the result in the destination parameter. For the + operation to be valid, the plaintext must have less than degree(poly_modulus) + many non-zero coefficients, and each coefficient must be less than the + plaintext modulus, i.e. the plaintext must be a valid plaintext under the + current encryption parameters. + + @param[in] encrypted The ciphertext to add + @param[in] plain The plaintext to add + @param[out] destination The ciphertext to overwrite with the addition result + @throws std::invalid_argument if encrypted or plain is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted or plain is in NTT form + */ + inline void add_plain(const Ciphertext &encrypted, const Plaintext &plain, + Ciphertext &destination) + { + destination = encrypted; + add_plain_inplace(destination, plain); + } + + /** + Subtracts a plaintext from a ciphertext. This function subtracts a plaintext + from a ciphertext. For the operation to be valid, the plaintext must have + less than degree(poly_modulus) many non-zero coefficients, and each coefficient + must be less than the plaintext modulus, i.e. the plaintext must be a valid + plaintext under the current encryption parameters. + + @param[in] encrypted The ciphertext to subtract from + @param[in] plain The plaintext to subtract + @throws std::invalid_argument if encrypted or plain is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted or plain is in NTT form + */ + void sub_plain_inplace(Ciphertext &encrypted, const Plaintext &plain); + + /** + Subtracts a plaintext from a ciphertext. This function subtracts a plaintext + from a ciphertext and stores the result in the destination parameter. For + the operation to be valid, the plaintext must have less than degree(poly_modulus) + many non-zero coefficients, and each coefficient must be less than the plaintext + modulus, i.e. the plaintext must be a valid plaintext under the current + encryption parameters. + + @param[in] encrypted The ciphertext to subtract from + @param[in] plain The plaintext to subtract + @param[out] destination The ciphertext to overwrite with the subtraction result + @throws std::invalid_argument if encrypted or plain is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted or plain is in NTT form + */ + inline void sub_plain(const Ciphertext &encrypted, const Plaintext &plain, + Ciphertext &destination) + { + destination = encrypted; + sub_plain_inplace(destination, plain); + } + + /** + Multiplies a ciphertext with a plaintext. This function multiplies a ciphertext + with a plaintext. For the operation to be valid, the plaintext must have + less than degree(poly_modulus) many non-zero coefficients, and each coefficient + must be less than the plaintext modulus, i.e. the plaintext must be a valid + plaintext under the current encryption parameters. Moreover, the plaintext + cannot be identially 0. Dynamic memory allocations in the process are allocated + from the memory pool pointed to by the given MemoryPoolHandle. + + @param[in] encrypted The ciphertext to multiply + @param[in] plain The plaintext to multiply + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if the encrypted or plain is not valid for + the encryption parameters + @throws std::invalid_argument if encrypted and plain are in different NTT forms + @throws std::invalid_argument if plain is zero + @throws std::invalid_argument if, when using scheme_type::CKKS, the output + scale is too large for the encryption parameters + @throws std::invalid_argument if pool is uninitialized + */ + void multiply_plain_inplace(Ciphertext &encrypted, const Plaintext &plain, + MemoryPoolHandle pool = MemoryManager::GetPool()); + + /** + Multiplies a ciphertext with a plaintext. This function multiplies + a ciphertext with a plaintext and stores the result in the destination + parameter. For the operation to be valid, the plaintext must have less + than degree (poly_modulus) many non-zero coefficients, and each coefficient + must be less than the plaintext modulus, i.e. the plaintext must be a valid + plaintext under the current encryption parameters. Moreover, the plaintext + cannot be identially 0. Dynamic memory allocations in the process are allocated + from the memory pool pointed to by the given MemoryPoolHandle. + + @param[in] encrypted The ciphertext to multiply + @param[in] plain The plaintext to multiply + @param[out] destination The ciphertext to overwrite with the multiplication result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if the encrypted or plain is not valid for + the encryption parameters + @throws std::invalid_argument if encrypted and plain are in different NTT forms + @throws std::invalid_argument if plain is zero + @throws std::invalid_argument if, when using scheme_type::CKKS, the output + scale is too large for the encryption parameters + @throws std::invalid_argument if pool is uninitialized + */ + inline void multiply_plain(const Ciphertext &encrypted, + const Plaintext &plain, Ciphertext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + destination = encrypted; + multiply_plain_inplace(destination, plain, std::move(pool)); + } + + /** + Transforms a plaintext to NTT domain. This functions applies the Number + Theoretic Transform to a plaintext by first embedding integers modulo the + plaintext modulus to integers modulo the coefficient modulus and then + performing David Harvey's NTT on the resulting polynomial. The transformation + is done with respect to encryption parameters corresponding to a given parms_id. + For the operation to be valid, the plaintext must have degree less than + poly_modulus_degree and each coefficient must be less than the plaintext + modulus, i.e. the plaintext must be a valid plaintext under the current + encryption parameters. Dynamic memory allocations in the process are allocated + from the memory pool pointed to by the given MemoryPoolHandle. + + @param[in] plain The plaintext to transform + @param[in] parms_id The parms_id with respect to which the NTT is done + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if plain is already in NTT form + @throws std::invalid_argument if plain or parms_id is not valid for the + encryption parameters + @throws std::invalid_argument if pool is uninitialized + */ + void transform_to_ntt_inplace(Plaintext &plain, parms_id_type parms_id, + MemoryPoolHandle pool = MemoryManager::GetPool()); + + /** + Transforms a plaintext to NTT domain. This functions applies the Number + Theoretic Transform to a plaintext by first embedding integers modulo the + plaintext modulus to integers modulo the coefficient modulus and then + performing David Harvey's NTT on the resulting polynomial. The transformation + is done with respect to encryption parameters corresponding to a given + parms_id. The result is stored in the destination_ntt parameter. For the + operation to be valid, the plaintext must have degree less than poly_modulus_degree + and each coefficient must be less than the plaintext modulus, i.e. the plaintext + must be a valid plaintext under the current encryption parameters. Dynamic + memory allocations in the process are allocated from the memory pool pointed + to by the given MemoryPoolHandle. + + @param[in] plain The plaintext to transform + @param[in] parms_id The parms_id with respect to which the NTT is done + @param[out] destinationNTT The plaintext to overwrite with the transformed result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if plain is already in NTT form + @throws std::invalid_argument if plain or parms_id is not valid for the + encryption parameters + @throws std::invalid_argument if pool is uninitialized + */ + inline void transform_to_ntt(const Plaintext &plain, + parms_id_type parms_id, Plaintext &destination_ntt, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + destination_ntt = plain; + transform_to_ntt_inplace(destination_ntt, parms_id, std::move(pool)); + } + + /** + Transforms a ciphertext to NTT domain. This functions applies David Harvey's + Number Theoretic Transform separately to each polynomial of a ciphertext. + + @param[in] encrypted The ciphertext to transform + @throws std::invalid_argument if encrypted is not valid for the encryption + parameters + @throws std::invalid_argument if encrypted is already in NTT form + */ + void transform_to_ntt_inplace(Ciphertext &encrypted); + + /** + Transforms a ciphertext to NTT domain. This functions applies David Harvey's + Number Theoretic Transform separately to each polynomial of a ciphertext. + The result is stored in the destination_ntt parameter. + + @param[in] encrypted The ciphertext to transform + @param[out] destination_ntt The ciphertext to overwrite with the transformed result + @throws std::invalid_argument if encrypted is not valid for the encryption + parameters + @throws std::invalid_argument if encrypted is already in NTT form + */ + inline void transform_to_ntt(const Ciphertext &encrypted, + Ciphertext &destination_ntt) + { + destination_ntt = encrypted; + transform_to_ntt_inplace(destination_ntt); + } + + /** + Transforms a ciphertext back from NTT domain. This functions applies the + inverse of David Harvey's Number Theoretic Transform separately to each + polynomial of a ciphertext. + + @param[in] encrypted_ntt The ciphertext to transform + @throws std::invalid_argument if encrypted_ntt is not valid for the encryption + parameters + @throws std::invalid_argument if encrypted_ntt is not in NTT form + */ + void transform_from_ntt_inplace(Ciphertext &encrypted_ntt); + + /** + Transforms a ciphertext back from NTT domain. This functions applies the + inverse of David Harvey's Number Theoretic Transform separately to each + polynomial of a ciphertext. The result is stored in the destination parameter. + + @param[in] encrypted_ntt The ciphertext to transform + @param[out] destination The ciphertext to overwrite with the transformed result + @throws std::invalid_argument if encrypted_ntt is not valid for the encryption + parameters + @throws std::invalid_argument if encrypted_ntt is not in NTT form + */ + inline void transform_from_ntt(const Ciphertext &encrypted_ntt, + Ciphertext &destination) + { + destination = encrypted_ntt; + transform_from_ntt_inplace(destination); + } + + /** + Applies a Galois automorphism to a ciphertext. To evaluate the Galois + automorphism, an appropriate set of Galois keys must also be provided. + Dynamic memory allocations in the process are allocated from the memory + pool pointed to by the given MemoryPoolHandle. + + + The desired Galois automorphism is given as a Galois element, and must be + an odd integer in the interval [1, M-1], where M = 2*N, and N = degree(poly_modulus). + Used with batching, a Galois element 3^i % M corresponds to a cyclic row + rotation i steps to the left, and a Galois element 3^(N/2-i) % M corresponds + to a cyclic row rotation i steps to the right. The Galois element M-1 corresponds + to a column rotation (row swap) in BFV, and complex conjugation in CKKS. + In the polynomial view (not batching), a Galois automorphism by a Galois + element p changes Enc(plain(x)) to Enc(plain(x^p)). + + @param[in] encrypted The ciphertext to apply the Galois automorphism to + @param[in] galois_elt The Galois element + @param[in] galois_keys The Galois keys + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if encrypted or galois_keys is not valid for + the encryption parameters + @throws std::invalid_argument if galois_keys do not correspond to the top + level parameters in the current context + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if the Galois element is not valid + @throws std::invalid_argument if necessary Galois keys are not present + @throws std::invalid_argument if pool is uninitialized + */ + void apply_galois_inplace(Ciphertext &encrypted, + std::uint64_t galois_elt, const GaloisKeys &galois_keys, + MemoryPoolHandle pool = MemoryManager::GetPool()); + + /** + Applies a Galois automorphism to a ciphertext and writes the result to the + destination parameter. To evaluate the Galois automorphism, an appropriate + set of Galois keys must also be provided. Dynamic memory allocations in + the process are allocated from the memory pool pointed to by the given + MemoryPoolHandle. + + The desired Galois automorphism is given as a Galois element, and must be + an odd integer in the interval [1, M-1], where M = 2*N, and N = degree(poly_modulus). + Used with batching, a Galois element 3^i % M corresponds to a cyclic row + rotation i steps to the left, and a Galois element 3^(N/2-i) % M corresponds + to a cyclic row rotation i steps to the right. The Galois element M-1 corresponds + to a column rotation (row swap) in BFV, and complex conjugation in CKKS. + In the polynomial view (not batching), a Galois automorphism by a Galois + element p changes Enc(plain(x)) to Enc(plain(x^p)). + + @param[in] encrypted The ciphertext to apply the Galois automorphism to + @param[in] galois_elt The Galois element + @param[in] galois_keys The Galois keys + @param[out] destination The ciphertext to overwrite with the result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if encrypted or galois_keys is not valid for + the encryption parameters + @throws std::invalid_argument if galois_keys do not correspond to the top + level parameters in the current context + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if the Galois element is not valid + @throws std::invalid_argument if necessary Galois keys are not present + @throws std::invalid_argument if pool is uninitialized + */ + inline void apply_galois(const Ciphertext &encrypted, + std::uint64_t galois_elt, const GaloisKeys &galois_keys, + Ciphertext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + destination = encrypted; + apply_galois_inplace(destination, galois_elt, galois_keys, std::move(pool)); + } + + /** + Rotates plaintext matrix rows cyclically. When batching is used with the + BFV scheme, this function rotates the encrypted plaintext matrix rows + cyclically to the left (steps > 0) or to the right (steps < 0). Since + the size of the batched matrix is 2-by-(N/2), where N is the degree of + the polynomial modulus, the number of steps to rotate must have absolute + value at most N/2-1. Dynamic memory allocations in the process are allocated + from the memory pool pointed to by the given MemoryPoolHandle. + + + @param[in] encrypted The ciphertext to rotate + @param[in] steps The number of steps to rotate (negative left, positive right) + @param[in] galois_keys The Galois keys + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::logic_error if scheme is not scheme_type::BFV + @throws std::logic_error if the encryption parameters do not support batching + @throws std::invalid_argument if encrypted or galois_keys is not valid for + the encryption parameters + @throws std::invalid_argument if galois_keys do not correspond to the top + level parameters in the current context + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if steps has too big absolute value + @throws std::invalid_argument if necessary Galois keys are not present + @throws std::invalid_argument if pool is uninitialized + */ + inline void rotate_rows_inplace(Ciphertext &encrypted, + int steps, const GaloisKeys &galois_keys, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + if (context_->context_data()->parms().scheme() != scheme_type::BFV) + { + throw std::logic_error("unsupported scheme"); + } + rotate_internal(encrypted, steps, galois_keys, std::move(pool)); + } + + /** + Rotates plaintext matrix rows cyclically. When batching is used with the + BFV scheme, this function rotates the encrypted plaintext matrix rows + cyclically to the left (steps > 0) or to the right (steps < 0) and writes + the result to the destination parameter. Since the size of the batched + matrix is 2-by-(N/2), where N is the degree of the polynomial modulus, + the number of steps to rotate must have absolute value at most N/2-1. Dynamic + memory allocations in the process are allocated from the memory pool pointed + to by the given MemoryPoolHandle. + + @param[in] encrypted The ciphertext to rotate + @param[in] steps The number of steps to rotate (negative left, positive right) + @param[in] galois_keys The Galois keys + @param[out] destination The ciphertext to overwrite with the rotated result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::logic_error if scheme is not scheme_type::BFV + @throws std::logic_error if the encryption parameters do not support batching + @throws std::invalid_argument if encrypted or galois_keys is not valid for + the encryption parameters + @throws std::invalid_argument if galois_keys do not correspond to the top + level parameters in the current context + @throws std::invalid_argument if encrypted is in NTT form + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if steps has too big absolute value + @throws std::invalid_argument if necessary Galois keys are not present + @throws std::invalid_argument if pool is uninitialized + */ + inline void rotate_rows(const Ciphertext &encrypted, int steps, + const GaloisKeys &galois_keys, Ciphertext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + destination = encrypted; + rotate_rows_inplace(destination, steps, galois_keys, std::move(pool)); + } + + /** + Rotates plaintext matrix columns cyclically. When batching is used with + the BFV scheme, this function rotates the encrypted plaintext matrix + columns cyclically. Since the size of the batched matrix is 2-by-(N/2), + where N is the degree of the polynomial modulus, this means simply swapping + the two rows. Dynamic memory allocations in the process are allocated from + the memory pool pointed to by the given MemoryPoolHandle. + + + @param[in] encrypted The ciphertext to rotate + @param[in] galois_keys The Galois keys + @param[out] destination The ciphertext to overwrite with the rotated result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::logic_error if scheme is not scheme_type::BFV + @throws std::logic_error if the encryption parameters do not support batching + @throws std::invalid_argument if encrypted or galois_keys is not valid for + the encryption parameters + @throws std::invalid_argument if galois_keys do not correspond to the top + level parameters in the current context + @throws std::invalid_argument if encrypted is in NTT form + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if necessary Galois keys are not present + @throws std::invalid_argument if pool is uninitialized + */ + inline void rotate_columns_inplace(Ciphertext &encrypted, + const GaloisKeys &galois_keys, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + if (context_->context_data()->parms().scheme() != scheme_type::BFV) + { + throw std::logic_error("unsupported scheme"); + } + conjugate_internal(encrypted, galois_keys, std::move(pool)); + } + + /** + Rotates plaintext matrix columns cyclically. When batching is used with + the BFV scheme, this function rotates the encrypted plaintext matrix columns + cyclically, and writes the result to the destination parameter. Since the + size of the batched matrix is 2-by-(N/2), where N is the degree of the + polynomial modulus, this means simply swapping the two rows. Dynamic memory + allocations in the process are allocated from the memory pool pointed to + by the given MemoryPoolHandle. + + @param[in] encrypted The ciphertext to rotate + @param[in] galois_keys The Galois keys + @param[out] destination The ciphertext to overwrite with the rotated result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::logic_error if scheme is not scheme_type::BFV + @throws std::logic_error if the encryption parameters do not support batching + @throws std::invalid_argument if encrypted or galois_keys is not valid for + the encryption parameters + @throws std::invalid_argument if galois_keys do not correspond to the top + level parameters in the current context + @throws std::invalid_argument if encrypted is in NTT form + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if necessary Galois keys are not present + @throws std::invalid_argument if pool is uninitialized + */ + inline void rotate_columns(const Ciphertext &encrypted, + const GaloisKeys &galois_keys, Ciphertext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + destination = encrypted; + rotate_columns_inplace(destination, galois_keys, std::move(pool)); + } + + /** + Rotates plaintext vector cyclically. When using the CKKS scheme, this function + rotates the encrypted plaintext vector cyclically to the left (steps > 0) + or to the right (steps < 0). Since the size of the batched matrix is + 2-by-(N/2), where N is the degree of the polynomial modulus, the number + of steps to rotate must have absolute value at most N/2-1. Dynamic memory + allocations in the process are allocated from the memory pool pointed to + by the given MemoryPoolHandle. + + @param[in] encrypted The ciphertext to rotate + @param[in] steps The number of steps to rotate (negative left, positive right) + @param[in] galois_keys The Galois keys + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::logic_error if scheme is not scheme_type::CKKS + @throws std::invalid_argument if encrypted or galois_keys is not valid for + the encryption parameters + @throws std::invalid_argument if galois_keys do not correspond to the top + level parameters in the current context + @throws std::invalid_argument if encrypted is not in the default NTT form + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if steps has too big absolute value + @throws std::invalid_argument if necessary Galois keys are not present + @throws std::invalid_argument if pool is uninitialized + */ + inline void rotate_vector_inplace(Ciphertext &encrypted, + int steps, const GaloisKeys &galois_keys, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + if (context_->context_data()->parms().scheme() != scheme_type::CKKS) + { + throw std::logic_error("unsupported scheme"); + } + rotate_internal(encrypted, steps, galois_keys, std::move(pool)); + } + + /** + Rotates plaintext vector cyclically. When using the CKKS scheme, this function + rotates the encrypted plaintext vector cyclically to the left (steps > 0) + or to the right (steps < 0) and writes the result to the destination parameter. + Since the size of the batched matrix is 2-by-(N/2), where N is the degree + of the polynomial modulus, the number of steps to rotate must have absolute + value at most N/2-1. Dynamic memory allocations in the process are allocated + from the memory pool pointed to by the given MemoryPoolHandle. + + @param[in] encrypted The ciphertext to rotate + @param[in] steps The number of steps to rotate (negative left, positive right) + @param[in] galois_keys The Galois keys + @param[out] destination The ciphertext to overwrite with the rotated result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::logic_error if scheme is not scheme_type::CKKS + @throws std::invalid_argument if encrypted or galois_keys is not valid for + the encryption parameters + @throws std::invalid_argument if galois_keys do not correspond to the top + level parameters in the current context + @throws std::invalid_argument if encrypted is in NTT form + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if steps has too big absolute value + @throws std::invalid_argument if necessary Galois keys are not present + @throws std::invalid_argument if pool is uninitialized + */ + inline void rotate_vector(const Ciphertext &encrypted, int steps, + const GaloisKeys &galois_keys, Ciphertext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + destination = encrypted; + rotate_vector_inplace(destination, steps, galois_keys, std::move(pool)); + } + + /** + Complex conjugates plaintext slot values. When using the CKKS scheme, this + function complex conjugates all values in the underlying plaintext. Dynamic + memory allocations in the process are allocated from the memory pool pointed + to by the given MemoryPoolHandle. + + @param[in] encrypted The ciphertext to rotate + @param[in] galois_keys The Galois keys + @param[out] destination The ciphertext to overwrite with the rotated result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::logic_error if scheme is not scheme_type::CKKS + @throws std::invalid_argument if encrypted or galois_keys is not valid for + the encryption parameters + @throws std::invalid_argument if galois_keys do not correspond to the top + level parameters in the current context + @throws std::invalid_argument if encrypted is in NTT form + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if necessary Galois keys are not present + @throws std::invalid_argument if pool is uninitialized + */ + inline void complex_conjugate_inplace(Ciphertext &encrypted, + const GaloisKeys &galois_keys, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + if (context_->context_data()->parms().scheme() != scheme_type::CKKS) + { + throw std::logic_error("unsupported scheme"); + } + conjugate_internal(encrypted, galois_keys, std::move(pool)); + } + + /** + Complex conjugates plaintext slot values. When using the CKKS scheme, this + function complex conjugates all values in the underlying plaintext, and + writes the result to the destination parameter. Dynamic memory allocations + in the process are allocated from the memory pool pointed to by the given + MemoryPoolHandle. + + @param[in] encrypted The ciphertext to rotate + @param[in] galois_keys The Galois keys + @param[out] destination The ciphertext to overwrite with the rotated result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::logic_error if scheme is not scheme_type::CKKS + @throws std::invalid_argument if encrypted or galois_keys is not valid for + the encryption parameters + @throws std::invalid_argument if galois_keys do not correspond to the top + level parameters in the current context + @throws std::invalid_argument if encrypted is in NTT form + @throws std::invalid_argument if encrypted has size larger than 2 + @throws std::invalid_argument if necessary Galois keys are not present + @throws std::invalid_argument if pool is uninitialized + */ + inline void complex_conjugate(const Ciphertext &encrypted, + const GaloisKeys &galois_keys, Ciphertext &destination, + MemoryPoolHandle pool = MemoryManager::GetPool()) + { + destination = encrypted; + complex_conjugate_inplace(destination, galois_keys, std::move(pool)); + } + + private: + Evaluator(const Evaluator ©) = delete; + + Evaluator(Evaluator &&source) = delete; + + Evaluator &operator =(const Evaluator &assign) = delete; + + Evaluator &operator =(Evaluator &&assign) = delete; + + void bfv_multiply(Ciphertext &encrypted1, const Ciphertext &encrypted2, + MemoryPoolHandle pool); + + void ckks_multiply(Ciphertext &encrypted1, const Ciphertext &encrypted2, + MemoryPoolHandle pool); + + void bfv_square(Ciphertext &encrypted, MemoryPoolHandle pool); + + void ckks_square(Ciphertext &encrypted, MemoryPoolHandle pool); + + void relinearize_internal(Ciphertext &encrypted, const RelinKeys &relin_keys, + std::size_t destination_size, MemoryPoolHandle pool); + + void mod_switch_scale_to_next(const Ciphertext &encrypted, Ciphertext &destination, + MemoryPoolHandle pool); + + void mod_switch_drop_to_next(const Ciphertext &encrypted, Ciphertext &destination); + + void mod_switch_drop_to_next(Plaintext &plain); + + void rotate_internal(Ciphertext &encrypted, int steps, + const GaloisKeys &galois_keys, MemoryPoolHandle pool); + + inline void conjugate_internal(Ciphertext &encrypted, + const GaloisKeys &galois_keys, MemoryPoolHandle pool) + { + // Verify parameters. + auto context_data_ptr = context_->context_data(encrypted.parms_id()); + if (!context_data_ptr) + { + throw std::invalid_argument("encrypted is not valid for encryption parameters"); + } + + // Extract encryption parameters. + auto &context_data = *context_data_ptr; + if (!context_data.qualifiers().using_batching) + { + throw std::logic_error("encryption parameters do not support batching"); + } + + auto &parms = context_data.parms(); + std::size_t coeff_count = parms.poly_modulus_degree(); + + // Perform rotation and key switching + apply_galois_inplace(encrypted, util::steps_to_galois_elt(0, coeff_count), + galois_keys, std::move(pool)); + } + + inline void decompose_single_coeff(const SEALContext::ContextData &context_data, + const std::uint64_t *value, std::uint64_t *destination, util::MemoryPool &pool) + { + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + std::size_t coeff_mod_count = coeff_modulus.size(); +#ifdef SEAL_DEBUG + if (value == nullptr) + { + throw std::invalid_argument("value cannot be null"); + } + if (destination == nullptr) + { + throw std::invalid_argument("destination cannot be null"); + } + if (destination == value) + { + throw std::invalid_argument("value cannot be the same as destination"); + } +#endif + if (coeff_mod_count == 1) + { + util::set_uint_uint(value, coeff_mod_count, destination); + return; + } + + auto value_copy(util::allocate_uint(coeff_mod_count, pool)); + for (std::size_t j = 0; j < coeff_mod_count; j++) + { + //destination[j] = util::modulo_uint( + // value, coeff_mod_count, coeff_modulus_[j], pool); + + // Manually inlined for efficiency + // Make a fresh copy of value + util::set_uint_uint(value, coeff_mod_count, value_copy.get()); + + // Starting from the top, reduce always 128-bit blocks + for (std::size_t k = coeff_mod_count - 1; k--; ) + { + value_copy[k] = util::barrett_reduce_128( + value_copy.get() + k, coeff_modulus[j]); + } + destination[j] = value_copy[0]; + } + } + + inline void decompose(const SEALContext::ContextData &context_data, + const std::uint64_t *value, std::uint64_t *destination, util::MemoryPool &pool) + { + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + std::size_t coeff_count = parms.poly_modulus_degree(); + std::size_t coeff_mod_count = coeff_modulus.size(); + std::size_t rns_poly_uint64_count = + util::mul_safe(coeff_mod_count, coeff_count); +#ifdef SEAL_DEBUG + if (value == nullptr) + { + throw std::invalid_argument("value cannot be null"); + } + if (destination == nullptr) + { + throw std::invalid_argument("destination cannot be null"); + } + if (destination == value) + { + throw std::invalid_argument("value cannot be the same as destination"); + } +#endif + if (coeff_mod_count == 1) + { + util::set_uint_uint(value, rns_poly_uint64_count, destination); + return; + } + + auto value_copy(util::allocate_uint(coeff_mod_count, pool)); + for (size_t i = 0; i < coeff_count; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + //destination[i + (j * coeff_count)] = + // util::modulo_uint(value + (i * coeff_mod_count), + // coeff_mod_count, coeff_modulus_[j], pool); + + // Manually inlined for efficiency + // Make a fresh copy of value + (i * coeff_mod_count) + util::set_uint_uint( + value + (i * coeff_mod_count), coeff_mod_count, value_copy.get()); + + // Starting from the top, reduce always 128-bit blocks + for (std::size_t k = coeff_mod_count - 1; k--; ) + { + value_copy[k] = util::barrett_reduce_128( + value_copy.get() + k, coeff_modulus[j]); + } + destination[i + (j * coeff_count)] = value_copy[0]; + } + } + } + + void bfv_relinearize_one_step(std::uint64_t *encrypted, std::size_t encrypted_size, + const SEALContext::ContextData &context_data, + const RelinKeys &relin_keys, util::MemoryPool &pool); + + void ckks_relinearize_one_step(std::uint64_t *encrypted, std::size_t encrypted_size, + const SEALContext::ContextData &context_data, + const RelinKeys &relin_keys, util::MemoryPool &pool); + + void multiply_plain_normal(Ciphertext &encrypted, const Plaintext &plain, + util::MemoryPool &pool); + + void multiply_plain_ntt(Ciphertext &encrypted_ntt, const Plaintext &plain_ntt); + + void populate_Zmstar_to_generator(); + + std::shared_ptr context_{ nullptr }; + + std::map> Zmstar_to_generator_{}; + }; +} diff --git a/src/seal/galoiskeys.cpp b/src/seal/galoiskeys.cpp new file mode 100644 index 000000000..b93196dd4 --- /dev/null +++ b/src/seal/galoiskeys.cpp @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/galoiskeys.h" +#include "seal/util/common.h" +#include + +using namespace std; +using namespace seal::util; + +namespace seal +{ + GaloisKeys &GaloisKeys::operator =(const GaloisKeys &assign) + { + // Check for self-assignment + if (this == &assign) + { + return *this; + } + + // Copy over fields + parms_id_ = assign.parms_id_; + decomposition_bit_count_ = assign.decomposition_bit_count_; + + // Then copy over keys + keys_.clear(); + size_t keys_dim1 = assign.keys_.size(); + keys_.reserve(keys_dim1); + for (size_t i = 0; i < keys_dim1; i++) + { + size_t keys_dim2 = assign.keys_[i].size(); + keys_.emplace_back(); + keys_[i].reserve(keys_dim2); + for (size_t j = 0; j < keys_dim2; j++) + { + keys_[i].emplace_back(pool_); + keys_[i][j] = assign.keys_[i][j]; + } + } + + return *this; + } + + bool GaloisKeys::is_valid_for(shared_ptr context) const noexcept + { + // Verify parameters + if (!context || !context->parameters_set()) + { + return false; + } + if (parms_id_ != context->first_parms_id()) + { + return false; + } + + for (auto &a : keys_) + { + for (auto &b : a) + { + if (!b.is_valid_for(context) || !b.is_ntt_form() || + b.parms_id() != parms_id_) + { + return false; + } + } + } + + return true; + } + + void GaloisKeys::save(std::ostream &stream) const + { + auto old_except_mask = stream.exceptions(); + try + { + // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit + stream.exceptions(ios_base::badbit | ios_base::failbit); + + int32_t decomposition_bit_count32 = + safe_cast(decomposition_bit_count_); + + // Save the parms_id + stream.write(reinterpret_cast(&parms_id_), + sizeof(parms_id_type)); + + // Save the decomposition bit count + stream.write(reinterpret_cast(&decomposition_bit_count32), + sizeof(int32_t)); + + // Save the size of keys_ + uint64_t keys_dim1 = static_cast(keys_.size()); + stream.write(reinterpret_cast(&keys_dim1), sizeof(uint64_t)); + + // Now loop again over keys_dim1 + for (size_t index = 0; index < keys_dim1; index++) + { + // Save second dimension of keys_ + uint64_t keys_dim2 = static_cast(keys_[index].size()); + stream.write(reinterpret_cast(&keys_dim2), sizeof(uint64_t)); + + // Loop over keys_dim2 and save all (or none) + for (size_t j = 0; j < keys_dim2; j++) + { + // Save the key + keys_[index][j].save(stream); + } + } + } + catch (const exception &) + { + stream.exceptions(old_except_mask); + throw; + } + + stream.exceptions(old_except_mask); + } + + void GaloisKeys::unsafe_load(std::istream &stream) + { + auto old_except_mask = stream.exceptions(); + try + { + // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit + stream.exceptions(ios_base::badbit | ios_base::failbit); + + // Clear current keys + keys_.clear(); + + // Read the parms_id + stream.read(reinterpret_cast(&parms_id_), + sizeof(parms_id_type)); + + // Read the decomposition_bit_count + int32_t decomposition_bit_count32 = 0; + stream.read(reinterpret_cast(&decomposition_bit_count32), + sizeof(int32_t)); + decomposition_bit_count_ = safe_cast(decomposition_bit_count32); + + // Read in the size of keys_ + uint64_t keys_dim1 = 0; + stream.read(reinterpret_cast(&keys_dim1), sizeof(uint64_t)); + + // Reserve first for dimension of keys_ + keys_.reserve(keys_dim1); + + // Loop over the first dimension of keys_ + for (size_t index = 0; index < keys_dim1; index++) + { + // Read the size of the second dimension + uint64_t keys_dim2 = 0; + stream.read(reinterpret_cast(&keys_dim2), sizeof(uint64_t)); + + // Don't resize; only reserve + keys_.emplace_back(); + keys_.back().reserve(keys_dim2); + for (size_t j = 0; j < keys_dim2; j++) + { + Ciphertext new_key(pool_); + new_key.unsafe_load(stream); + keys_[index].emplace_back(move(new_key)); + } + } + } + catch (const exception &) + { + stream.exceptions(old_except_mask); + throw; + } + + stream.exceptions(old_except_mask); + } +} diff --git a/src/seal/galoiskeys.h b/src/seal/galoiskeys.h new file mode 100644 index 000000000..59ab251ed --- /dev/null +++ b/src/seal/galoiskeys.h @@ -0,0 +1,247 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include "seal/ciphertext.h" +#include "seal/memorymanager.h" +#include "seal/encryptionparams.h" + +namespace seal +{ + /** + Class to store Galois keys. + + @par Slot Rotations + Galois keys are used together with batching (BatchEncoder). If the polynomial modulus + is a polynomial of degree N, in batching the idea is to view a plaintext polynomial as + a 2-by-(N/2) matrix of integers modulo plaintext modulus. Normal homomorphic computations + operate on such encrypted matrices element (slot) wise. However, special rotation + operations allow us to also rotate the matrix rows cyclically in either direction, and + rotate the columns (swap the rows). These operations require the Galois keys. + + @par Decomposition Bit Count + Decomposition bit count (dbc) is a parameter that describes a performance trade-off in + the rotation operation. Its function is exactly the same as in relinearization. Namely, + the polynomials in the ciphertexts (with large coefficients) get decomposed into a smaller + base 2^dbc, coefficient-wise. Each of the decomposition factors corresponds to a piece of + data in the Galois keys, so the smaller the dbc is, the larger the Galois keys are. + Moreover, a smaller dbc results in less invariant noise budget being consumed in the + rotation operation. However, using a large dbc is much faster, and often one would want + to optimize the dbc to be as large as possible for performance. The dbc is upper-bounded + by the value of 60, and lower-bounded by the value of 1. + + @par Thread Safety + In general, reading from GaloisKeys is thread-safe as long as no other thread is + concurrently mutating it. This is due to the underlying data structure storing the + Galois keys not being thread-safe. + + @see SecretKey for the class that stores the secret key. + @see PublicKey for the class that stores the public key. + @see RelinKeys for the class that stores the relinearization keys. + @see KeyGenerator for the class that generates the Galois keys. + */ + class GaloisKeys + { + friend class KeyGenerator; + + public: + /** + Creates an empty set of Galois keys. + */ + GaloisKeys() = default; + + /** + Creates a new GaloisKeys instance by copying a given instance. + + @param[in] copy The GaloisKeys to copy from + */ + GaloisKeys(const GaloisKeys ©) = default; + + /** + Creates a new GaloisKeys instance by moving a given instance. + + @param[in] source The GaloisKeys to move from + */ + GaloisKeys(GaloisKeys &&source) = default; + + /** + Copies a given GaloisKeys instance to the current one. + + @param[in] assign The GaloisKeys to copy from + */ + GaloisKeys &operator =(const GaloisKeys &assign); + + /** + Moves a given GaloisKeys instance to the current one. + + @param[in] assign The GaloisKeys to move from + */ + GaloisKeys &operator =(GaloisKeys &&assign) = default; + + /** + Returns the current number of Galois keys. + */ + inline std::size_t size() const + { + return std::accumulate(keys_.begin(), keys_.end(), std::size_t(0), + [](std::size_t current_size, const std::vector &next_key) + { + return current_size + static_cast(next_key.size() > 0); + }); + } + + /* + Returns the decomposition bit count. + */ + inline int decomposition_bit_count() const noexcept + { + return decomposition_bit_count_; + } + + /** + Returns a reference to the Galois keys data. + */ + inline auto &data() noexcept + { + return keys_; + } + + /** + Returns a const reference to the Galois keys data. + */ + inline auto &data() const noexcept + { + return keys_; + } + + /** + Returns a const reference to a Galois key. The returned Galois key corresponds + to the given Galois element. + + @param[in] galois_elt The Galois element + @throw std::invalid_argument if the key corresponding to galois_elt does not exist + */ + inline auto &key(std::uint64_t galois_elt) const + { + if (!has_key(galois_elt)) + { + throw std::invalid_argument("requested key does not exist"); + } + std::uint64_t index = (galois_elt - 1) >> 1; + return keys_[index]; + } + + /** + Returns whether a Galois key corresponding to a given Galois element exists. + + @param[in] galois_elt The Galois element + @throw std::invalid_argument if Galois element is not valid + */ + inline bool has_key(std::uint64_t galois_elt) const + { + // Verify parameters + if (!(galois_elt & 1)) + { + throw std::invalid_argument("galois element is not valid"); + } + std::uint64_t index = (galois_elt - 1) >> 1; + return (index < keys_.size()) && !keys_[index].empty(); + } + + /** + Returns a reference to parms_id. + + @see EncryptionParameters for more information about parms_id. + */ + inline auto &parms_id() noexcept + { + return parms_id_; + } + + /** + Returns a const reference to parms_id. + + @see EncryptionParameters for more information about parms_id. + */ + inline auto &parms_id() const noexcept + { + return parms_id_; + } + + /** + Check whether the current GaloisKeys is valid for a given SEALContext. If + the given SEALContext is not set, the encryption parameters are invalid, + or the GaloisKeys data does not match the SEALContext, this function returns + false. Otherwise, returns true. + + @param[in] context The SEALContext + */ + bool is_valid_for(std::shared_ptr context) const noexcept; + + /** + Saves the GaloisKeys instance to an output stream. The output is in binary + format and not human-readable. The output stream must have the "binary" + flag set. + + @param[in] stream The stream to save the GaloisKeys to + @throws std::exception if the GaloisKeys could not be written to stream + */ + void save(std::ostream &stream) const; + + /** + Loads a GaloisKeys from an input stream overwriting the current GaloisKeys. + No checking of the validity of the GaloisKeys data against encryption + parameters is performed. This function should not be used unless the + GaloisKeys comes from a fully trusted source. + + @param[in] stream The stream to load the GaloisKeys from + @throws std::exception if a valid GaloisKeys could not be read from stream + */ + void unsafe_load(std::istream &stream); + + /** + Loads a GaloisKeys from an input stream overwriting the current GaloisKeys. + The loaded GaloisKeys is verified to be valid for the given SEALContext. + + @param[in] context The SEALContext + @param[in] stream The stream to load the GaloisKeys from + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + @throws std::exception if a valid GaloisKeys could not be read from stream + @throws std::invalid_argument if the loaded GaloisKeys is invalid for the + context + */ + inline void load(std::shared_ptr context, std::istream &stream) + { + unsafe_load(stream); + if (!is_valid_for(std::move(context))) + { + throw std::invalid_argument("GaloisKeys data is invalid"); + } + } + + /** + Returns the currently used MemoryPoolHandle. + */ + inline MemoryPoolHandle pool() const noexcept + { + return pool_; + } + + private: + MemoryPoolHandle pool_ = MemoryManager::GetPool(); + + parms_id_type parms_id_ = parms_id_zero; + + /** + The vector of Galois keys. + */ + std::vector> keys_{}; + + int decomposition_bit_count_ = 0; + }; +} diff --git a/src/seal/intarray.h b/src/seal/intarray.h new file mode 100644 index 000000000..7168ac7b4 --- /dev/null +++ b/src/seal/intarray.h @@ -0,0 +1,487 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "seal/memorymanager.h" +#include "seal/util/pointer.h" +#include "seal/util/defines.h" +#include "seal/util/common.h" +#include +#include +#include +#include +#include + +namespace seal +{ + /** + A resizable container for storing an array of integral data types. The + allocations are done from a memory pool. The IntArray class is mainly + intended for internal use and provides the underlying data structure for + Plaintext and Ciphertext classes. + + @par Size and Capacity + IntArray allows the user to pre-allocate memory (capacity) for the array + in cases where the array is known to be resized in the future and memory + moves are to be avoided at the time of resizing. The size of the IntArray + can never exceed its capacity. The capacity and size can be changed using + the reserve and resize functions, respectively. + + @par Thread Safety + In general, reading from IntArray is thread-safe as long as no other thread + is concurrently mutating it. + */ + template::value>> + class IntArray + { + public: + using size_type = std::size_t; + + /** + Creates a new IntArray. No memory is allocated by this constructor. + + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if pool is uninitialized + */ + IntArray(MemoryPoolHandle pool = MemoryManager::GetPool()) : + pool_(std::move(pool)) + { + if (!pool_) + { + throw std::invalid_argument("pool is uninitialized"); + } + } + + /** + Creates a new IntArray with given size. + + @param[in] size The size of the array + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if size is less than zero + @throws std::invalid_argument if pool is uninitialized + */ + explicit IntArray(size_type size, + MemoryPoolHandle pool = MemoryManager::GetPool()) : + pool_(std::move(pool)) + { + if (!pool_) + { + throw std::invalid_argument("pool is uninitialized"); + } + + // Reserve memory, resize, and set to zero + resize(size); + } + + /** + Creates a new IntArray with given capacity and size. + + @param[in] capacity The capacity of the array + @param[in] size The size of the array + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if capacity is less than size + @throws std::invalid_argument if capacity is less than zero + @throws std::invalid_argument if size is less than zero + @throws std::invalid_argument if pool is uninitialized + */ + explicit IntArray(size_type capacity, size_type size, + MemoryPoolHandle pool = MemoryManager::GetPool()) : + pool_(std::move(pool)) + { + if (!pool_) + { + throw std::invalid_argument("pool is uninitialized"); + } + if (capacity < size) + { + throw std::invalid_argument("capacity cannot be smaller than size"); + } + + // Reserve memory, resize, and set to zero + reserve(capacity); + resize(size); + } + + /** + Constructs a new IntArray by copying a given one. + + @param[in] copy The IntArray to copy from + */ + IntArray(const IntArray ©) : + pool_(MemoryManager::GetPool()), + capacity_(copy.size_), + size_(copy.size_), + data_(util::allocate(copy.size_, pool_)) + { + // Copy over value + std::copy_n(copy.cbegin(), copy.size_, begin()); + } + + /** + Constructs a new IntArray by moving a given one. + + @param[in] source The IntArray to move from + */ + IntArray(IntArray &&source) noexcept : + pool_(std::move(source.pool_)), + capacity_(source.capacity_), + size_(source.size_), + data_(std::move(source.data_)) + { + } + + /** + Returns a pointer to the beginning of the array data. + */ + inline T* begin() noexcept + { + return data_.get(); + } + + /** + Returns a constant pointer to the beginning of the array data. + */ + inline const T* cbegin() const noexcept + { + return data_.get(); + } + + /** + Returns a pointer to the end of the array data. + */ + inline T* end() noexcept + { + return size_ ? begin() + size_ : begin(); + } + + /** + Returns a constant pointer to the end of the array data. + */ + inline const T* cend() const noexcept + { + return size_ ? cbegin() + size_ : cbegin(); + } +#ifdef SEAL_USE_MSGSL_SPAN + /** + Returns a span pointing to the beginning of the IntArray. + */ + inline gsl::span span() + { + return gsl::span( + begin(), static_cast(size_)); + } + + /** + Returns a span pointing to the beginning of the IntArray. + */ + inline gsl::span span() const + { + return gsl::span( + cbegin(), static_cast(size_)); + } +#endif + /** + Returns a constant reference to the array element at a given index. + This function performs bounds checking and will throw an error if + the index is out of range. + + @param[in] index The index of the array element + @throws std::out_of_range if index is out of range + */ + inline const T &at(size_type index) const + { + if (index >= size_) + { + throw std::out_of_range("index must be within [0, size)"); + } + return data_[index]; + } + + /** + Returns a reference to the array element at a given index. This + function performs bounds checking and will throw an error if the + index is out of range. + + @param[in] index The index of the array element + @throws std::out_of_range if index is out of range + */ + inline T &at(size_type index) + { + if (index >= size_) + { + throw std::out_of_range("index must be within [0, size)"); + } + return data_[index]; + } + + /** + Returns a constant reference to the array element at a given index. + This function does not perform bounds checking. + + @param[in] index The index of the array element + */ + inline const T &operator [](size_type index) const + { + return data_[index]; + } + + /** + Returns a reference to the array element at a given index. This + function does not perform bounds checking. + + @param[in] index The index of the array element + */ + inline T &operator [](size_type index) + { + return data_[index]; + } + + /** + Returns whether the array has size zero. + */ + inline bool empty() const noexcept + { + return (size_ == 0); + } + + /** + Returns the largest possible array size. + */ + inline size_type max_size() const noexcept + { + return std::numeric_limits::max(); + } + + /** + Returns the size of the array. + */ + inline size_type size() const noexcept + { + return size_; + } + + /** + Returns the capacity of the array. + */ + inline size_type capacity() const noexcept + { + return capacity_; + } + + /** + Swaps the current array with a given array. + */ + inline void swap_with(IntArray &other) noexcept + { + std::swap(pool_, other.pool_); + std::swap(capacity_, other.capacity_); + std::swap(size_, other.size_); + data_.swap_with(other.data_); + } + + /** + Returns the currently used MemoryPoolHandle. + */ + inline MemoryPoolHandle pool() const noexcept + { + return pool_; + } + + /** + Releases any allocated memory to the memory pool and sets the size + and capacity of the array to zero. + */ + inline void release() noexcept + { + capacity_ = 0; + size_ = 0; + data_.release(); + } + + /** + Sets the size of the array to zero. The capacity is not changed. + */ + inline void clear() noexcept + { + size_ = 0; + } + + /** + Allocates enough memory for storing a given number of elements without + changing the size of the array. If the given capacity is smaller than + the current size, the size is automatically set to equal the new capacity. + + @param[in] capacity The capacity of the array + */ + inline void reserve(size_type capacity) + { + size_type copy_size = std::min(capacity, size_); + + // Create new allocation and copy over value + auto new_data(util::allocate(capacity, pool_)); + std::copy_n(cbegin(), copy_size, new_data.get()); + data_.swap_with(new_data); + + // Set the coeff_count and capacity + capacity_ = capacity; + size_ = copy_size; + } + + /** + Reallocates the array so that its capacity exactly matches its size. + */ + inline void shrink_to_fit() + { + reserve(size_); + } + + /** + Resizes the array to given size. When resizing to larger size the data + in the array remains unchanged and any new space is initialized to zero; + when resizing to smaller size the last elements of the array are dropped. + If the capacity is not already large enough to hold the new size, the + array is also reallocated. + + @param[in] size The size of the array + */ + inline void resize(size_type size) + { + if (size <= capacity_) + { + // Are we changing size to bigger within current capacity? + // If so, need to set top terms to zero + if (size > size_) + { + std::fill(end(), begin() + size, T{ 0 }); + } + + // Set the size + size_ = size; + + return; + } + + // At this point we know for sure that size_ <= capacity_ < size so need + // to reallocate to bigger + auto new_data(util::allocate(size, pool_)); + std::copy_n(cbegin(), size_, new_data.get()); + std::fill(new_data.get() + size_, new_data.get() + size, T{ 0 }); + data_.swap_with(new_data); + + // Set the coeff_count and capacity + capacity_ = size; + size_ = size; + } + + /** + Copies a given IntArray to the current one. + + @param[in] assign The IntArray to copy from + */ + inline IntArray &operator =(const IntArray &assign) + { + // Check for self-assignment + if (this == &assign) + { + return *this; + } + + // First resize to correct size + resize(assign.size_); + + // Size is guaranteed to be OK now so copy over + std::copy_n(assign.cbegin(), assign.size_, begin()); + + return *this; + } + + /** + Moves a given IntArray to the current one. + + @param[in] assign The IntArray to move from + */ + IntArray &operator =(IntArray &&assign) noexcept + { + pool_ = std::move(assign.pool_); + capacity_ = assign.capacity_; + size_ = assign.size_; + data_ = std::move(assign.data_); + + return *this; + } + + /** + Saves the IntArray to an output stream. The output is in binary format + and not human-readable. The output stream must have the "binary" flag set. + + @param[in] stream The stream to save the IntArray to + @throws std::exception if the IntArray could not be written to stream + */ + inline void save(std::ostream &stream) const + { + auto old_except_mask = stream.exceptions(); + try + { + // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit + stream.exceptions(std::ios_base::badbit | std::ios_base::failbit); + + std::uint64_t size64 = size_; + stream.write(reinterpret_cast(&size64), sizeof(std::uint64_t)); + stream.write(reinterpret_cast(cbegin()), + util::safe_cast( + util::mul_safe(size_, util::safe_cast(sizeof(T))))); + } + catch (const std::exception &) + { + stream.exceptions(old_except_mask); + throw; + } + + stream.exceptions(old_except_mask); + } + + /** + Loads a IntArray from an input stream overwriting the current IntArray. + + @param[in] stream The stream to load the IntArray from + @throws std::exception if a valid IntArray could not be read from stream + */ + inline void load(std::istream &stream) + { + auto old_except_mask = stream.exceptions(); + try + { + // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit + stream.exceptions(std::ios_base::badbit | std::ios_base::failbit); + + std::uint64_t size64 = 0; + stream.read(reinterpret_cast(&size64), sizeof(std::uint64_t)); + + // Set new size + resize(util::safe_cast(size64)); + + // Read data + stream.read(reinterpret_cast(begin()), + util::safe_cast( + util::mul_safe(size_, util::safe_cast(sizeof(T))))); + } + catch (const std::exception &) + { + stream.exceptions(old_except_mask); + throw; + } + + stream.exceptions(old_except_mask); + } + + private: + MemoryPoolHandle pool_; + + size_type capacity_ = 0; + + size_type size_ = 0; + + util::Pointer data_; + }; +} diff --git a/src/seal/keygenerator.cpp b/src/seal/keygenerator.cpp new file mode 100644 index 000000000..66ba233ff --- /dev/null +++ b/src/seal/keygenerator.cpp @@ -0,0 +1,865 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include "seal/keygenerator.h" +#include "seal/util/common.h" +#include "seal/util/uintcore.h" +#include "seal/util/uintarith.h" +#include "seal/util/uintarithsmallmod.h" +#include "seal/util/polyarithsmallmod.h" +#include "seal/util/randomtostd.h" +#include "seal/util/clipnormal.h" +#include "seal/util/polycore.h" +#include "seal/util/smallntt.h" + +using namespace std; +using namespace seal::util; + +namespace seal +{ + KeyGenerator::KeyGenerator(shared_ptr context) : + context_(move(context)) + { + // Verify parameters + if (!context_) + { + throw invalid_argument("invalid context"); + } + if (!context_->parameters_set()) + { + throw invalid_argument("encryption parameters are not set correctly"); + } + + // Secret key and public key have not been generated + sk_generated_ = false; + pk_generated_ = false; + + // Generate the secret and public key + generate_sk(); + generate_pk(); + } + + KeyGenerator::KeyGenerator(shared_ptr context, + const SecretKey &secret_key) : context_(move(context)) + { + // Verify parameters + if (!context_) + { + throw invalid_argument("invalid context"); + } + if (!context_->parameters_set()) + { + throw invalid_argument("encryption parameters are not set correctly"); + } + if (secret_key.parms_id() != context_->first_parms_id()) + { + throw invalid_argument("secret key is not valid for encryption parameters"); + } + + // Set the secret key + secret_key_ = secret_key; + sk_generated_ = true; + + // Generate the public key + generate_pk(); + } + + KeyGenerator::KeyGenerator(shared_ptr context, + const SecretKey &secret_key, const PublicKey &public_key) : + context_(move(context)) + { + // Verify parameters + if (!context_) + { + throw invalid_argument("invalid context"); + } + if (!context_->parameters_set()) + { + throw invalid_argument("encryption parameters are not set correctly"); + } + if (secret_key.parms_id() != context_->first_parms_id()) + { + throw invalid_argument("secret key is not valid for encryption parameters"); + } + if (public_key.parms_id() != context_->first_parms_id()) + { + throw invalid_argument("public key is not valid for encryption parameters"); + } + + // Extract encryption parameters. + auto &context_data = *context_->context_data(); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + + // Set the secret and public keys + public_key_ = public_key; + secret_key_ = secret_key; + + // Set the secret_key_array to have size 1 (first power of secret) + secret_key_array_ = allocate_poly(coeff_count, coeff_mod_count, pool_); + set_poly_poly(secret_key_.data().data(), coeff_count, coeff_mod_count, + secret_key_array_.get()); + secret_key_array_size_ = 1; + + // Secret key and public key are generated + sk_generated_ = true; + pk_generated_ = true; + } + + void KeyGenerator::generate_sk() + { + // Extract encryption parameters. + auto &context_data = *context_->context_data(); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + + // Initialize secret key. + secret_key_ = SecretKey(); + sk_generated_ = false; + secret_key_.data().resize(mul_safe(coeff_count, coeff_mod_count)); + + shared_ptr random(parms.random_generator()->create()); + + // Generate secret key + uint64_t *secret_key = secret_key_.data().data(); + set_poly_coeffs_zero_one_negone(context_data, secret_key, random); + + auto &small_ntt_tables = context_data.small_ntt_tables(); + for (size_t i = 0; i < coeff_mod_count; i++) + { + // Transform the secret s into NTT representation. + ntt_negacyclic_harvey(secret_key + (i * coeff_count), small_ntt_tables[i]); + } + + // Set the secret_key_array to have size 1 (first power of secret) + secret_key_array_ = allocate_poly(coeff_count, coeff_mod_count, pool_); + set_poly_poly(secret_key_.data().data(), coeff_count, coeff_mod_count, + secret_key_array_.get()); + secret_key_array_size_ = 1; + + // Set the parms_id for secret key + secret_key_.parms_id() = parms.parms_id(); + + // Secret key has been generated + sk_generated_ = true; + } + + void KeyGenerator::generate_pk() + { + if (!sk_generated_) + { + throw logic_error("cannot generate public key for unspecified secret key"); + } + + // Extract encryption parameters. + auto &context_data = *context_->context_data(); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + + // Size check + if (!product_fits_in(coeff_count, coeff_mod_count)) + { + throw logic_error("invalid parameters"); + } + + // Initialize public key. + public_key_ = PublicKey(); + pk_generated_ = false; + public_key_.data().resize(context_, parms.parms_id(), 2); + + // The public key is in NTT form + public_key_.data().is_ntt_form() = true; + + shared_ptr random(parms.random_generator()->create()); + + // Generate public key: (pk[0],pk[1]) = ([-(as+e)]_q, a) + uint64_t *secret_key = secret_key_.data().data(); + + // Sample a uniformly at random + // Set pk[1] = a (we sample the NTT form directly) + uint64_t *public_key_1 = public_key_.data().data(1); + set_poly_coeffs_uniform(context_data, public_key_1, random); + + // calculate a*s + e (mod q) and store in pk[0] + auto &small_ntt_tables = context_data.small_ntt_tables(); + + auto noise(allocate_poly(coeff_count, coeff_mod_count, pool_)); + set_poly_coeffs_normal(context_data, noise.get(), random); + for (size_t i = 0; i < coeff_mod_count; i++) + { + // Transform the noise e into NTT representation. + ntt_negacyclic_harvey( + noise.get() + (i * coeff_count), small_ntt_tables[i]); + + // The inputs are not reduced but that's OK. We are only at most at + // 122 bits and barrett_reduce_128 can deal with that. + dyadic_product_coeffmod( + secret_key + (i * coeff_count), + public_key_1 + (i * coeff_count), coeff_count, + coeff_modulus[i], + public_key_.data().data(0) + (i * coeff_count)); + add_poly_poly_coeffmod( + noise.get() + (i * coeff_count), + public_key_.data().data(0) + (i * coeff_count), + coeff_count, coeff_modulus[i], + public_key_.data().data(0) + (i * coeff_count)); + } + + // Negate and set this value to pk[0] + // pk[0] is now -(as+e) mod q + for (size_t i = 0; i < coeff_mod_count; i++) + { + negate_poly_coeffmod( + public_key_.data().data(0) + (i * coeff_count), coeff_count, + coeff_modulus[i], public_key_.data().data(0) + (i * coeff_count)); + } + + // Set the parms_id for public key + public_key_.parms_id() = parms.parms_id(); + + // Public key has been generated + pk_generated_ = true; + } + + RelinKeys KeyGenerator::relin_keys(int decomposition_bit_count, size_t count) + { + // Check to see if secret key and public key have been generated + if (!sk_generated_) + { + throw logic_error("cannot generate relinearization keys for unspecified secret key"); + } + + // Check that count is in correct interval + if (count < SEAL_RELIN_KEY_COUNT_MIN || + count > SEAL_RELIN_KEY_COUNT_MAX) + { + throw invalid_argument("count out of bounds"); + } + + // Check that decomposition_bit_count is in correct interval + if (decomposition_bit_count < SEAL_DBC_MIN || + decomposition_bit_count > SEAL_DBC_MAX) + { + throw invalid_argument("decomposition_bit_count is not in the valid range"); + } + + // Extract encryption parameters. + auto &context_data = *context_->context_data(); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + auto &small_ntt_tables = context_data.small_ntt_tables(); + + // Size check + if (!product_fits_in(coeff_count, coeff_mod_count)) + { + throw logic_error("invalid parameters"); + } + + // Create the RelinKeys object to return + RelinKeys relin_keys; + + // Initialize decomposition_factors + vector> decomposition_factors; + populate_decomposition_factors(context_data, decomposition_bit_count, + decomposition_factors); + + // Initialize the relinearization keys + relin_keys.data().resize(count); + for (size_t i = 0; i < count; i++) + { + relin_keys.data()[i].reserve(coeff_mod_count); + + for (size_t j = 0; j < coeff_mod_count; j++) + { + relin_keys.data()[i].emplace_back( + context_, parms.parms_id(), + 2 * decomposition_factors[j].size(), + relin_keys.pool()); + + // Resize to right size too (above only allocated) + // This is slightly odd use of Ciphertext as a container + relin_keys.data()[i].back().resize( + 2 * decomposition_factors[j].size()); + + // The keys are in NTT form + relin_keys.data()[i].back().is_ntt_form() = true; + } + } + + shared_ptr random(parms.random_generator()->create()); + + // Create relinearization keys. + auto noise(allocate_poly(coeff_count, coeff_mod_count, pool_)); + auto temp(allocate_uint(coeff_count, pool_)); + + // Make sure we have enough secret keys computed + compute_secret_key_array(context_data, count + 1); + + // assume the secret key is already transformed into NTT form. + for (size_t k = 0; k < count; k++) + { + for (size_t l = 0; l < coeff_mod_count; l++) + { + // populate evaluate_keys_[k] + for (size_t i = 0; i < decomposition_factors[l].size(); i++) + { + // generate NTT(a_i) and store in relin_keys_[k][l].second[i] + uint64_t *eval_keys_first = relin_keys.data()[k][l].data(2 * i); + uint64_t *eval_keys_second = relin_keys.data()[k][l].data(2 * i + 1); + + // We sample a_i directly in NTT form + set_poly_coeffs_uniform(context_data, eval_keys_second, random); + + for (size_t j = 0; j < coeff_mod_count; j++) + { + // calculate a_i*s and store in relin_keys_[k].first[i] + dyadic_product_coeffmod(eval_keys_second + (j * coeff_count), + secret_key_.data().data() + (j * coeff_count), + coeff_count, coeff_modulus[j], eval_keys_first + (j * coeff_count)); + } + + // generate NTT(e_i) + set_poly_coeffs_normal(context_data, noise.get(), random); + for (size_t j = 0; j < coeff_mod_count; j++) + { + ntt_negacyclic_harvey(noise.get() + (j * coeff_count), small_ntt_tables[j]); + + // add e_i into relin_keys_[k].first[i] + add_poly_poly_coeffmod( + noise.get() + (j * coeff_count), eval_keys_first + (j * coeff_count), + coeff_count, coeff_modulus[j], eval_keys_first + (j * coeff_count)); + + // negate value in relin_keys_[k].first[i] + negate_poly_coeffmod( + eval_keys_first + (j * coeff_count), coeff_count, coeff_modulus[j], + eval_keys_first + (j * coeff_count)); + + // multiply w^i * s^(k+2) + uint64_t decomposition_factor_mod = decomposition_factors[l][i] & + static_cast(-static_cast(l == j)); + multiply_poly_scalar_coeffmod( + secret_key_array_.get() + (k + 1) * coeff_count * coeff_mod_count + (j * coeff_count), + coeff_count, decomposition_factor_mod, coeff_modulus[j], temp.get()); + + // add w^i . s^(k+2) into relin_keys_[k].first[i] + add_poly_poly_coeffmod(eval_keys_first + (j * coeff_count), temp.get(), coeff_count, + coeff_modulus[j], eval_keys_first + (j * coeff_count)); + } + } + } + } + + // Set decomposition_bit_count + relin_keys.decomposition_bit_count_ = decomposition_bit_count; + + // Set the parms_id + relin_keys.parms_id() = parms.parms_id(); + + return relin_keys; + } + + GaloisKeys KeyGenerator::galois_keys(int decomposition_bit_count, + const vector &galois_elts) + { + // Check to see if secret key and public key have been generated + if (!sk_generated_) + { + throw logic_error("cannot generate galois keys for unspecified secret key"); + } + + // Check that decomposition_bit_count is in correct interval + if (decomposition_bit_count < SEAL_DBC_MIN || + decomposition_bit_count > SEAL_DBC_MAX) + { + throw invalid_argument("decomposition_bit_count is not on the valid range"); + } + + // Extract encryption parameters. + auto &context_data = *context_->context_data(); + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + int coeff_count_power = get_power_of_two(coeff_count); + auto &small_ntt_tables = context_data.small_ntt_tables(); + + // Size check + if (!product_fits_in(coeff_count, coeff_mod_count, size_t(2))) + { + throw logic_error("invalid parameters"); + } + + // Create the GaloisKeys object to return + GaloisKeys galois_keys; + + // The max number of keys is equal to number of coefficients + galois_keys.data().resize(coeff_count); + + // Initialize decomposition_factors + vector> decomposition_factors; + populate_decomposition_factors(context_data, decomposition_bit_count, + decomposition_factors); + + for (uint64_t galois_elt : galois_elts) + { + // Verify coprime conditions. + if (!(galois_elt & 1) || (galois_elt >= 2 * coeff_count)) + { + throw invalid_argument("galois element is not valid"); + } + + // Do we already have the key? + if (galois_keys.has_key(galois_elt)) + { + continue; + } + + // Rotate secret key for each coeff_modulus + auto rotated_secret_key(allocate_poly(coeff_count, coeff_mod_count, pool_)); + for (size_t i = 0; i < coeff_mod_count; i++) + { + apply_galois_ntt(secret_key_.data().data() + (i * coeff_count), + coeff_count_power, galois_elt, + rotated_secret_key.get() + (i * coeff_count)); + } + + // Initialize galois key + // This is the location in the galois_keys vector + uint64_t index = (galois_elt - 1) >> 1; + galois_keys.data()[index].reserve(coeff_mod_count); + + for (size_t i = 0; i < coeff_mod_count; i++) + { + galois_keys.data()[index].emplace_back( + context_, parms.parms_id(), + 2 * decomposition_factors[i].size(), + galois_keys.pool()); + + // Resize to right size too (above only allocated) + // This is slightly odd use of Ciphertext as a container + galois_keys.data()[index].back().resize( + 2 * decomposition_factors[i].size()); + + // The Galois keys are in NTT form + galois_keys.data()[index].back().is_ntt_form() = true; + } + + shared_ptr random(parms.random_generator()->create()); + + // Create Galois keys. + auto noise(allocate_poly(coeff_count, coeff_mod_count, pool_)); + auto temp(allocate_uint(coeff_count, pool_)); + + for (size_t l = 0; l < coeff_mod_count; l++) + { + // populate galois_keys_[k] + for (size_t i = 0; i < decomposition_factors[l].size(); i++) + { + // generate NTT(a_i) and store in galois_keys_[k][l].second[i] + uint64_t *eval_keys_first = galois_keys.data()[index][l].data(2 * i); + uint64_t *eval_keys_second = galois_keys.data()[index][l].data(2 * i + 1); + + // We sample a_i in NTT form directly + set_poly_coeffs_uniform(context_data, eval_keys_second, random); + for (size_t j = 0; j < coeff_mod_count; j++) + { + // calculate a_i*s and store in galois_keys_[k].first[i] + dyadic_product_coeffmod(eval_keys_second + (j * coeff_count), + secret_key_.data().data() + (j * coeff_count), + coeff_count, coeff_modulus[j], + eval_keys_first + (j * coeff_count)); + } + + // generate NTT(e_i) + set_poly_coeffs_normal(context_data, noise.get(), random); + for (size_t j = 0; j < coeff_mod_count; j++) + { + ntt_negacyclic_harvey( + noise.get() + (j * coeff_count), small_ntt_tables[j]); + + // add NTT(e_i) into galois_keys_[k].first[i] + add_poly_poly_coeffmod(noise.get() + (j * coeff_count), + eval_keys_first + (j * coeff_count), + coeff_count, coeff_modulus[j], + eval_keys_first + (j * coeff_count)); + + // negate value in galois_keys_[k].first[i] + negate_poly_coeffmod( + eval_keys_first + (j * coeff_count), coeff_count, + coeff_modulus[j], eval_keys_first + (j * coeff_count)); + + // multiply w^i * rotated_secret_key + uint64_t decomposition_factor_mod = decomposition_factors[l][i] & + static_cast(-static_cast(l == j)); + multiply_poly_scalar_coeffmod(rotated_secret_key.get() + (j * coeff_count), + coeff_count, decomposition_factor_mod, + coeff_modulus[j], temp.get()); + + // add w^i * rotated_secret_key into galois_keys_[k].first[i] + add_poly_poly_coeffmod(eval_keys_first + (j * coeff_count), temp.get(), + coeff_count, coeff_modulus[j], eval_keys_first + (j * coeff_count)); + } + } + } + } + + // Set decomposition_bit_count + galois_keys.decomposition_bit_count_ = decomposition_bit_count; + + // Set the parms_id + galois_keys.parms_id_ = parms.parms_id(); + + return galois_keys; + } + + GaloisKeys KeyGenerator::galois_keys(int decomposition_bit_count, + const vector &steps) + { + // Check to see if secret key and public key have been generated + if (!sk_generated_) + { + throw logic_error("cannot generate galois keys for unspecified secret key"); + } + + // Check that decomposition_bit_count is in correct interval + if (decomposition_bit_count < SEAL_DBC_MIN || + decomposition_bit_count > SEAL_DBC_MAX) + { + throw invalid_argument("decomposition_bit_count is not on the valid range"); + } + + // Extract encryption parameters. + auto &context_data = *context_->context_data(); + if (!context_data.qualifiers().using_batching) + { + throw logic_error("encryption parameters do not support batching"); + } + + auto &parms = context_data.parms(); + size_t coeff_count = parms.poly_modulus_degree(); + + vector galois_elts; + transform(steps.begin(), steps.end(), back_inserter(galois_elts), + [&](auto s) { return steps_to_galois_elt(s, coeff_count); }); + + return galois_keys(decomposition_bit_count, galois_elts); + } + + GaloisKeys KeyGenerator::galois_keys(int decomposition_bit_count) + { + // Check to see if secret key and public key have been generated + if (!sk_generated_) + { + throw logic_error("cannot generate galois keys for unspecified secret key"); + } + + // Check that decomposition_bit_count is in correct interval + if (decomposition_bit_count < SEAL_DBC_MIN || + decomposition_bit_count > SEAL_DBC_MAX) + { + throw invalid_argument("decomposition_bit_count is not in the valid range"); + } + + size_t coeff_count = context_->context_data()->parms().poly_modulus_degree(); + uint64_t m = coeff_count << 1; + int logn = get_power_of_two(static_cast(coeff_count)); + + vector logn_galois_keys{}; + + // Generate Galois keys for m - 1 (X -> X^{m-1}) + logn_galois_keys.push_back(m - 1); + + // Generate Galois key for power of 3 mod m (X -> X^{3^k}) and + // for negative power of 3 mod m (X -> X^{-3^k}) + uint64_t two_power_of_three = 3; + uint64_t neg_two_power_of_three = 0; + try_mod_inverse(3, m, neg_two_power_of_three); + for (int i = 0; i < logn - 1; i++) + { + logn_galois_keys.push_back(two_power_of_three); + two_power_of_three *= two_power_of_three; + two_power_of_three &= (m - 1); + + logn_galois_keys.push_back(neg_two_power_of_three); + neg_two_power_of_three *= neg_two_power_of_three; + neg_two_power_of_three &= (m - 1); + } + + return galois_keys(decomposition_bit_count, logn_galois_keys); + } + + void KeyGenerator::set_poly_coeffs_zero_one_negone( + const SEALContext::ContextData &context_data, + uint64_t *poly, shared_ptr random) const + { + // Extract encryption parameters. + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + + RandomToStandardAdapter engine(random); + uniform_int_distribution dist(-1, 1); + + for (size_t i = 0; i < coeff_count; i++) + { + int rand_index = dist(engine); + if (rand_index == 1) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + poly[i + (j * coeff_count)] = 1; + } + } + else if (rand_index == -1) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + poly[i + (j * coeff_count)] = coeff_modulus[j].value() - 1; + } + } + else + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + poly[i + (j * coeff_count)] = 0; + } + } + } + } + + void KeyGenerator::set_poly_coeffs_normal( + const SEALContext::ContextData &context_data, uint64_t *poly, + shared_ptr random) const + { + // Extract encryption parameters. + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + + if (parms.noise_standard_deviation() == 0 || + parms.noise_max_deviation() == 0) + { + set_zero_poly(coeff_count, coeff_mod_count, poly); + return; + } + RandomToStandardAdapter engine(random); + ClippedNormalDistribution dist(0, parms.noise_standard_deviation(), + parms.noise_max_deviation()); + for (size_t i = 0; i < coeff_count; i++) + { + int64_t noise = static_cast(dist(engine)); + if (noise > 0) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + poly[i + (j * coeff_count)] = static_cast(noise); + } + } + else if (noise < 0) + { + noise = -noise; + for (size_t j = 0; j < coeff_mod_count; j++) + { + poly[i + (j * coeff_count)] = + coeff_modulus[j].value() - static_cast(noise); + } + } + else + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + poly[i + (j * coeff_count)] = 0; + } + } + } + } + + void KeyGenerator::set_poly_coeffs_uniform( + const SEALContext::ContextData &context_data, + uint64_t *poly, shared_ptr random) const + { + // Extract encryption parameters. + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + + // Set up source of randomness which produces random things of size 32 bit + RandomToStandardAdapter engine(random); + + for (size_t j = 0; j < coeff_mod_count; j++) + { + uint64_t current_modulus = coeff_modulus[j].value(); + for (size_t i = 0; i < coeff_count; i++, poly++) + { + uint64_t new_coeff = (static_cast(engine()) << 32) + + static_cast(engine()); + *poly = new_coeff % current_modulus; + } + } + } + + const SecretKey &KeyGenerator::secret_key() const + { + if (!sk_generated_) + { + throw logic_error("secret key has not been generated"); + } + return secret_key_; + } + + const PublicKey &KeyGenerator::public_key() const + { + if (!pk_generated_) + { + throw logic_error("public key has not been generated"); + } + return public_key_; + } + + void KeyGenerator::compute_secret_key_array( + const SEALContext::ContextData &context_data, size_t max_power) + { + // Extract encryption parameters. + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_count = parms.poly_modulus_degree(); + size_t coeff_mod_count = coeff_modulus.size(); + + // Size check + if (!product_fits_in(coeff_count, coeff_mod_count, max_power)) + { + throw logic_error("invalid parameters"); + } + + ReaderLock reader_lock(secret_key_array_locker_.acquire_read()); + + size_t old_size = secret_key_array_size_; + size_t new_size = max(max_power, old_size); + + if (old_size == new_size) + { + return; + } + + reader_lock.unlock(); + + // Need to extend the array + // Compute powers of secret key until max_power + auto new_secret_key_array(allocate_poly( + new_size * coeff_count, coeff_mod_count, pool_)); + set_poly_poly(secret_key_array_.get(), old_size * coeff_count, + coeff_mod_count, new_secret_key_array.get()); + + size_t poly_ptr_increment = coeff_count * coeff_mod_count; + uint64_t *prev_poly_ptr = new_secret_key_array.get() + + (old_size - 1) * poly_ptr_increment; + uint64_t *next_poly_ptr = prev_poly_ptr + poly_ptr_increment; + + // Since all of the key powers in secret_key_array_ are already + // NTT transformed, to get the next one we simply need to compute + // a dyadic product of the last one with the first one + // [which is equal to NTT(secret_key_)]. + for (size_t i = old_size; i < new_size; i++) + { + for (size_t j = 0; j < coeff_mod_count; j++) + { + dyadic_product_coeffmod( + prev_poly_ptr + (j * coeff_count), + new_secret_key_array.get() + (j * coeff_count), + coeff_count, coeff_modulus[j], + next_poly_ptr + (j * coeff_count)); + } + prev_poly_ptr = next_poly_ptr; + next_poly_ptr += poly_ptr_increment; + } + + + // Take writer lock to update array + WriterLock writer_lock(secret_key_array_locker_.acquire_write()); + + // Do we still need to update size? + old_size = secret_key_array_size_; + new_size = max(max_power, secret_key_array_size_); + + if (old_size == new_size) + { + return; + } + + // Acquire new array + secret_key_array_size_ = new_size; + secret_key_array_.acquire(new_secret_key_array); + } + + // decomposition_factors[i][j] = 2^(w*j) * hat-q_i mod q_i + void KeyGenerator::populate_decomposition_factors( + const SEALContext::ContextData &context_data, + int decomposition_bit_count, + vector> &decomposition_factors) const + { + // Extract encryption parameters. + auto &parms = context_data.parms(); + auto &coeff_modulus = parms.coeff_modulus(); + size_t coeff_mod_count = coeff_modulus.size(); + + decomposition_factors.clear(); + + // Initialize decomposition_factors + decomposition_factors.resize(coeff_mod_count); + uint64_t power_of_w = uint64_t(1) << decomposition_bit_count; + + // Compute hat-q_i mod q_i + vector coeff_prod_mod(coeff_mod_count); + for (size_t i = 0; i < coeff_mod_count; i++) + { + coeff_prod_mod[i] = 1; + for (size_t j = 0; j < coeff_mod_count; j++) + { + if (i != j) + { + coeff_prod_mod[i] = multiply_uint_uint_mod(coeff_prod_mod[i], + coeff_modulus[j].value(), coeff_modulus[i]); + } + } + } + + for (size_t i = 0; i < coeff_mod_count; i++) + { + uint64_t current_decomposition_factor = coeff_prod_mod[i]; + uint64_t current_smallmod = coeff_modulus[i].value(); + while (current_smallmod != 0) + { + decomposition_factors[i].emplace_back(current_decomposition_factor); + //multiply 2^w mod q_i + current_decomposition_factor = multiply_uint_uint_mod( + current_decomposition_factor, power_of_w, coeff_modulus[i]); + current_smallmod >>= decomposition_bit_count; + } + } + + // We need to ensure that the total number of decomposition factors does not + // exceed 63 for lazy reduction in relinearization to work + size_t total_ev_factor_count = 0; + for (size_t i = 0; i < coeff_mod_count; i++) + { + total_ev_factor_count = + add_safe(total_ev_factor_count, decomposition_factors[i].size()); + } + if (total_ev_factor_count > 63) + { + throw invalid_argument("decomposition_bit_count is too small"); + } + } +} diff --git a/src/seal/keygenerator.h b/src/seal/keygenerator.h new file mode 100644 index 000000000..9c4db08d0 --- /dev/null +++ b/src/seal/keygenerator.h @@ -0,0 +1,209 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include "seal/context.h" +#include "seal/util/smallntt.h" +#include "seal/memorymanager.h" +#include "seal/publickey.h" +#include "seal/secretkey.h" +#include "seal/relinkeys.h" +#include "seal/galoiskeys.h" +#include "seal/randomgen.h" + +namespace seal +{ + /** + Generates matching secret key and public key. An existing KeyGenerator can + also at any time be used to generate relinearization keys and Galois keys. + Constructing a KeyGenerator requires only a SEALContext. + + @see EncryptionParameters for more details on encryption parameters. + @see SecretKey for more details on secret key. + @see PublicKey for more details on public key. + @see RelinKeys for more details on relinearization keys. + @see GaloisKeys for more details on Galois keys. + */ + class KeyGenerator + { + public: + /** + Creates a KeyGenerator initialized with the specified SEALContext. + + @param[in] context The SEALContext + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + */ + KeyGenerator(std::shared_ptr context); + + /** + Creates an KeyGenerator instance initialized with the specified SEALContext + and specified previously secret key. This can e.g. be used to increase + the number of relinearization keys from what had earlier been generated, + or to generate Galois keys in case they had not been generated earlier. + + + @param[in] context The SEALContext + @param[in] secret_key A previously generated secret key + @throws std::invalid_argument if encryption parameters are not valid + @throws std::invalid_argument if secret_key or public_key is not valid + for encryption parameters + */ + KeyGenerator(std::shared_ptr context, + const SecretKey &secret_key); + + /** + Creates an KeyGenerator instance initialized with the specified SEALContext + and specified previously secret and public keys. This can e.g. be used + to increase the number of relinearization keys from what had earlier been + generated, or to generate Galois keys in case they had not been generated + earlier. + + @param[in] context The SEALContext + @param[in] secret_key A previously generated secret key + @param[in] public_key A previously generated public key + @throws std::invalid_argument if encryption parameters are not valid + @throws std::invalid_argument if secret_key or public_key is not valid + for encryption parameters + */ + KeyGenerator(std::shared_ptr context, + const SecretKey &secret_key, const PublicKey &public_key); + + /** + Returns a const reference to the secret key. + */ + const SecretKey &secret_key() const; + + /** + Returns a const reference to the public key. + */ + const PublicKey &public_key() const; + + /** + Generates and returns the specified number of relinearization keys. + + @param[in] decomposition_bit_count The decomposition bit count + @param[in] count The number of relinearization keys to generate + @throws std::invalid_argument if decomposition_bit_count is not within [1, 60] + @throws std::invalid_argument if count is zero or too large + */ + RelinKeys relin_keys(int decomposition_bit_count, std::size_t count = 1); + + /** + Generates and returns Galois keys. This function creates specific Galois + keys that can be used to apply specific Galois automorphisms on encrypted + data. The user needs to give as input a vector of Galois elements + corresponding to the keys that are to be created. + + The Galois elements are odd integers in the interval [1, M-1], where + M = 2*N, and N = degree(poly_modulus). Used with batching, a Galois element + 3^i % M corresponds to a cyclic row rotation i steps to the left, and + a Galois element 3^(N/2-i) % M corresponds to a cyclic row rotation i + steps to the right. The Galois element M-1 corresponds to a column rotation + (row swap) in BFV, and complex conjugation in CKKS. In the polynomial view + (not batching), a Galois automorphism by a Galois element p changes Enc(plain(x)) + to Enc(plain(x^p)). + + @param[in] decomposition_bit_count The decomposition bit count + @param[in] galois_elts The Galois elements for which to generate keys + @throws std::invalid_argument if decomposition_bit_count is not within [1, 60] + @throws std::invalid_argument if the Galois elements are not valid + */ + GaloisKeys galois_keys(int decomposition_bit_count, + const std::vector &galois_elts); + + /** + Generates and returns Galois keys. This function creates specific Galois + keys that can be used to apply specific Galois automorphisms on encrypted + data. The user needs to give as input a vector of desired Galois rotation + step counts, where negative step counts correspond to rotations to the + right and positive step counts correspond to rotations to the left. + A step count of zero can be used to indicate a column rotation in the BFV + scheme complex conjugation in the CKKS scheme. + + @param[in] decomposition_bit_count The decomposition bit count + @param[in] galois_elts The rotation step counts for which to generate keys + @throws std::logic_error if the encryption parameters do not support batching + and scheme is scheme_type::BFV + @throws std::invalid_argument if decomposition_bit_count is not within [1, 60] + @throws std::invalid_argument if the step counts are not valid + */ + GaloisKeys galois_keys(int decomposition_bit_count, + const std::vector &steps); + + /** + Generates and returns Galois keys. This function creates logarithmically + many (in degree of the polynomial modulus) Galois keys that is sufficient + to apply any Galois automorphism (e.g. rotations) on encrypted data. Most + users will want to use this overload of the function. + + @param[in] decomposition_bit_count The decomposition bit count + @throws std::invalid_argument if decomposition_bit_count is not within [1, 60] + */ + GaloisKeys galois_keys(int decomposition_bit_count); + + private: + KeyGenerator(const KeyGenerator ©) = delete; + + KeyGenerator &operator =(const KeyGenerator &assign) = delete; + + KeyGenerator(KeyGenerator &&source) = delete; + + KeyGenerator &operator =(KeyGenerator &&assign) = delete; + + void set_poly_coeffs_zero_one_negone( + const SEALContext::ContextData &context_data, std::uint64_t *poly, + std::shared_ptr random) const; + + void set_poly_coeffs_normal( + const SEALContext::ContextData &context_data, std::uint64_t *poly, + std::shared_ptr random) const; + + void set_poly_coeffs_uniform( + const SEALContext::ContextData &context_data, std::uint64_t *poly, + std::shared_ptr random) const; + + void compute_secret_key_array( + const SEALContext::ContextData &context_data, + std::size_t max_power); + + void populate_decomposition_factors( + const SEALContext::ContextData &context_data, + int decomposition_bit_count, + std::vector> &decomposition_factors) const; + + /** + Generates new secret key. + */ + void generate_sk(); + + /** + Generates new public key matching to existing secret key. + */ + void generate_pk(); + + /** + We use a fresh memory pool with `clear_on_destruction' enabled + */ + MemoryPoolHandle pool_ = MemoryManager::GetPool(mm_prof_opt::FORCE_NEW, true); + + std::shared_ptr context_{ nullptr }; + + PublicKey public_key_; + + SecretKey secret_key_; + + std::size_t secret_key_array_size_ = 0; + + util::Pointer secret_key_array_; + + mutable util::ReaderWriterLocker secret_key_array_locker_; + + bool sk_generated_ = false; + + bool pk_generated_ = false; + }; +} diff --git a/src/seal/memorymanager.cpp b/src/seal/memorymanager.cpp new file mode 100644 index 000000000..32d6d56ae --- /dev/null +++ b/src/seal/memorymanager.cpp @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/memorymanager.h" + +namespace seal +{ + std::unique_ptr + MemoryManager::mm_prof_{ new MMProfGlobal }; +#ifndef _M_CEE + std::mutex MemoryManager::switch_mutex_; +#else +#pragma message("WARNING: MemoryManager compiled thread-unsafe and MMProfGuard disabled to support /clr") +#endif +} diff --git a/src/seal/memorymanager.h b/src/seal/memorymanager.h new file mode 100644 index 000000000..6f38f5250 --- /dev/null +++ b/src/seal/memorymanager.h @@ -0,0 +1,822 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include +#include "seal/util/mempool.h" +#include "seal/util/globals.h" + +/* +For .NET Framework wrapper support (C++/CLI) we need to + (1) compile the MemoryManager class as thread-unsafe because C++ + mutexes cannot be brought through C++/CLI layer; + (2) disable thread-safe memory pools. +*/ +#ifndef _M_CEE +#include +#include +#endif + +namespace seal +{ + /** + Manages a shared pointer to a memory pool. SEAL uses memory pools for + improved performance due to the large number of memory allocations needed + by the homomorphic encryption operations, and the underlying polynomial + arithmetic. The library automatically creates a shared global memory pool + that is used for all dynamic allocations by default, and the user can + optionally create any number of custom memory pools to be used instead. + + @Uses in Multi-Threaded Applications + Sometimes the user might want to use specific memory pools for dynamic + allocations in certain functions. For example, in heavily multi-threaded + applications allocating concurrently from a shared memory pool might lead + to significant performance issues due to thread contention. For these cases + SEAL provides overloads of the functions that take a MemoryPoolHandle as an + additional argument, and uses the associated memory pool for all dynamic + allocations inside the function. Whenever these functions are called, the + user can then simply pass a thread-local MemoryPoolHandle to be used. + + @Thread-Unsafe Memory Pools + While memory pools are by default thread-safe, in some cases it suffices + to have a memory pool be thread-unsafe. To get a little extra performance, + the user can optionally create such thread-unsafe memory pools and use them + just as they would use thread-safe memory pools. + + @Initialized and Uninitialized Handles + A MemoryPoolHandle has to be set to point either to the global memory pool, + or to a new memory pool. If this is not done, the MemoryPoolHandle is + said to be uninitialized, and cannot be used. Initialization simple means + assigning MemoryPoolHandle::Global() or MemoryPoolHandle::New() to it. + + @Managing Lifetime + Internally, the MemoryPoolHandle wraps an std::shared_ptr pointing to + a SEAL memory pool class. Thus, as long as a MemoryPoolHandle pointing to + a particular memory pool exists, the pool stays alive. Classes such as + Evaluator and Ciphertext store their own local copies of a MemoryPoolHandle + to guarantee that the pool stays alive as long as the managing object + itself stays alive. The global memory pool is implemented as a global + std::shared_ptr to a memory pool class, and is thus expected to stay + alive for the entire duration of the program execution. Note that it can + be problematic to create other global objects that use the memory pool + e.g. in their constructor, as one would have to ensure the initialization + order of these global variables to be correct (i.e. global memory pool + first). + */ + class MemoryPoolHandle + { + public: + /** + Creates a new uninitialized MemoryPoolHandle. + */ + MemoryPoolHandle() = default; + + /** + Creates a MemoryPoolHandle pointing to a given MemoryPool object. + */ + MemoryPoolHandle(std::shared_ptr pool) noexcept : + pool_(std::move(pool)) + { + } + + /** + Creates a copy of a given MemoryPoolHandle. As a result, the created + MemoryPoolHandle will point to the same underlying memory pool as the + copied instance. + + + @param[in] copy The MemoryPoolHandle to copy from + */ + MemoryPoolHandle(const MemoryPoolHandle ©) noexcept + { + operator =(copy); + } + + /** + Creates a new MemoryPoolHandle by moving a given one. As a result, the + moved MemoryPoolHandle will become uninitialized. + + + @param[in] source The MemoryPoolHandle to move from + */ + MemoryPoolHandle(MemoryPoolHandle &&source) noexcept + { + operator =(std::move(source)); + } + + /** + Overwrites the MemoryPoolHandle instance with the specified instance. As + a result, the current MemoryPoolHandle will point to the same underlying + memory pool as the assigned instance. + + @param[in] assign The MemoryPoolHandle instance to assign to the current + instance + */ + inline MemoryPoolHandle &operator =(const MemoryPoolHandle &assign) noexcept + { + pool_ = assign.pool_; + return *this; + } + + /** + Moves a specified MemoryPoolHandle instance to the current instance. As + a result, the assigned MemoryPoolHandle will become uninitialized. + + @param[in] assign The MemoryPoolHandle instance to assign to the current + instance + */ + inline MemoryPoolHandle &operator =(MemoryPoolHandle &&assign) noexcept + { + pool_ = std::move(assign.pool_); + return *this; + } + + /** + Returns a MemoryPoolHandle pointing to the global memory pool. + */ + inline static MemoryPoolHandle Global() + { + // We return an aliased shared_ptr; a global shared_ptr can cause problems + // with some wrappers. + return MemoryPoolHandle( + std::shared_ptr(std::shared_ptr(), + util::global_variables::global_memory_pool.get())); + } +#ifndef _M_CEE + /** + Returns a MemoryPoolHandle pointing to the thread-local memory pool. Note + that the thread-local memory pool cannot be used to communicate across + different threads. + */ + inline static MemoryPoolHandle ThreadLocal() + { + // We return an aliased shared_ptr; a global shared_ptr can cause problems + // with some wrappers. + return MemoryPoolHandle( + std::shared_ptr(std::shared_ptr(), + util::global_variables::tls_memory_pool.get())); + } +#endif + /** + Returns a MemoryPoolHandle pointing to a new thread-safe memory pool. + + @param[in] clear_on_destruction Indicates whether the memory pool data + should be cleared when destroyed. This can be important when memory pools + are used to store private data. + */ + inline static MemoryPoolHandle New(bool clear_on_destruction = false) + { + return MemoryPoolHandle( + std::make_shared(clear_on_destruction)); + } + + /** + Returns a reference to the internal SEAL memory pool that the MemoryPoolHandle + points to. This function is mainly for internal use. + + @throws std::logic_error if the MemoryPoolHandle is uninitialized + */ + inline operator util::MemoryPool &() const + { + if (!pool_) + { + throw std::logic_error("pool not initialized"); + } + return *pool_.get(); + } + + /** + Returns the number of different allocation sizes. This function returns + the number of different allocation sizes the memory pool pointed to by + the current MemoryPoolHandle has made. For example, if the memory pool has + only allocated two allocations of sizes 128 KB, this function returns 1. + If it has instead allocated one allocation of size 64 KB and one of 128 KB, + this function returns 2. + + @throws std::logic_error if the MemoryPoolHandle is uninitialized + */ + inline std::size_t pool_count() const + { + if (!pool_) + { + throw std::logic_error("pool not initialized"); + } + return pool_->pool_count(); + } + + /** + Returns the size of allocated memory. This functions returns the total + amount of memory (in bytes) allocated by the memory pool pointed to by + the current MemoryPoolHandle. + + + @throws std::logic_error if the MemoryPoolHandle is uninitialized + */ + inline std::size_t alloc_byte_count() const + { + if (!pool_) + { + throw std::logic_error("pool not initialized"); + } + return pool_->alloc_byte_count(); + } + + /** + Returns whether the MemoryPoolHandle is initialized. + */ + inline operator bool () const + { + return pool_.operator bool(); + } + + /** + Compares MemoryPoolHandles. This function returns whether the current + MemoryPoolHandle points to the same memory pool as a given MemoryPoolHandle. + */ + inline bool operator ==(const MemoryPoolHandle &compare) noexcept + { + return pool_ == compare.pool_; + } + + /** + Compares MemoryPoolHandles. This function returns whether the current + MemoryPoolHandle points to a different memory pool than a given + MemoryPoolHandle. + */ + inline bool operator !=(const MemoryPoolHandle &compare) noexcept + { + return pool_ != compare.pool_; + } + + private: + std::shared_ptr pool_ = nullptr; + }; + + using mm_prof_opt_t = std::uint64_t; + + /** + Control options for MemoryManager::GetPool function. These force the MemoryManager + to override the current MMProf and instead return a MemoryPoolHandle pointing + to a memory pool of the indicated type. + */ + enum mm_prof_opt : mm_prof_opt_t + { + DEFAULT = 0x0, + FORCE_GLOBAL = 0x1, + FORCE_NEW = 0x2, + FORCE_THREAD_LOCAL = 0x4 + }; + + /** + The MMProf is a pure virtual class that every profile for the MemoryManager + should inherit from. The only functionality this class implements is the + get_pool(mm_prof_opt_t) function that returns a MemoryPoolHandle pointing + to a pool selected by internal logic optionally using the input parameter + of type mm_prof_opt_t. The returned MemoryPoolHandle must point to a valid + memory pool. + */ + class MMProf + { + public: + /** + Creates a new MMProf. + */ + MMProf() = default; + + /** + Destroys the MMProf. + */ + virtual ~MMProf() noexcept + { + } + + /** + Returns a MemoryPoolHandle pointing to a pool selected by internal logic + in a derived class and by the mm_prof_opt_t input parameter. + + */ + virtual MemoryPoolHandle get_pool(mm_prof_opt_t) = 0; + + private: + }; + + /** + A memory manager profile that always returns a MemoryPoolHandle pointing to + the global memory pool. SEAL uses this memory manager profile by default. + */ + class MMProfGlobal : public MMProf + { + public: + /** + Creates a new MMProfGlobal. + */ + MMProfGlobal() = default; + + /** + Destroys the MMProfGlobal. + */ + virtual ~MMProfGlobal() noexcept override + { + } + + /** + Returns a MemoryPoolHandle pointing to the global memory pool. The + mm_prof_opt_t input parameter has no effect. + */ + inline virtual MemoryPoolHandle + get_pool(mm_prof_opt_t) override + { + return MemoryPoolHandle::Global(); + } + + private: + }; + + /** + A memory manager profile that always returns a MemoryPoolHandle pointing to + the new thread-safe memory pool. This profile should not be used except in + special circumstances, as it does not result in any reuse of allocated memory. + */ + class MMProfNew : public MMProf + { + public: + /** + Creates a new MMProfNew. + */ + MMProfNew() = default; + + /** + Destroys the MMProfNew. + */ + virtual ~MMProfNew() noexcept override + { + } + + /** + Returns a MemoryPoolHandle pointing to a new thread-safe memory pool. The + mm_prof_opt_t input parameter has no effect. + */ + inline virtual MemoryPoolHandle + get_pool(mm_prof_opt_t) override + { + return MemoryPoolHandle::New(); + } + + private: + }; + + /** + A memory manager profile that always returns a MemoryPoolHandle pointing to + specific memory pool. + */ + class MMProfFixed : public MMProf + { + public: + /** + Creates a new MMProfFixed. The MemoryPoolHandle given as argument is returned + by every call to get_pool(mm_prof_opt_t). + + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if pool is uninitialized + */ + MMProfFixed(MemoryPoolHandle pool) : pool_(std::move(pool)) + { + if (!pool_) + { + throw std::invalid_argument("pool is uninitialized"); + } + } + + /** + Destroys the MMProfFixed. + */ + virtual ~MMProfFixed() noexcept override + { + } + + /** + Returns a MemoryPoolHandle pointing to the stored memory pool. The + mm_prof_opt_t input parameter has no effect. + */ + inline virtual MemoryPoolHandle + get_pool(mm_prof_opt_t) override + { + return pool_; + } + + private: + MemoryPoolHandle pool_; + }; +#ifndef _M_CEE + /** + A memory manager profile that always returns a MemoryPoolHandle pointing to + the thread-local memory pool. This profile should be used with care, as any + memory allocated by it will be released once the thread exits. In other words, + the thread-local memory pool cannot be used to share memory across different + threads. On the other hand, this profile can be useful when a very high number + of threads doing simultaneous allocations would cause contention in the + global memory pool. + */ + class MMProfThreadLocal : public MMProf + { + public: + /** + Creates a new MMProfThreadLocal. + */ + MMProfThreadLocal() = default; + + /** + Destroys the MMProfThreadLocal. + */ + virtual ~MMProfThreadLocal() noexcept override + { + } + + /** + Returns a MemoryPoolHandle pointing to the thread-local memory pool. The + mm_prof_opt_t input parameter has no effect. + */ + inline virtual MemoryPoolHandle + get_pool(mm_prof_opt_t) override + { + return MemoryPoolHandle::ThreadLocal(); + } + + private: + }; +#endif + /** + The MemoryManager class can be used to create instances of MemoryPoolHandle + based on a given "profile". A profile is implemented by inheriting from the + MMProf class (pure virtual) and encapsulates internal logic for deciding which + memory pool to use. + */ + class MemoryManager + { + friend class MMProfGuard; + + public: + MemoryManager() = delete; + + /** + Sets the current profile to a given one and returns a unique_ptr pointing + to the previously set profile. + + @param[in] mm_prof Pointer to a new memory manager profile + @throws std::invalid_argument if mm_prof is nullptr + */ + static inline std::unique_ptr + SwitchProfile(MMProf* &&mm_prof) noexcept + { +#ifndef _M_CEE + std::lock_guard switching_lock(switch_mutex_); +#endif + return SwitchProfileThreadUnsafe(std::move(mm_prof)); + } + + /** + Sets the current profile to a given one and returns a unique_ptr pointing + to the previously set profile. + + @param[in] mm_prof Pointer to a new memory manager profile + @throws std::invalid_argument if mm_prof is nullptr + */ + static inline std::unique_ptr SwitchProfile( + std::unique_ptr &&mm_prof) noexcept + { +#ifndef _M_CEE + std::lock_guard switch_lock(switch_mutex_); +#endif + return SwitchProfileThreadUnsafe(std::move(mm_prof)); + } + + /** + Returns a MemoryPoolHandle according to the currently set memory manager + profile and prof_opt. The following values for prof_opt have an effect + independent of the current profile: + + + mm_prof_opt::FORCE_NEW: return MemoryPoolHandle::New() + mm_prof_opt::FORCE_GLOBAL: return MemoryPoolHandle::Global() + mm_prof_opt::FORCE_THREAD_LOCAL: return MemoryPoolHandle::ThreadLocal() + + Other values for prof_opt are forwarded to the current profile and, depending + on the profile, may or may not have an effect. The value mm_prof_opt::DEFAULT + will always invoke a default behavior for the current profile. + + @param[in] prof_opt A mm_prof_opt_t parameter used to provide additional + instructions to the memory manager profile for internal logic. + */ + template + static inline MemoryPoolHandle GetPool(mm_prof_opt_t prof_opt, Args &&...args) + { + switch (prof_opt) + { + case mm_prof_opt::FORCE_GLOBAL: + return MemoryPoolHandle::Global(); + + case mm_prof_opt::FORCE_NEW: + return MemoryPoolHandle::New(std::forward(args)...); +#ifndef _M_CEE + case mm_prof_opt::FORCE_THREAD_LOCAL: + return MemoryPoolHandle::ThreadLocal(); +#endif + default: +#ifdef SEAL_DEBUG + { + auto pool = mm_prof_->get_pool(prof_opt); + if (!pool) + { + throw std::logic_error("cannot return uninitialized pool"); + } + return pool; + } +#endif + return mm_prof_->get_pool(prof_opt); + } + } + + static inline MemoryPoolHandle GetPool() + { + return GetPool(mm_prof_opt::DEFAULT); + } + + private: + static inline std::unique_ptr + SwitchProfileThreadUnsafe( + MMProf* &&mm_prof) + { + if (!mm_prof) + { + throw std::invalid_argument("mm_prof cannot be nullptr"); + } + auto ret_mm_prof = std::move(mm_prof_); + mm_prof_.reset(mm_prof); + return ret_mm_prof; + } + + static inline std::unique_ptr + SwitchProfileThreadUnsafe( + std::unique_ptr &&mm_prof) + { + if (!mm_prof) + { + throw std::invalid_argument("mm_prof cannot be nullptr"); + } + std::swap(mm_prof_, mm_prof); + return std::move(mm_prof); + } + + static std::unique_ptr mm_prof_; +#ifndef _M_CEE + static std::mutex switch_mutex_; +#endif + }; +#ifndef _M_CEE + /** + Class for a scoped switch of memory manager profile. This class acts as a scoped + "guard" for changing the memory manager profile so that the programmer does + not have to explicitly switch back afterwards and that other threads cannot + change the MMProf. It can also help with exception safety by guaranteeing that + the profile is switched back to the original if a function throws an exception + after changing the profile for local use. + */ + class MMProfGuard + { + public: + /** + Creates a new MMProfGuard. If start_locked is true, this function will + attempt to lock the MemoryManager for profile switch to mm_prof, perform + the switch, and keep the lock until unlocked or destroyed. If start_lock + is false, mm_prof will be stored but the switch will not be performed and + a lock will not be obtained until lock() is explicitly called. + + @param[in] mm_prof Pointer to a new memory manager profile + @param[in] start_locked Bool indicating whether the lock should be + immediately obtained (true by default) + */ + MMProfGuard(std::unique_ptr &&mm_prof, + bool start_locked = true) noexcept : + mm_switch_lock_(MemoryManager::switch_mutex_,std::defer_lock) + { + if (start_locked) + { + lock(std::move(mm_prof)); + } + else + { + old_prof_ = std::move(mm_prof); + } + } + + /** + Creates a new MMProfGuard. If start_locked is true, this function will + attempt to lock the MemoryManager for profile switch to mm_prof, perform + the switch, and keep the lock until unlocked or destroyed. If start_lock + is false, mm_prof will be stored but the switch will not be performed and + a lock will not be obtained until lock() is explicitly called. + + @param[in] mm_prof Pointer to a new memory manager profile + @param[in] start_locked Bool indicating whether the lock should be + immediately obtained (true by default) + */ + MMProfGuard(MMProf* &&mm_prof, + bool start_locked = true) noexcept : + mm_switch_lock_(MemoryManager::switch_mutex_, std::defer_lock) + { + if (start_locked) + { + lock(std::move(mm_prof)); + } + else + { + old_prof_.reset(std::move(mm_prof)); + } + } + + /** + Attempts to lock the MemoryManager for profile switch, perform the switch + to currently stored memory manager profile, store the previously held profile, + and keep the lock until unlocked or destroyed. If the lock cannot be obtained + on the first attempt, the function returns false; otherwise returns true. + + @throws std::runtime_error if the lock is already owned + */ + inline bool try_lock() + { + if (mm_switch_lock_.owns_lock()) + { + throw std::runtime_error("lock is already owned"); + } + if (!mm_switch_lock_.try_lock()) + { + return false; + } + old_prof_ = MemoryManager::SwitchProfileThreadUnsafe( + std::move(old_prof_)); + return true; + } + + /** + Locks the MemoryManager for profile switch, performs the switch to currently + stored memory manager profile, stores the previously held profile, and + keep the lock until unlocked or destroyed. The calling thread will block + until the lock can be obtained. + + @throws std::runtime_error if the lock is already owned + */ + inline void lock() + { + if (mm_switch_lock_.owns_lock()) + { + throw std::runtime_error("lock is already owned"); + } + mm_switch_lock_.lock(); + old_prof_ = MemoryManager::SwitchProfileThreadUnsafe( + std::move(old_prof_)); + } + + /** + Attempts to lock the MemoryManager for profile switch, perform the switch + to the given memory manager profile, store the previously held profile, + and keep the lock until unlocked or destroyed. If the lock cannot be + obtained on the first attempt, the function returns false; otherwise + returns true. + + @param[in] mm_prof Pointer to a new memory manager profile + @throws std::runtime_error if the lock is already owned + */ + inline bool try_lock( + std::unique_ptr &&mm_prof) + { + if (mm_switch_lock_.owns_lock()) + { + throw std::runtime_error("lock is already owned"); + } + if (!mm_switch_lock_.try_lock()) + { + return false; + } + old_prof_ = MemoryManager::SwitchProfileThreadUnsafe( + std::move(mm_prof)); + return true; + } + + /** + Locks the MemoryManager for profile switch, performs the switch to the given + memory manager profile, stores the previously held profile, and keep the + lock until unlocked or destroyed. The calling thread will block until the + lock can be obtained. + + @param[in] mm_prof Pointer to a new memory manager profile + @throws std::runtime_error if the lock is already owned + */ + inline void lock( + std::unique_ptr &&mm_prof) + { + if (mm_switch_lock_.owns_lock()) + { + throw std::runtime_error("lock is already owned"); + } + mm_switch_lock_.lock(); + old_prof_ = MemoryManager::SwitchProfileThreadUnsafe( + std::move(mm_prof)); + } + + /** + Attempts to lock the MemoryManager for profile switch, perform the switch + to the given memory manager profile, store the previously held profile, + and keep the lock until unlocked or destroyed. If the lock cannot be + obtained on the first attempt, the function returns false; otherwise returns + true. + + @param[in] mm_prof Pointer to a new memory manager profile + @throws std::runtime_error if the lock is already owned + */ + inline bool try_lock(MMProf* &&mm_prof) + { + if (mm_switch_lock_.owns_lock()) + { + throw std::runtime_error("lock is already owned"); + } + if (!mm_switch_lock_.try_lock()) + { + return false; + } + old_prof_ = MemoryManager::SwitchProfileThreadUnsafe( + std::move(mm_prof)); + return true; + } + + /** + Locks the MemoryManager for profile switch, performs the switch to the + given memory manager profile, stores the previously held profile, and keep + the lock until unlocked or destroyed. The calling thread will block until + the lock can be obtained. + + @param[in] mm_prof Pointer to a new memory manager profile + @throws std::runtime_error if the lock is already owned + */ + inline void lock(MMProf* &&mm_prof) + { + if (mm_switch_lock_.owns_lock()) + { + throw std::runtime_error("lock is already owned"); + } + mm_switch_lock_.lock(); + old_prof_ = MemoryManager::SwitchProfileThreadUnsafe( + std::move(mm_prof)); + } + + /** + Releases the memory manager profile switch lock for MemoryManager, stores + the current profile, and resets the profile to the one used before locking. + + @throw std::runtime_error if the lock is not owned + */ + inline void unlock() + { + if (!mm_switch_lock_.owns_lock()) + { + throw std::runtime_error("lock is not owned"); + } + old_prof_ = MemoryManager::SwitchProfileThreadUnsafe( + std::move(old_prof_)); + mm_switch_lock_.unlock(); + } + + /** + Destroys the MMProfGuard. If the memory manager profile switch lock is + owned, releases the lock, and resets the profile to the one used before + locking. + */ + ~MMProfGuard() + { + if (mm_switch_lock_.owns_lock()) + { + old_prof_ = MemoryManager::SwitchProfileThreadUnsafe( + std::move(old_prof_)); + mm_switch_lock_.unlock(); + } + } + + /** + Returns whether the current MMProfGuard owns the memory manager profile + switch lock. + */ + inline bool owns_lock() noexcept + { + return mm_switch_lock_.owns_lock(); + } + + private: + std::unique_ptr old_prof_; + + std::unique_lock mm_switch_lock_; + }; +#endif +} diff --git a/src/seal/plaintext.cpp b/src/seal/plaintext.cpp new file mode 100644 index 000000000..be019894e --- /dev/null +++ b/src/seal/plaintext.cpp @@ -0,0 +1,327 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/plaintext.h" +#include "seal/util/common.h" + +using namespace std; +using namespace seal::util; + +namespace seal +{ + namespace + { + bool is_dec_char(char c) + { + return c >= '0' && c <= '9'; + } + + int get_dec_value(char c) + { + return c - '0'; + } + + int get_coeff_length(const char *poly) + { + int length = 0; + while (is_hex_char(*poly)) + { + length++; + poly++; + } + return length; + } + + int get_coeff_power(const char *poly, int *power_length) + { + int length = 0; + if (*poly == '\0') + { + *power_length = 0; + return 0; + } + if (*poly != 'x') + { + return -1; + } + poly++; + length++; + + if (*poly != '^') + { + return -1; + } + poly++; + length++; + + int power = 0; + while (is_dec_char(*poly)) + { + power *= 10; + power += get_dec_value(*poly); + poly++; + length++; + } + *power_length = length; + return power; + } + + int get_plus(const char *poly) + { + if (*poly == '\0') + { + return 0; + } + if (*poly++ != ' ') + { + return -1; + } + if (*poly++ != '+') + { + return -1; + } + if (*poly != ' ') + { + return -1; + } + return 3; + } + } + + Plaintext &Plaintext::operator =(const string &hex_poly) + { + if (is_ntt_form()) + { + throw logic_error("cannot set an NTT transformed Plaintext"); + } + if (unsigned_gt(hex_poly.size(), numeric_limits::max())) + { + throw invalid_argument("hex_poly too long"); + } + int length = safe_cast(hex_poly.size()); + + // Determine size needed to store string coefficient. + int assign_coeff_count = 0; + + int assign_coeff_bit_count = 0; + int pos = 0; + int last_power = safe_cast( + min(data_.max_size(), safe_cast(numeric_limits::max()))); + const char *hex_poly_ptr = hex_poly.data(); + while (pos < length) + { + // Determine length of coefficient starting at pos. + int coeff_length = get_coeff_length(hex_poly_ptr + pos); + if (coeff_length == 0) + { + throw invalid_argument("unable to parse hex_poly"); + } + + // Determine bit length of coefficient. + int coeff_bit_count = + get_hex_string_bit_count(hex_poly_ptr + pos, coeff_length); + if (coeff_bit_count > assign_coeff_bit_count) + { + assign_coeff_bit_count = coeff_bit_count; + } + pos += coeff_length; + + // Extract power-term. + int power_length = 0; + int power = get_coeff_power(hex_poly_ptr + pos, &power_length); + if (power == -1 || power >= last_power) + { + throw invalid_argument("unable to parse hex_poly"); + } + if (assign_coeff_count == 0) + { + assign_coeff_count = power + 1; + } + pos += power_length; + last_power = power; + + // Extract plus (unless it is the end). + int plus_length = get_plus(hex_poly_ptr + pos); + if (plus_length == -1) + { + throw invalid_argument("unable to parse hex_poly"); + } + pos += plus_length; + } + + // If string is empty, then done. + if (assign_coeff_count == 0 || assign_coeff_bit_count == 0) + { + set_zero(); + return *this; + } + + // Resize polynomial. + if (assign_coeff_bit_count > bits_per_uint64) + { + throw invalid_argument("hex_poly has too large coefficients"); + } + resize(safe_cast(assign_coeff_count)); + + // Populate polynomial from string. + pos = 0; + last_power = safe_cast(coeff_count()); + while (pos < length) + { + // Determine length of coefficient starting at pos. + const char *coeff_start = hex_poly_ptr + pos; + int coeff_length = get_coeff_length(coeff_start); + pos += coeff_length; + + // Extract power-term. + int power_length = 0; + int power = get_coeff_power(hex_poly_ptr + pos, &power_length); + pos += power_length; + + // Extract plus (unless it is the end). + int plus_length = get_plus(hex_poly_ptr + pos); + pos += plus_length; + + // Zero coefficients not set by string. + for (int zero_power = last_power - 1; zero_power > power; --zero_power) + { + data_[static_cast(zero_power)] = 0; + } + + // Populate coefficient. + uint64_t *coeff_ptr = data_.begin() + power; + hex_string_to_uint(coeff_start, coeff_length, size_t(1), coeff_ptr); + last_power = power; + } + + // Zero coefficients not set by string. + for (int zero_power = last_power - 1; zero_power >= 0; --zero_power) + { + data_[static_cast(zero_power)] = 0; + } + + return *this; + } + + bool Plaintext::is_valid_for(shared_ptr context) const + { + // Verify parameters + if (!context || !context->parameters_set()) + { + return false; + } + + if (is_ntt_form()) + { + auto context_data_ptr = context->context_data(parms_id_); + if (!context_data_ptr) + { + return false; + } + + auto &coeff_modulus = context_data_ptr->parms().coeff_modulus(); + size_t coeff_mod_count = coeff_modulus.size(); + size_t poly_modulus_degree = context_data_ptr->parms().poly_modulus_degree(); + if (mul_safe(coeff_modulus.size(), poly_modulus_degree) != data_.size()) + { + return false; + } + + const pt_coeff_type *ptr = data(); + for (size_t j = 0; j < coeff_mod_count; j++) + { + uint64_t modulus = coeff_modulus[j].value(); + for (size_t k = 0; k < poly_modulus_degree; k++, ptr++) + { + if (*ptr >= modulus) + { + return false; + } + } + } + } + else + { + auto context_data_ptr = context->context_data(); + if (context_data_ptr->parms().scheme() != scheme_type::BFV) + { + return false; + } + + size_t poly_modulus_degree = context_data_ptr->parms().poly_modulus_degree(); + if (data_.size() > poly_modulus_degree) + { + return false; + } + + uint64_t modulus = context->context_data()->parms().plain_modulus().value(); + const pt_coeff_type *ptr = data(); + for (size_t k = 0; k < poly_modulus_degree; k++, ptr++) + { + if (*ptr >= modulus) + { + return false; + } + } + } + + return true; + } + + void Plaintext::save(ostream &stream) const + { + auto old_except_mask = stream.exceptions(); + try + { + // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit + stream.exceptions(ios_base::badbit | ios_base::failbit); + + stream.write(reinterpret_cast(&parms_id_), sizeof(parms_id_type)); + stream.write(reinterpret_cast(&scale_), sizeof(double)); + data_.save(stream); + } + catch (const exception &) + { + stream.exceptions(old_except_mask); + throw; + } + + stream.exceptions(old_except_mask); + } + + void Plaintext::unsafe_load(istream &stream) + { + auto old_except_mask = stream.exceptions(); + try + { + // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit + stream.exceptions(ios_base::badbit | ios_base::failbit); + + parms_id_type parms_id{}; + stream.read(reinterpret_cast(&parms_id), sizeof(parms_id_type)); + + double scale = 0; + stream.read(reinterpret_cast(&scale), sizeof(double)); + + // Load the data + IntArray new_data(data_.pool()); + new_data.load(stream); + + // Set the parms_id + parms_id_ = parms_id; + + // Set the scale + scale_ = scale; + + // Set the data + data_.swap_with(new_data); + } + catch (const exception &) + { + stream.exceptions(old_except_mask); + throw; + } + + stream.exceptions(old_except_mask); + } +} diff --git a/src/seal/plaintext.h b/src/seal/plaintext.h new file mode 100644 index 000000000..bd06a1f97 --- /dev/null +++ b/src/seal/plaintext.h @@ -0,0 +1,624 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include +#include +#include "seal/util/common.h" +#include "seal/util/polycore.h" +#include "seal/util/defines.h" +#include "seal/memorymanager.h" +#include "seal/encryptionparams.h" +#include "seal/intarray.h" +#include "seal/context.h" + +namespace seal +{ + /** + Class to store a plaintext element. The data for the plaintext is a polynomial + with coefficients modulo the plaintext modulus. The degree of the plaintext + polynomial must be one less than the degree of the polynomial modulus. The + backing array always allocates one 64-bit word per each coefficient of the + polynomial. + + @par Memory Management + The coefficient count of a plaintext refers to the number of word-size + coefficients in the plaintext, whereas its capacity refers to the number of + word-size coefficients that fit in the current memory allocation. In high- + performance applications unnecessary re-allocations should be avoided by + reserving enough memory for the plaintext to begin with either by providing + the desired capacity to the constructor as an extra argument, or by calling + the reserve function at any time. + + When the scheme is scheme_type::BFV each coefficient of a plaintext is a 64-bit + word, but when the scheme is scheme_type::CKKS the plaintext is by default + stored in an NTT transformed form with respect to each of the primes in the + coefficient modulus. Thus, the size of the allocation that is needed is the + size of the coefficient modulus (number of primes) times the degree of the + polynomial modulus. In addition, a valid CKKS plaintext also store the parms_id + for the corresponding encryption parameters. + + @par Thread Safety + In general, reading from plaintext is thread-safe as long as no other thread + is concurrently mutating it. This is due to the underlying data structure + storing the plaintext not being thread-safe. + + @see Ciphertext for the class that stores ciphertexts. + */ + class Plaintext + { + public: + using pt_coeff_type = std::uint64_t; + + using size_type = IntArray::size_type; + + /** + Constructs an empty plaintext allocating no memory. + + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if pool is uninitialized + */ + Plaintext(MemoryPoolHandle pool = MemoryManager::GetPool()) : + data_(std::move(pool)) + { + } + + /** + Constructs a plaintext representing a constant polynomial 0. The coefficient + count of the polynomial is set to the given value. The capacity is set to + the same value. + + @param[in] coeff_count The number of (zeroed) coefficients in the plaintext + polynomial + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if coeff_count is negative + @throws std::invalid_argument if pool is uninitialized + */ + explicit Plaintext(size_type coeff_count, + MemoryPoolHandle pool = MemoryManager::GetPool()) : + data_(coeff_count, std::move(pool)) + { + } + + /** + Constructs a plaintext representing a constant polynomial 0. The coefficient + count of the polynomial and the capacity are set to the given values. + + @param[in] capacity The capacity + @param[in] coeff_count The number of (zeroed) coefficients in the plaintext + polynomial + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if capacity is less than coeff_count + @throws std::invalid_argument if coeff_count is negative + @throws std::invalid_argument if pool is uninitialized + */ + explicit Plaintext(size_type capacity, size_type coeff_count, + MemoryPoolHandle pool = MemoryManager::GetPool()) : + data_(capacity, coeff_count, std::move(pool)) + { + } + + /** + Constructs a plaintext from a given hexadecimal string describing the + plaintext polynomial. + + The string description of the polynomial must adhere to the format returned + by to_string(), + which is of the form "7FFx^3 + 1x^1 + 3" and summarized by the following + rules: + 1. Terms are listed in order of strictly decreasing exponent + 2. Coefficient values are non-negative and in hexadecimal format (upper + and lower case letters are both supported) + 3. Exponents are positive and in decimal format + 4. Zero coefficient terms (including the constant term) may be (but do + not have to be) omitted + 5. Term with the exponent value of one must be exactly written as x^1 + 6. Term with the exponent value of zero (the constant term) must be written + as just a hexadecimal number without exponent + 7. Terms must be separated by exactly + and minus is not + allowed + 8. Other than the +, no other terms should have whitespace + + @param[in] hex_poly The formatted polynomial string specifying the plaintext + polynomial + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool + @throws std::invalid_argument if hex_poly does not adhere to the expected + format + @throws std::invalid_argument if pool is uninitialized + */ + Plaintext(const std::string &hex_poly, + MemoryPoolHandle pool = MemoryManager::GetPool()) : + data_(std::move(pool)) + { + operator =(hex_poly); + } + + /** + Constructs a new plaintext by copying a given one. + + @param[in] copy The plaintext to copy from + */ + Plaintext(const Plaintext ©) = default; + + /** + Constructs a new plaintext by moving a given one. + + @param[in] source The plaintext to move from + */ + Plaintext(Plaintext &&source) = default; + + /** + Allocates enough memory to accommodate the backing array of a plaintext + with given capacity. + + @param[in] capacity The capacity + @throws std::invalid_argument if capacity is negative + @throws std::logic_error if the plaintext is NTT transformed + */ + void reserve(size_type capacity) + { + if (is_ntt_form()) + { + throw std::logic_error("cannot reserve for an NTT transformed Plaintext"); + } + data_.reserve(capacity); + } + + /** + Allocates enough memory to accommodate the backing array of the current + plaintext and copies it over to the new location. This function is meant + to reduce the memory use of the plaintext to smallest possible and can be + particularly important after modulus switching. + */ + inline void shrink_to_fit() + { + data_.shrink_to_fit(); + } + + /** + Resets the plaintext. This function releases any memory allocated by the + plaintext, returning it to the memory pool. + */ + inline void release() noexcept + { + parms_id_ = parms_id_zero; + scale_ = 1.0; + data_.release(); + } + + /** + Resizes the plaintext to have a given coefficient count. The plaintext + is automatically reallocated if the new coefficient count does not fit in + the current capacity. + + @param[in] coeff_count The number of coefficients in the plaintext polynomial + @throws std::invalid_argument if coeff_count is negative + @throws std::logic_error if the plaintext is NTT transformed + */ + inline void resize(size_type coeff_count) + { + if (is_ntt_form()) + { + throw std::logic_error("cannot reserve for an NTT transformed Plaintext"); + } + data_.resize(coeff_count); + } + + /** + Copies a given plaintext to the current one. + + @param[in] assign The plaintext to copy from + */ + Plaintext &operator =(const Plaintext &assign) = default; + + /** + Moves a given plaintext to the current one. + + @param[in] assign The plaintext to move from + */ + Plaintext &operator =(Plaintext &&assign) = default; + + /** + Sets the value of the current plaintext to the polynomial represented by + the a given hexadecimal string. + + The string description of the polynomial must adhere to the format returned + by to_string(), which is of the form "7FFx^3 + 1x^1 + 3" and summarized + by the following rules: + 1. Terms are listed in order of strictly decreasing exponent + 2. Coefficient values are non-negative and in hexadecimal format (upper + and lower case letters are both supported) + 3. Exponents are positive and in decimal format + 4. Zero coefficient terms (including the constant term) may be (but do + not have to be) omitted + 5. Term with the exponent value of one must be exactly written as x^1 + 6. Term with the exponent value of zero (the constant term) must be + written as just a hexadecimal number without exponent + 7. Terms must be separated by exactly + and minus is not + allowed + 8. Other than the +, no other terms should have whitespace + + @param[in] hex_poly The formatted polynomial string specifying the plaintext + polynomial + @throws std::invalid_argument if hex_poly does not adhere to the expected + format + @throws std::invalid_argument if the coefficients of hex_poly are too wide + */ + Plaintext &operator =(const std::string &hex_poly); + + /** + Sets the value of the current plaintext to a given constant polynomial. + The coefficient count is set to one. + + @param[in] const_coeff The constant coefficient + @throws std::logic_error if the plaintext is NTT transformed + */ + Plaintext &operator =(pt_coeff_type const_coeff) + { + data_.resize(1); + data_[0] = const_coeff; + return *this; + } + + /** + Sets a given range of coefficients of a plaintext polynomial to zero; does + nothing if length is zero. + + @param[in] start_coeff The index of the first coefficient to set to zero + @param[in] length The number of coefficients to set to zero + @throws std::out_of_range if start_coeff + length - 1 is not within [0, coeff_count) + */ + inline void set_zero(size_type start_coeff, size_type length) + { + if (!length) + { + return; + } + if (start_coeff + length - 1 >= coeff_count()) + { + throw std::out_of_range("length must be non-negative and start_coeff + length - 1 must be within [0, coeff_count)"); + } + std::fill_n(data_.begin() + start_coeff, length, pt_coeff_type(0)); + } + + /** + Sets the plaintext polynomial coefficients to zero starting at a given index. + + @param[in] start_coeff The index of the first coefficient to set to zero + @throws std::out_of_range if start_coeff is not within [0, coeff_count) + */ + inline void set_zero(size_type start_coeff) + { + if (start_coeff >= coeff_count()) + { + throw std::out_of_range("start_coeff must be within [0, coeff_count)"); + } + std::fill(data_.begin() + start_coeff, data_.end(), pt_coeff_type(0)); + } + + /** + Sets the plaintext polynomial to zero. + */ + inline void set_zero() + { + std::fill(data_.begin(), data_.end(), pt_coeff_type(0)); + } + + /** + Returns a pointer to the beginning of the plaintext polynomial. + */ + inline pt_coeff_type *data() + { + return data_.begin(); + } + + /** + Returns a const pointer to the beginning of the plaintext polynomial. + */ + inline const pt_coeff_type *data() const + { + return data_.cbegin(); + } +#ifdef SEAL_USE_MSGSL_SPAN + /** + Returns a span pointing to the beginning of the text polynomial. + */ + inline gsl::span data_span() + { + return gsl::span(data_.begin(), + static_cast(coeff_count())); + } + + /** + Returns a span pointing to the beginning of the text polynomial. + */ + inline gsl::span data_span() const + { + return gsl::span(data_.cbegin(), + static_cast(coeff_count())); + } +#endif + /** + Returns a pointer to a given coefficient of the plaintext polynomial. + + @param[in] coeff_index The index of the coefficient in the plaintext polynomial + @throws std::out_of_range if coeff_index is not within [0, coeff_count) + */ + inline pt_coeff_type *data(size_type coeff_index) + { + if (coeff_count() == 0) + { + return nullptr; + } + if (coeff_index >= coeff_count()) + { + throw std::out_of_range("coeff_index must be within [0, coeff_count)"); + } + return data_.begin() + coeff_index; + } + + /** + Returns a const pointer to a given coefficient of the plaintext polynomial. + + @param[in] coeff_index The index of the coefficient in the plaintext polynomial + */ + inline const pt_coeff_type *data(size_type coeff_index) const + { + if (coeff_count() == 0) + { + return nullptr; + } + if (coeff_index >= coeff_count()) + { + throw std::out_of_range("coeff_index must be within [0, coeff_count)"); + } + return data_.cbegin() + coeff_index; + } + + /** + Returns a const reference to a given coefficient of the plaintext polynomial. + + @param[in] coeff_index The index of the coefficient in the plaintext polynomial + @throws std::out_of_range if coeff_index is not within [0, coeff_count) + */ + inline const pt_coeff_type &operator [](size_type coeff_index) const + { + return data_.at(coeff_index); + } + + /** + Returns a reference to a given coefficient of the plaintext polynomial. + + @param[in] coeff_index The index of the coefficient in the plaintext polynomial + @throws std::out_of_range if coeff_index is not within [0, coeff_count) + */ + inline pt_coeff_type &operator [](size_type coeff_index) + { + return data_.at(coeff_index); + } + + /** + Returns whether or not the plaintext has the same semantic value as a given + plaintext. Leading zero coefficients are ignored by the comparison. + + @param[in] compare The plaintext to compare against + */ + inline bool operator ==(const Plaintext &compare) const + { + std::size_t sig_coeff_count = significant_coeff_count(); + std::size_t sig_coeff_count_compare = compare.significant_coeff_count(); + bool parms_id_compare = (is_ntt_form() && compare.is_ntt_form() + && (parms_id_ == compare.parms_id_)) || + (!is_ntt_form() && !compare.is_ntt_form()); + return parms_id_compare + && (sig_coeff_count == sig_coeff_count_compare) + && std::equal(data_.cbegin(), + data_.cbegin() + sig_coeff_count, + compare.data_.cbegin(), + compare.data_.cbegin() + sig_coeff_count) + && std::all_of(data_.cbegin() + sig_coeff_count, + data_.cend(), util::is_zero) + && std::all_of(compare.data_.cbegin() + sig_coeff_count, + compare.data_.cend(), util::is_zero) + && util::are_close(scale_, compare.scale_); + } + + /** + Returns whether or not the plaintext has a different semantic value than + a given plaintext. Leading zero coefficients are ignored by the comparison. + + @param[in] compare The plaintext to compare against + */ + inline bool operator !=(const Plaintext &compare) const + { + return !operator ==(compare); + } + + /** + Returns whether the current plaintext polynomial has all zero coefficients. + */ + inline bool is_zero() const + { + return (coeff_count() == 0) || + std::all_of(data_.cbegin(), data_.cend(), + util::is_zero); + } + + /** + Returns the capacity of the current allocation. + */ + inline size_type capacity() const noexcept + { + return data_.capacity(); + } + + /** + Returns the coefficient count of the current plaintext polynomial. + */ + inline size_type coeff_count() const noexcept + { + return data_.size(); + } + + /** + Returns the significant coefficient count of the current plaintext polynomial. + */ + inline size_type significant_coeff_count() const + { + if (coeff_count() == 0) + { + return 0; + } + return util::get_significant_uint64_count_uint(data_.cbegin(), coeff_count()); + } + + /** + Returns a human-readable string description of the plaintext polynomial. + + The returned string is of the form "7FFx^3 + 1x^1 + 3" with a format + summarized by the following: + 1. Terms are listed in order of strictly decreasing exponent + 2. Coefficient values are non-negative and in hexadecimal format (hexadecimal + letters are in upper-case) + 3. Exponents are positive and in decimal format + 4. Zero coefficient terms (including the constant term) are omitted unless + the polynomial is exactly 0 (see rule 9) + 5. Term with the exponent value of one is written as x^1 + 6. Term with the exponent value of zero (the constant term) is written as + just a hexadecimal number without x or exponent + 7. Terms are separated exactly by + + 8. Other than the +, no other terms have whitespace + 9. If the polynomial is exactly 0, the string "0" is returned + + @throws std::invalid_argument if the plaintext is in NTT transformed form + */ + inline std::string to_string() const + { + if (is_ntt_form()) + { + throw std::invalid_argument("cannot convert NTT transformed plaintext to string"); + } + return util::poly_to_hex_string(data_.cbegin(), coeff_count(), 1); + } + + /** + Check whether the current Plaintext is valid for a given SEALContext. If + the given SEALContext is not set, the encryption parameters are invalid, + or the Plaintext data does not match the SEALContext, this function returns + false. Otherwise, returns true. + + @param[in] context The SEALContext + */ + bool is_valid_for(std::shared_ptr context) const; + + /** + Saves the plaintext to an output stream. The output is in binary format + and not human-readable. The output stream must have the "binary" flag set. + + @param[in] stream The stream to save the plaintext to + @throws std::exception if the plaintext could not be written to stream + */ + void save(std::ostream &stream) const; + + /** + Loads a plaintext from an input stream overwriting the current plaintext. + No checking of the validity of the plaintext data against encryption + parameters is performed. This function should not be used unless the + plaintext comes from a fully trusted source. + + @param[in] stream The stream to load the plaintext from + @throws std::exception if a valid plaintext could not be read from stream + */ + void unsafe_load(std::istream &stream); + + /** + Loads a plaintext from an input stream overwriting the current plaintext. + The loaded plaintext is verified to be valid for the given SEALContext. + + @param[in] context The SEALContext + @param[in] stream The stream to load the plaintext from + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + @throws std::exception if a valid plaintext could not be read from stream + @throws std::invalid_argument if the loaded plaintext is invalid for the + context + */ + inline void load(std::shared_ptr context, + std::istream &stream) + { + unsafe_load(stream); + if (!is_valid_for(std::move(context))) + { + throw std::invalid_argument("Plaintext data is invalid"); + } + } + + /** + Returns whether the plaintext is in NTT form. + */ + inline bool is_ntt_form() const noexcept + { + return (parms_id_ != parms_id_zero); + } + + /** + Returns a reference to parms_id. The parms_id must remain zero unless the + plaintext polynomial is in NTT form. + + @see EncryptionParameters for more information about parms_id. + */ + inline auto &parms_id() noexcept + { + return parms_id_; + } + + /** + Returns a const reference to parms_id. The parms_id must remain zero unless + the plaintext polynomial is in NTT form. + + @see EncryptionParameters for more information about parms_id. + */ + inline auto &parms_id() const noexcept + { + return parms_id_; + } + + /** + Returns a reference to the scale. This is only needed when using the CKKS + encryption scheme. The user should have little or no reason to ever change + the scale by hand. + */ + inline auto &scale() noexcept + { + return scale_; + } + + /** + Returns a constant reference to the scale. This is only needed when using + the CKKS encryption scheme. + */ + inline auto &scale() const noexcept + { + return scale_; + } + + /** + Returns the currently used MemoryPoolHandle. + */ + inline MemoryPoolHandle pool() const noexcept + { + return data_.pool(); + } + + private: + parms_id_type parms_id_ = parms_id_zero; + + double scale_ = 1.0; + + IntArray data_; + }; +} diff --git a/src/seal/publickey.h b/src/seal/publickey.h new file mode 100644 index 000000000..becba577e --- /dev/null +++ b/src/seal/publickey.h @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include "seal/ciphertext.h" +#include "seal/context.h" + +namespace seal +{ + /** + Class to store a public key. + + @par Thread Safety + In general, reading from PublicKey is thread-safe as long as no other thread + is concurrently mutating it. This is due to the underlying data structure + storing the public key not being thread-safe. + + @see KeyGenerator for the class that generates the public key. + @see SecretKey for the class that stores the secret key. + @see RelinKeys for the class that stores the relinearization keys. + @see GaloisKeys for the class that stores the Galois keys. + */ + class PublicKey + { + friend class KeyGenerator; + + public: + /** + Creates an empty public key. + */ + PublicKey() = default; + + /** + Creates a new PublicKey by copying an old one. + + @param[in] copy The PublicKey to copy from + */ + PublicKey(const PublicKey ©) = default; + + /** + Creates a new PublicKey by moving an old one. + + @param[in] source The PublicKey to move from + */ + PublicKey(PublicKey &&source) = default; + + /** + Copies an old PublicKey to the current one. + + @param[in] assign The PublicKey to copy from + */ + PublicKey &operator =(const PublicKey &assign) = default; + + /** + Moves an old PublicKey to the current one. + + @param[in] assign The PublicKey to move from + */ + PublicKey &operator =(PublicKey &&assign) = default; + + /** + Returns a reference to the underlying data. + */ + inline auto &data() noexcept + { + return pk_; + } + + /** + Returns a const reference to the underlying data. + */ + inline auto &data() const noexcept + { + return pk_; + } + + /** + Check whether the current PublicKey is valid for a given SEALContext. If + the given SEALContext is not set, the encryption parameters are invalid, + or the PublicKey data does not match the SEALContext, this function returns + false. Otherwise, returns true. + + @param[in] context The SEALContext + */ + inline bool is_valid_for(std::shared_ptr context) const noexcept + { + // Verify parameters + if (!context || !context->parameters_set()) + { + return false; + } + auto parms_id = context->first_parms_id(); + return pk_.is_valid_for(std::move(context)) && + pk_.is_ntt_form() && pk_.parms_id() == parms_id; + } + + /** + Saves the PublicKey to an output stream. The output is in binary format + and not human-readable. The output stream must have the "binary" flag set. + + @param[in] stream The stream to save the PublicKey to + @throws std::exception if the PublicKey could not be written to stream + */ + inline void save(std::ostream &stream) const + { + pk_.save(stream); + } + + /** + Loads a PublicKey from an input stream overwriting the current PublicKey. + No checking of the validity of the PublicKey data against encryption + parameters is performed. This function should not be used unless the + PublicKey comes from a fully trusted source. + + @param[in] stream The stream to load the PublicKey from + @throws std::exception if a valid PublicKey could not be read from stream + */ + inline void unsafe_load(std::istream &stream) + { + pk_.unsafe_load(stream); + } + + /** + Loads a PublicKey from an input stream overwriting the current PublicKey. + The loaded PublicKey is verified to be valid for the given SEALContext. + + @param[in] context The SEALContext + @param[in] stream The stream to load the PublicKey from + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + @throws std::exception if a valid PublicKey could not be read from stream + @throws std::invalid_argument if the loaded PublicKey is invalid for the + context + */ + inline void load(std::shared_ptr context, + std::istream &stream) + { + unsafe_load(stream); + if (!is_valid_for(std::move(context))) + { + throw std::invalid_argument("PublicKey data is invalid"); + } + } + + /** + Returns a reference to parms_id. + */ + inline auto &parms_id() noexcept + { + return pk_.parms_id(); + } + + /** + Returns a const reference to parms_id. + */ + inline auto &parms_id() const noexcept + { + return pk_.parms_id(); + } + + /** + Returns the currently used MemoryPoolHandle. + */ + inline MemoryPoolHandle pool() const noexcept + { + return pk_.pool(); + } + + private: + Ciphertext pk_; + }; +} diff --git a/src/seal/randomgen.cpp b/src/seal/randomgen.cpp new file mode 100644 index 000000000..be96de351 --- /dev/null +++ b/src/seal/randomgen.cpp @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/randomgen.h" + +using namespace std; + +namespace seal +{ + /** + Returns the default random number generator factory. This instance should + not be destroyed. + */ + auto UniformRandomGeneratorFactory::default_factory() + -> const shared_ptr + { + static const shared_ptr + default_factory{ new SEAL_DEFAULT_RNG_FACTORY }; + return default_factory; + } +#ifdef SEAL_USE_AES_NI_PRNG + auto FastPRNGFactory::create() -> shared_ptr + { + if (!(seed_[0] & seed_[1])) + { + return make_shared(random_uint64(), random_uint64()); + } + else + { + return make_shared(seed_[0], seed_[1]); + } + } +#endif +} diff --git a/src/seal/randomgen.h b/src/seal/randomgen.h new file mode 100644 index 000000000..d35575618 --- /dev/null +++ b/src/seal/randomgen.h @@ -0,0 +1,288 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include +#include "seal/util/defines.h" +#include "seal/util/common.h" +#include "seal/util/aes.h" + +namespace seal +{ + /** + Provides the base-class for a uniform random number generator. Instances of + this class are typically returned from the UniformRandomGeneratorFactory class. + This class is meant for users to sub-class to implement their own random number + generators. The implementation should provide a uniform random unsigned 32-bit + value for each call to generate(). Note that the library will never make + concurrent calls to generate() to the same instance (but individual instances + of the same class may have concurrent calls). The uniformity and unpredictability + of the numbers generated is essential for making a secure cryptographic system. + + @see UniformRandomGeneratorFactory for the base-class of a factory class that + generates UniformRandomGenerator instances. + @see StandardRandomAdapter for an implementation of UniformRandomGenerator to + support the C++ standard library's random number generators. + */ + class UniformRandomGenerator + { + public: + /** + Generates a new uniform unsigned 32-bit random number. Note that the + implementation does not need to be thread-safe. + */ + virtual std::uint32_t generate() = 0; + + /** + Destroys the random number generator. + */ + virtual ~UniformRandomGenerator() = default; + }; + + /** + Provides the base-class for a factory instance that creates instances of + UniformRandomGenerator. This class is meant for users to sub-class to implement + their own random number generators. Note that each instance returned may be + used concurrently across separate threads, but each individual instance does + not need to be thread-safe. + + @see UniformRandomGenerator for details relating to the random number generator + instances. + @see StandardRandomAdapterFactory for an implementation of + UniformRandomGeneratorFactory that supports the standard C++ library's + random number generators. + */ + class UniformRandomGeneratorFactory + { + public: + /** + Creates a new uniform random number generator. + */ + virtual auto create() + -> std::shared_ptr = 0; + + /** + Destroys the random number generator factory. + */ + virtual ~UniformRandomGeneratorFactory() = default; + + /** + Returns the default random number generator factory. This instance should + not be destroyed. + */ + static auto default_factory() + -> const std::shared_ptr; + + private: + }; +#ifdef SEAL_USE_AES_NI_PRNG + /** + Provides an implementation of UniformRandomGenerator for using very fast + AES-NI randomness with given 128-bit seed. + */ + class FastPRNG : public UniformRandomGenerator + { + public: + /** + Creates a new FastPRNGFactory instance that initializes every FastPRNG + instance it creates with the given seed. + */ + FastPRNG(std::uint64_t seed_lw, std::uint64_t seed_hw) : + aes_enc_{ seed_lw, seed_hw } + { + refill_buffer(); + } + + /** + Generates a new uniform unsigned 32-bit random number. Note that the + implementation does not need to be thread-safe. + */ + virtual std::uint32_t generate() override + { + std::uint32_t result; + std::copy_n(buffer_head_, util::bytes_per_uint32, + reinterpret_cast(&result)); + buffer_head_ += util::bytes_per_uint32; + if (buffer_head_ == buffer_.cend()) + { + refill_buffer(); + } + return result; + } + + /** + Destroys the random number generator. + */ + virtual ~FastPRNG() override = default; + + private: + AESEncryptor aes_enc_; + + static constexpr std::size_t bytes_per_block_ = + sizeof(aes_block) / sizeof(SEAL_BYTE); + + static constexpr std::size_t buffer_block_size_ = 8; + + static constexpr std::size_t buffer_size_ = + buffer_block_size_ * bytes_per_block_; + + std::array buffer_; + + std::size_t counter_ = 0; + + typename decltype(buffer_)::const_iterator buffer_head_; + + void refill_buffer() + { + // Fill the randomness buffer + aes_block *buffer_ptr = reinterpret_cast(&*buffer_.begin()); + aes_enc_.counter_encrypt(counter_, buffer_block_size_, buffer_ptr); + counter_ += buffer_block_size_; + buffer_head_ = buffer_.cbegin(); + } + }; + + class FastPRNGFactory : public UniformRandomGeneratorFactory + { + public: + /** + Creates a new FastPRNGFactory instance that initializes every FastPRNG + instance it creates with the given seed. A zero seed (default value) + signals that each random number generator created by the factory should + use a different random seed obtained from std::random_device. + + @param[in] seed_lw Low-word for seed for the PRNG + @param[in] seed_hw High-word for seed for the PRNG + */ + FastPRNGFactory(std::uint64_t seed_lw = 0, std::uint64_t seed_hw = 0) : + seed_{ seed_lw, seed_hw } + { + } + + /** + Creates a new uniform random number generator. The caller of create needs + to ensure the returned instance is destroyed once it is no longer in-use + to prevent a memory leak. + */ + virtual auto create() -> std::shared_ptr override; + + /** + Destroys the random number generator factory. + */ + virtual ~FastPRNGFactory() = default; + + private: + std::uint64_t random_uint64() const noexcept + { + std::random_device rd; + return (static_cast(rd()) << 32) + + static_cast(rd()); + } + + std::uint64_t seed_[2]; + }; +#endif //SEAL_USE_AES_NI_PRNG + /** + Provides an implementation of UniformRandomGenerator for the standard C++ + library's uniform random number generators. + + @tparam RNG specifies the type of the standard C++ library's random number + generator (e.g., std::default_random_engine) + */ + template + class StandardRandomAdapter : public UniformRandomGenerator + { + public: + /** + Creates a new random number generator (of type RNG). + */ + StandardRandomAdapter() = default; + + /** + Returns a reference to the random number generator. + */ + inline const RNG &generator() const noexcept + { + return generator_; + } + + /** + Returns a reference to the random number generator. + */ + inline RNG &generator() noexcept + { + return generator_; + } + + /** + Generates a new uniform unsigned 32-bit random number. + */ + std::uint32_t generate() noexcept override + { + SEAL_IF_CONSTEXPR (RNG::min() == 0 && RNG::max() >= UINT32_MAX) + { + return static_cast(generator_()); + } + else SEAL_IF_CONSTEXPR (RNG::max() - RNG::min() >= UINT32_MAX) + { + return static_cast(generator_() - RNG::min()); + } + else SEAL_IF_CONSTEXPR (RNG::min() == 0) + { + std::uint64_t max_value = RNG::max(); + std::uint64_t value = static_cast(generator_()); + std::uint64_t max = max_value; + while (max < UINT32_MAX) + { + value *= max_value; + max *= max_value; + value += static_cast(generator_()); + } + return static_cast(value); + } + else + { + std::uint64_t max_value = RNG::max() - RNG::min(); + std::uint64_t value = static_cast(generator_() - RNG::min()); + std::uint64_t max = max_value; + while (max < UINT32_MAX) + { + value *= max_value; + max *= max_value; + value += static_cast(generator_() - RNG::min()); + } + return static_cast(value); + } + } + + private: + RNG generator_; + }; + + /** + Provides an implementation of UniformRandomGeneratorFactory for the standard + C++ library's random number generators. + + @tparam RNG specifies the type of the standard C++ library's random number + generator (e.g., std::default_random_engine) + */ + template + class StandardRandomAdapterFactory : public UniformRandomGeneratorFactory + { + public: + /** + Creates a new uniform random number generator. + */ + auto create() -> std::shared_ptr override + { + return std::shared_ptr{ + new StandardRandomAdapter() }; + } + + private: + }; +} diff --git a/src/seal/relinkeys.cpp b/src/seal/relinkeys.cpp new file mode 100644 index 000000000..550ca794f --- /dev/null +++ b/src/seal/relinkeys.cpp @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/relinkeys.h" +#include "seal/util/defines.h" +#include + +using namespace std; +using namespace seal::util; + +namespace seal +{ + RelinKeys &RelinKeys::operator =(const RelinKeys &assign) + { + // Check for self-assignment + if (this == &assign) + { + return *this; + } + + // Copy over fields + parms_id_ = assign.parms_id_; + decomposition_bit_count_ = assign.decomposition_bit_count_; + + // Then copy over keys + keys_.clear(); + size_t keys_dim1 = assign.keys_.size(); + keys_.reserve(keys_dim1); + for (size_t i = 0; i < keys_dim1; i++) + { + size_t keys_dim2 = assign.keys_[i].size(); + keys_.emplace_back(); + keys_[i].reserve(keys_dim2); + for (size_t j = 0; j < keys_dim2; j++) + { + keys_[i].emplace_back(pool_); + keys_[i][j] = assign.keys_[i][j]; + } + } + + return *this; + } + + bool RelinKeys::is_valid_for(shared_ptr context) const noexcept + { + // Verify parameters + if (!context || !context->parameters_set()) + { + return false; + } + if (parms_id_ != context->first_parms_id()) + { + return false; + } + + for (auto &a : keys_) + { + for (auto &b : a) + { + if (!b.is_valid_for(context) || !b.is_ntt_form() || + b.parms_id() != parms_id_) + { + return false; + } + } + } + + return true; + } + + void RelinKeys::save(std::ostream &stream) const + { + auto old_except_mask = stream.exceptions(); + try + { + // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit + stream.exceptions(ios_base::badbit | ios_base::failbit); + + uint64_t keys_dim1 = static_cast(keys_.size()); + + // Validate keys_dim1 (relinearization key count) + if (keys_dim1 < SEAL_RELIN_KEY_COUNT_MIN || + keys_dim1 > SEAL_RELIN_KEY_COUNT_MAX) + { + throw invalid_argument("count out of bounds"); + } + + int32_t decomposition_bit_count32 = + safe_cast(decomposition_bit_count_); + + // Save the parms_id + stream.write(reinterpret_cast(&parms_id_), + sizeof(parms_id_type)); + + // Save the decomposition bit count + stream.write(reinterpret_cast(&decomposition_bit_count32), + sizeof(int32_t)); + + // Save the size of keys_ + stream.write(reinterpret_cast(&keys_dim1), sizeof(uint64_t)); + + // Now loop again over keys_dim1 + for (size_t index = 0; index < keys_dim1; index++) + { + // Save second dimension of keys_ + uint64_t keys_dim2 = static_cast(keys_[index].size()); + stream.write(reinterpret_cast(&keys_dim2), sizeof(uint64_t)); + + // Loop over keys_dim2 and save all (or none) + for (size_t j = 0; j < keys_dim2; j++) + { + // Save the key + keys_[index][j].save(stream); + } + } + } + catch (const std::exception &) + { + stream.exceptions(old_except_mask); + throw; + } + + stream.exceptions(old_except_mask); + } + + void RelinKeys::unsafe_load(std::istream &stream) + { + auto old_except_mask = stream.exceptions(); + try + { + // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit + stream.exceptions(ios_base::badbit | ios_base::failbit); + + // Clear current keys + keys_.clear(); + + // Read the parms_id + stream.read(reinterpret_cast(&parms_id_), + sizeof(parms_id_type)); + + // Read and validate the decomposition_bit_count + int32_t decomposition_bit_count32 = 0; + stream.read(reinterpret_cast(&decomposition_bit_count32), + sizeof(int32_t)); + if (decomposition_bit_count32 < SEAL_DBC_MIN || + decomposition_bit_count32 > SEAL_DBC_MAX) + { + throw logic_error("decomposition bit count out of bounds"); + } + decomposition_bit_count_ = safe_cast(decomposition_bit_count32); + + // Read in the size of keys_ + uint64_t keys_dim1 = 0; + stream.read(reinterpret_cast(&keys_dim1), sizeof(uint64_t)); + + // Validate keys_dim1 (relinearization key count) + if (keys_dim1 < SEAL_RELIN_KEY_COUNT_MIN || + keys_dim1 > SEAL_RELIN_KEY_COUNT_MAX) + { + throw invalid_argument("count out of bounds"); + } + + // Reserve first for dimension of keys_ + keys_.reserve(safe_cast(keys_dim1)); + + // Loop over the first dimension of keys_ + for (size_t index = 0; index < keys_dim1; index++) + { + // Read the size of the second dimension + uint64_t keys_dim2 = 0; + stream.read(reinterpret_cast(&keys_dim2), sizeof(uint64_t)); + + // Don't resize; only reserve + keys_.emplace_back(); + keys_.back().reserve(safe_cast(keys_dim2)); + for (size_t j = 0; j < keys_dim2; j++) + { + Ciphertext new_key(pool_); + new_key.unsafe_load(stream); + keys_[index].emplace_back(move(new_key)); + } + } + } + catch (const std::exception &) + { + stream.exceptions(old_except_mask); + throw; + } + + stream.exceptions(old_except_mask); + } +} diff --git a/src/seal/relinkeys.h b/src/seal/relinkeys.h new file mode 100644 index 000000000..42386ad8e --- /dev/null +++ b/src/seal/relinkeys.h @@ -0,0 +1,248 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include "seal/ciphertext.h" +#include "seal/memorymanager.h" +#include "seal/encryptionparams.h" + +namespace seal +{ + /** + Class to store relinearization keys. An relinearization key has type + std::vector. + An instance of the RelinKeys class stores internally an std::vector of + relinearization keys. + + @par Relinearization + Concretely, an relinearization key corresponding to a power K of the secret + key can be used in the relinearization operation to change a ciphertext of size + K+1 to size K. Recall that the smallest possible size for a ciphertext is 2, + so the first relinearization key is corresponds to the square of the secret + key. The second relinearization key corresponds to the cube of the secret key, + and so on. For example, to relinearize a ciphertext of size 7 back to size 2, + one would need 5 relinearization keys, although it is hard to imagine a situation + where it makes sense to have size 7 ciphertexts, as operating on such objects + would be very slow. Most commonly only one relinearization key is needed, and + relinearization is performed after every multiplication. + + @par Decomposition Bit Count + Decomposition bit count (dbc) is a parameter that describes a performance + trade-off in the relinearization process. Namely, in the relinearization process + the polynomials in the ciphertexts (with large coefficients) get decomposed + into a smaller base 2^dbc, coefficient-wise. Each of the decomposition factors + corresponds to a piece of data in the relinearization key, so the smaller the + dbc is, the larger the relinearization keys are. Moreover, a smaller dbc results + in less invariant noise budget being consumed in the relinearization process. + However, using a large dbc is much faster, and often one would want to optimize + the dbc to be as large as possible for performance. The dbc is upper-bounded + by the value of 60, and lower-bounded by the value of 1. + + @par Thread Safety + In general, reading from RelinKeys is thread-safe as long as no other thread + is concurrently mutating it. This is due to the underlying data structure + storing the relinearization keys not being thread-safe. + + + @see SecretKey for the class that stores the secret key. + @see PublicKey for the class that stores the public key. + @see GaloisKeys for the class that stores the Galois keys. + @see KeyGenerator for the class that generates the relinearization keys. + */ + class RelinKeys + { + friend class KeyGenerator; + + public: + /** + Creates an empty set of relinearization keys. + */ + RelinKeys() = default; + + /** + Creates a new RelinKeys instance by copying a given instance. + + @param[in] copy The RelinKeys to copy from + */ + RelinKeys(const RelinKeys ©) = default; + + /** + Creates a new RelinKeys instance by moving a given instance. + + @param[in] source The RelinKeys to move from + */ + RelinKeys(RelinKeys &&source) = default; + + /** + Copies a given RelinKeys instance to the current one. + + @param[in] assign The RelinKeys to copy from + */ + RelinKeys &operator =(const RelinKeys &assign); + + /** + Moves a given RelinKeys instance to the current one. + + @param[in] assign The RelinKeys to move from + */ + RelinKeys &operator =(RelinKeys &&assign) = default; + + /** + Returns the current number of relinearization keys. + */ + inline std::size_t size() const + { + return keys_.size(); + } + + /** + Returns the decomposition bit count. + */ + inline int decomposition_bit_count() const noexcept + { + return decomposition_bit_count_; + } + + /** + Returns a reference to the relinearization keys data. + */ + inline auto &data() noexcept + { + return keys_; + } + + /** + Returns a const reference to the relinearization keys data. + */ + inline auto &data() const noexcept + { + return keys_; + } + + /** + Returns a const reference to an relinearization key. The returned + relinearization key corresponds to the given power of the secret key. + + + @param[in] key_power The power of the secret key + @throw std::invalid_argument if the key corresponding to key_power does + not exist + */ + inline auto &key(std::size_t key_power) const + { + if (!has_key(key_power)) + { + throw std::invalid_argument("requested key does not exist"); + } + return keys_[key_power - 2]; + } + + /** + Returns whether an relinearization key corresponding to a given power of + the secret key exists. + + @param[in] key_power The power of the secret key + */ + inline bool has_key(std::size_t key_power) const noexcept + { + return (key_power >= 2) && (keys_.size() >= key_power - 1); + } + + /** + Returns a reference to parms_id. + + @see EncryptionParameters for more information about parms_id. + */ + inline auto &parms_id() noexcept + { + return parms_id_; + } + + /** + Returns a const reference to parms_id. + + @see EncryptionParameters for more information about parms_id. + */ + inline auto &parms_id() const noexcept + { + return parms_id_; + } + + /** + Check whether the current RelinKeys is valid for a given SEALContext. If + the given SEALContext is not set, the encryption parameters are invalid, + or the RelinKeys data does not match the SEALContext, this function returns + false. Otherwise, returns true. + + @param[in] context The SEALContext + */ + bool is_valid_for(std::shared_ptr context) const noexcept; + + /** + Saves the RelinKeys instance to an output stream. The output is in binary + format and not human-readable. The output stream must have the "binary" + flag set. + + @param[in] stream The stream to save the RelinKeys to + @throws std::exception if the RelinKeys could not be written to stream + */ + void save(std::ostream &stream) const; + + /** + Loads a RelinKeys from an input stream overwriting the current RelinKeys. + No checking of the validity of the RelinKeys data against encryption + parameters is performed. This function should not be used unless the + RelinKeys comes from a fully trusted source. + + @param[in] stream The stream to load the RelinKeys from + @throws std::exception if a valid RelinKeys could not be read from stream + */ + void unsafe_load(std::istream &stream); + + /** + Loads a RelinKeys from an input stream overwriting the current RelinKeys. + The loaded RelinKeys is verified to be valid for the given SEALContext. + + @param[in] context The SEALContext + @param[in] stream The stream to load the RelinKeys from + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + @throws std::exception if a valid RelinKeys could not be read from stream + @throws std::invalid_argument if the loaded RelinKeys is invalid for the + context + */ + inline void load(std::shared_ptr context, + std::istream &stream) + { + unsafe_load(stream); + if (!is_valid_for(std::move(context))) + { + throw std::invalid_argument("RelinKeys data is invalid"); + } + } + + /** + Returns the currently used MemoryPoolHandle. + */ + inline MemoryPoolHandle pool() const noexcept + { + return pool_; + } + + private: + MemoryPoolHandle pool_ = MemoryManager::GetPool(); + + parms_id_type parms_id_ = parms_id_zero; + + /** + The vector of relinearization keys. + */ + std::vector> keys_{}; + + int decomposition_bit_count_ = 0; + }; +} diff --git a/src/seal/seal.h b/src/seal/seal.h new file mode 100644 index 000000000..03fb4bfed --- /dev/null +++ b/src/seal/seal.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "seal/biguint.h" +#include "seal/ciphertext.h" +#include "seal/ckks.h" +#include "seal/context.h" +#include "seal/decryptor.h" +#include "seal/defaultparams.h" +#include "seal/encoder.h" +#include "seal/encryptionparams.h" +#include "seal/encryptor.h" +#include "seal/evaluator.h" +#include "seal/intarray.h" +#include "seal/keygenerator.h" +#include "seal/memorymanager.h" +#include "seal/plaintext.h" +#include "seal/batchencoder.h" +#include "seal/publickey.h" +#include "seal/randomgen.h" +#include "seal/relinkeys.h" +#include "seal/secretkey.h" +#include "seal/smallmodulus.h" diff --git a/src/seal/secretkey.h b/src/seal/secretkey.h new file mode 100644 index 000000000..829ee45d8 --- /dev/null +++ b/src/seal/secretkey.h @@ -0,0 +1,219 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "seal/randomgen.h" +#include "seal/plaintext.h" +#include "seal/memorymanager.h" +#include "seal/util/common.h" +#include +#include +#include +#include +#include + +namespace seal +{ + /** + Class to store a secret key. + + @par Thread Safety + In general, reading from SecretKey is thread-safe as long as no other thread + is concurrently mutating it. This is due to the underlying data structure + storing the secret key not being thread-safe. + + + @see KeyGenerator for the class that generates the secret key. + @see PublicKey for the class that stores the public key. + @see RelinKeys for the class that stores the relinearization keys. + @see GaloisKeys for the class that stores the Galois keys. + */ + class SecretKey + { + friend class KeyGenerator; + + public: + /** + Creates an empty secret key. + */ + SecretKey() = default; + + /** + Overwrites the key data by random data and destroys the SecretKey object. + */ + ~SecretKey() noexcept + { + // We use a default factory from std::random_device to make sure + // randomize_key does not throw. + static std::unique_ptr random_factory( + std::make_unique>()); + randomize_secret(random_factory->create()); + } + + /** + Creates a new SecretKey by copying an old one. + + @param[in] copy The SecretKey to copy from + */ + SecretKey(const SecretKey ©) = default; + + /** + Creates a new SecretKey by moving an old one. + + @param[in] source The SecretKey to move from + */ + SecretKey(SecretKey &&source) = default; + + /** + Copies an old SecretKey to the current one. + + @param[in] assign The SecretKey to copy from + */ + SecretKey &operator =(const SecretKey &assign) + { + sk_ = assign.sk_; + return *this; + } + + /** + Moves an old SecretKey to the current one. + + @param[in] assign The SecretKey to move from + */ + SecretKey &operator =(SecretKey &&assign) = default; + + /** + Returns a reference to the underlying polynomial. + */ + inline auto &data() noexcept + { + return sk_; + } + + /** + Returns a const reference to the underlying polynomial. + */ + inline auto &data() const noexcept + { + return sk_; + } + + /** + Check whether the current SecretKey is valid for a given SEALContext. If + the given SEALContext is not set, the encryption parameters are invalid, + or the SecretKey data does not match the SEALContext, this function returns + false. Otherwise, returns true. + + @param[in] context The SEALContext + */ + inline bool is_valid_for(std::shared_ptr context) const + { + // Verify parameters + if (!context || !context->parameters_set()) + { + return false; + } + auto parms_id = context->first_parms_id(); + return sk_.is_valid_for(std::move(context)) && + sk_.is_ntt_form() && sk_.parms_id() == parms_id; + } + + /** + Saves the SecretKey to an output stream. The output is in binary format + and not human-readable. The output stream must have the "binary" flag set. + + @param[in] stream The stream to save the SecretKey to + @throws std::exception if the plaintext could not be written to stream + */ + inline void save(std::ostream &stream) const + { + sk_.save(stream); + } + + /** + Loads a SecretKey from an input stream overwriting the current SecretKey. + No checking of the validity of the SecretKey data against encryption + parameters is performed. This function should not be used unless the + SecretKey comes from a fully trusted source. + + @param[in] stream The stream to load the SecretKey from + @throws std::exception if a valid SecretKey could not be read from stream + */ + inline void unsafe_load(std::istream &stream) + { + sk_.unsafe_load(stream); + } + + /** + Loads a SecretKey from an input stream overwriting the current SecretKey. + The loaded SecretKey is verified to be valid for the given SEALContext. + + @param[in] context The SEALContext + @param[in] stream The stream to load the SecretKey from + @throws std::invalid_argument if the context is not set or encryption + parameters are not valid + @throws std::exception if a valid SecretKey could not be read from stream + @throws std::invalid_argument if the loaded SecretKey is invalid for the + context + */ + inline void load(std::shared_ptr context, + std::istream &stream) + { + unsafe_load(stream); + if (!is_valid_for(std::move(context))) + { + throw std::invalid_argument("SecretKey data is invalid"); + } + } + + /** + Returns a reference to parms_id. + + @see EncryptionParameters for more information about parms_id. + */ + inline auto &parms_id() noexcept + { + return sk_.parms_id(); + } + + /** + Returns a const reference to parms_id. + + @see EncryptionParameters for more information about parms_id. + */ + inline auto &parms_id() const noexcept + { + return sk_.parms_id(); + } + + /** + Returns the currently used MemoryPoolHandle. + */ + inline MemoryPoolHandle pool() const noexcept + { + return sk_.pool(); + } + + private: + inline void randomize_secret( + std::shared_ptr random) noexcept + { + std::size_t capacity = sk_.capacity(); + volatile SEAL_BYTE *data_ptr = reinterpret_cast(sk_.data()); + while (capacity--) + { + std::size_t pt_coeff_byte_count = sizeof(Plaintext::pt_coeff_type); + while (pt_coeff_byte_count--) + { + *data_ptr++ = static_cast(random->generate()); + } + } + } + + /** + We use a fresh memory pool with `clear_on_destruction' enabled + */ + Plaintext sk_{ MemoryManager::GetPool(mm_prof_opt::FORCE_NEW, true) }; + }; +} diff --git a/src/seal/smallmodulus.cpp b/src/seal/smallmodulus.cpp new file mode 100644 index 000000000..b7fffcd67 --- /dev/null +++ b/src/seal/smallmodulus.cpp @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/smallmodulus.h" +#include "seal/util/uintarith.h" +#include "seal/util/common.h" +#include + +using namespace seal::util; +using namespace std; + +namespace seal +{ + void SmallModulus::save(ostream &stream) const + { + auto old_except_mask = stream.exceptions(); + try + { + // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit + stream.exceptions(ios_base::badbit | ios_base::failbit); + + stream.write(reinterpret_cast(&value_), sizeof(uint64_t)); + } + catch (const std::exception &) + { + stream.exceptions(old_except_mask); + throw; + } + + stream.exceptions(old_except_mask); + } + + void SmallModulus::load(istream &stream) + { + auto old_except_mask = stream.exceptions(); + try + { + // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit + stream.exceptions(ios_base::badbit | ios_base::failbit); + + uint64_t value; + stream.read(reinterpret_cast(&value), sizeof(uint64_t)); + set_value(value); + } + catch (const std::exception &) + { + stream.exceptions(old_except_mask); + throw; + } + + stream.exceptions(old_except_mask); + } + + void SmallModulus::set_value(uint64_t value) + { + if (value == 0) + { + // Zero settings + bit_count_ = 0; + uint64_count_ = 1; + value_ = 0; + const_ratio_ = { { 0, 0, 0 } }; + } + else if ((value >> 62 != 0) || (value == uint64_t(0x4000000000000000)) || + (value == 1)) + { + throw invalid_argument("value can be at most 62 bits and cannot be 1"); + } + else + { + // All normal, compute const_ratio and set everything + value_ = value; + bit_count_ = get_significant_bit_count(value_); + + // Compute Barrett ratios for 64-bit words (barrett_reduce_128) + uint64_t numerator[3]{ 0, 0, 1 }; + uint64_t quotient[3]{ 0, 0, 0 }; + + // Use a special method to avoid using memory pool + divide_uint192_uint64_inplace(numerator, value_, quotient); + + const_ratio_[0] = quotient[0]; + const_ratio_[1] = quotient[1]; + + // We store also the remainder + const_ratio_[2] = numerator[0]; + + uint64_count_ = 1; + } + } +} diff --git a/src/seal/smallmodulus.h b/src/seal/smallmodulus.h new file mode 100644 index 000000000..ee60a454e --- /dev/null +++ b/src/seal/smallmodulus.h @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include "seal/util/uintcore.h" +#include "seal/memorymanager.h" + +namespace seal +{ + /** + Represent an integer modulus of up to 62 bits. An instance of the SmallModulus + class represents a non-negative integer modulus up to 62 bits. In particular, + the encryption parameter plain_modulus, and the primes in coeff_modulus, are + represented by instances of SmallModulus. The purpose of this class is to + perform and store the pre-computation required by Barrett reduction. + + @par Thread Safety + In general, reading from SmallModulus is thread-safe as long as no other thread + is concurrently mutating it. + + @see EncryptionParameters for a description of the encryption parameters. + */ + class SmallModulus + { + public: + /** + Creates a SmallModulus instance. The value of the SmallModulus is set to + the given value, or to zero by default. + + @param[in] value The integer modulus + @throws std::invalid_argument if value is 1 or more than 62 bits + */ + SmallModulus(std::uint64_t value = 0) + { + set_value(value); + } + + /** + Creates a new SmallModulus by copying a given one. + + @param[in] copy The SmallModulus to copy from + */ + SmallModulus(const SmallModulus ©) = default; + + /** + Creates a new SmallModulus by copying a given one. + + @param[in] source The SmallModulus to move from + */ + SmallModulus(SmallModulus &&source) = default; + + /** + Copies a given SmallModulus to the current one. + + @param[in] assign The SmallModulus to copy from + */ + SmallModulus &operator =(const SmallModulus &assign) = default; + + /** + Moves a given SmallModulus to the current one. + + @param[in] assign The SmallModulus to move from + */ + SmallModulus &operator =(SmallModulus &&assign) = default; + + /** + Sets the value of the SmallModulus. + + @param[in] value The new integer modulus + @throws std::invalid_argument if value is 1 or more than 62 bits + */ + inline SmallModulus &operator =(std::uint64_t value) + { + set_value(value); + return *this; + } + + /** + Returns the significant bit count of the value of the current SmallModulus. + */ + inline int bit_count() const + { + return bit_count_; + } + + /** + Returns the size (in 64-bit words) of the value of the current SmallModulus. + */ + inline std::size_t uint64_count() const + { + return uint64_count_; + } + + /** + Returns a const pointer to the value of the current SmallModulus. + */ + inline const uint64_t *data() const + { + return &value_; + } + + /** + Returns the value of the current SmallModulus. + */ + inline std::uint64_t value() const + { + return value_; + } + + /** + Returns the Barrett ratio computed for the value of the current SmallModulus. + The first two components of the Barrett ratio are the floor of 2^128/value, + and the third component is the remainder. + */ + inline auto &const_ratio() const + { + return const_ratio_; + } + + /** + Returns whether the value of the current SmallModulus is zero. + */ + inline bool is_zero() const + { + return value_ == 0; + } + + /** + Compares two SmallModulus instances. + + @param[in] compare The SmallModulus to compare against + */ + inline bool operator ==(const SmallModulus &compare) const + { + return value_ == compare.value_; + } + + /** + Compares two SmallModulus instances. + + @param[in] compare The SmallModulus to compare against + */ + inline bool operator !=(const SmallModulus &compare) const + { + return !(value_ == compare.value_); + } + + /** + Saves the SmallModulus to an output stream. The full state of the modulus is + serialized. The output is in binary format and not human-readable. The output + stream must have the "binary" flag set. + + @param[in] stream The stream to save the SmallModulus to + @throws std::exception if the SmallModulus could not be written to stream + */ + void save(std::ostream &stream) const; + + /** + Loads a SmallModulus from an input stream overwriting the current SmallModulus. + + @param[in] stream The stream to load the SmallModulus from + @throws std::exception if a valid SmallModulus could not be read from stream + */ + void load(std::istream &stream); + + private: + SmallModulus(std::uint64_t value, + std::array const_ratio, + int bit_count, std::size_t uint64_count) : + value_(value), const_ratio_(const_ratio), + bit_count_(bit_count), uint64_count_(uint64_count) + { + } + + void set_value(std::uint64_t value); + + std::uint64_t value_ = 0; + + std::array const_ratio_{ { 0, 0, 0 } }; + + int bit_count_ = 0; + + std::size_t uint64_count_ = 0; + }; +} diff --git a/src/seal/util/CMakeLists.txt b/src/seal/util/CMakeLists.txt new file mode 100644 index 000000000..b63c1dea4 --- /dev/null +++ b/src/seal/util/CMakeLists.txt @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +target_sources(seal + PRIVATE + ${CMAKE_CURRENT_LIST_DIR}/aes.cpp + ${CMAKE_CURRENT_LIST_DIR}/baseconverter.cpp + ${CMAKE_CURRENT_LIST_DIR}/clipnormal.cpp + ${CMAKE_CURRENT_LIST_DIR}/globals.cpp + ${CMAKE_CURRENT_LIST_DIR}/hash.cpp + ${CMAKE_CURRENT_LIST_DIR}/mempool.cpp + ${CMAKE_CURRENT_LIST_DIR}/numth.cpp + ${CMAKE_CURRENT_LIST_DIR}/polyarith.cpp + ${CMAKE_CURRENT_LIST_DIR}/polyarithmod.cpp + ${CMAKE_CURRENT_LIST_DIR}/polyarithsmallmod.cpp + ${CMAKE_CURRENT_LIST_DIR}/smallntt.cpp + ${CMAKE_CURRENT_LIST_DIR}/uintarith.cpp + ${CMAKE_CURRENT_LIST_DIR}/uintarithmod.cpp + ${CMAKE_CURRENT_LIST_DIR}/uintarithsmallmod.cpp + ${CMAKE_CURRENT_LIST_DIR}/uintcore.cpp +) + +# Create the config file +configure_file(${CMAKE_CURRENT_LIST_DIR}/config.h.in ${CMAKE_CURRENT_LIST_DIR}/config.h) + +install( + FILES + ${CMAKE_CURRENT_LIST_DIR}/aes.h + ${CMAKE_CURRENT_LIST_DIR}/baseconverter.h + ${CMAKE_CURRENT_LIST_DIR}/clang.h + ${CMAKE_CURRENT_LIST_DIR}/clipnormal.h + ${CMAKE_CURRENT_LIST_DIR}/common.h + ${CMAKE_CURRENT_LIST_DIR}/config.h + ${CMAKE_CURRENT_LIST_DIR}/defines.h + ${CMAKE_CURRENT_LIST_DIR}/gcc.h + ${CMAKE_CURRENT_LIST_DIR}/globals.h + ${CMAKE_CURRENT_LIST_DIR}/hash.h + ${CMAKE_CURRENT_LIST_DIR}/hestdparms.h + ${CMAKE_CURRENT_LIST_DIR}/locks.h + ${CMAKE_CURRENT_LIST_DIR}/mempool.h + ${CMAKE_CURRENT_LIST_DIR}/msvc.h + ${CMAKE_CURRENT_LIST_DIR}/numth.h + ${CMAKE_CURRENT_LIST_DIR}/pointer.h + ${CMAKE_CURRENT_LIST_DIR}/polyarith.h + ${CMAKE_CURRENT_LIST_DIR}/polyarithmod.h + ${CMAKE_CURRENT_LIST_DIR}/polyarithsmallmod.h + ${CMAKE_CURRENT_LIST_DIR}/polycore.h + ${CMAKE_CURRENT_LIST_DIR}/randomtostd.h + ${CMAKE_CURRENT_LIST_DIR}/smallntt.h + ${CMAKE_CURRENT_LIST_DIR}/uintarith.h + ${CMAKE_CURRENT_LIST_DIR}/uintarithmod.h + ${CMAKE_CURRENT_LIST_DIR}/uintarithsmallmod.h + ${CMAKE_CURRENT_LIST_DIR}/uintcore.h + DESTINATION + ${SEAL_INCLUDES_INSTALL_DIR}/seal/util +) diff --git a/src/seal/util/aes.cpp b/src/seal/util/aes.cpp new file mode 100644 index 000000000..3b1b6efc1 --- /dev/null +++ b/src/seal/util/aes.cpp @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "aes.h" + +#ifdef SEAL_USE_AES_NI_PRNG + +namespace seal +{ + namespace + { + __m128i keygen_helper(__m128i key, __m128i key_rcon) + { + key_rcon = _mm_shuffle_epi32(key_rcon, _MM_SHUFFLE(3, 3, 3, 3)); + key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); + key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); + key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); + return _mm_xor_si128(key, key_rcon); + } + } + + void AESEncryptor::set_key(const aes_block &key) + { + round_key_[0] = key.i128; + round_key_[1] = keygen_helper(round_key_[0], _mm_aeskeygenassist_si128(round_key_[0], 0x01)); + round_key_[2] = keygen_helper(round_key_[1], _mm_aeskeygenassist_si128(round_key_[1], 0x02)); + round_key_[3] = keygen_helper(round_key_[2], _mm_aeskeygenassist_si128(round_key_[2], 0x04)); + round_key_[4] = keygen_helper(round_key_[3], _mm_aeskeygenassist_si128(round_key_[3], 0x08)); + round_key_[5] = keygen_helper(round_key_[4], _mm_aeskeygenassist_si128(round_key_[4], 0x10)); + round_key_[6] = keygen_helper(round_key_[5], _mm_aeskeygenassist_si128(round_key_[5], 0x20)); + round_key_[7] = keygen_helper(round_key_[6], _mm_aeskeygenassist_si128(round_key_[6], 0x40)); + round_key_[8] = keygen_helper(round_key_[7], _mm_aeskeygenassist_si128(round_key_[7], 0x80)); + round_key_[9] = keygen_helper(round_key_[8], _mm_aeskeygenassist_si128(round_key_[8], 0x1B)); + round_key_[10] = keygen_helper(round_key_[9], _mm_aeskeygenassist_si128(round_key_[9], 0x36)); + } + + void AESEncryptor::ecb_encrypt(const aes_block &plaintext, aes_block &ciphertext) const + { + ciphertext.i128 = _mm_xor_si128(plaintext.i128, round_key_[0]); + ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[1]); + ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[2]); + ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[3]); + ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[4]); + ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[5]); + ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[6]); + ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[7]); + ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[8]); + ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[9]); + ciphertext.i128 = _mm_aesenclast_si128(ciphertext.i128, round_key_[10]); + } + + void AESEncryptor::ecb_encrypt(const aes_block *plaintext, + size_t aes_block_count, aes_block *ciphertext) const + { + for (; aes_block_count--; ciphertext++, plaintext++) + { + ciphertext->i128 = _mm_xor_si128(plaintext->i128, round_key_[0]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[1]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[2]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[3]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[4]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[5]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[6]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[7]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[8]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[9]); + ciphertext->i128 = _mm_aesenclast_si128(ciphertext->i128, round_key_[10]); + } + } + + void AESEncryptor::counter_encrypt(size_t start_index, + size_t aes_block_count, aes_block *ciphertext) const + { + for (; aes_block_count--; start_index++, ciphertext++) + { + ciphertext->i128 = _mm_xor_si128( + _mm_set_epi64x(0, static_cast(start_index)), round_key_[0]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[1]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[2]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[3]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[4]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[5]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[6]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[7]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[8]); + ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[9]); + ciphertext->i128 = _mm_aesenclast_si128(ciphertext->i128, round_key_[10]); + } + } + + AESDecryptor::AESDecryptor(const aes_block &key) + { + set_key(key); + } + + void AESDecryptor::set_key(const aes_block &key) + { + const __m128i &v0 = key.i128; + const __m128i v1 = keygen_helper(v0, _mm_aeskeygenassist_si128(v0, 0x01)); + const __m128i v2 = keygen_helper(v1, _mm_aeskeygenassist_si128(v1, 0x02)); + const __m128i v3 = keygen_helper(v2, _mm_aeskeygenassist_si128(v2, 0x04)); + const __m128i v4 = keygen_helper(v3, _mm_aeskeygenassist_si128(v3, 0x08)); + const __m128i v5 = keygen_helper(v4, _mm_aeskeygenassist_si128(v4, 0x10)); + const __m128i v6 = keygen_helper(v5, _mm_aeskeygenassist_si128(v5, 0x20)); + const __m128i v7 = keygen_helper(v6, _mm_aeskeygenassist_si128(v6, 0x40)); + const __m128i v8 = keygen_helper(v7, _mm_aeskeygenassist_si128(v7, 0x80)); + const __m128i v9 = keygen_helper(v8, _mm_aeskeygenassist_si128(v8, 0x1B)); + const __m128i v10 = keygen_helper(v9, _mm_aeskeygenassist_si128(v9, 0x36)); + + _mm_storeu_si128(round_key_, v10); + _mm_storeu_si128(round_key_ + 1, _mm_aesimc_si128(v9)); + _mm_storeu_si128(round_key_ + 2, _mm_aesimc_si128(v8)); + _mm_storeu_si128(round_key_ + 3, _mm_aesimc_si128(v7)); + _mm_storeu_si128(round_key_ + 4, _mm_aesimc_si128(v6)); + _mm_storeu_si128(round_key_ + 5, _mm_aesimc_si128(v5)); + _mm_storeu_si128(round_key_ + 6, _mm_aesimc_si128(v4)); + _mm_storeu_si128(round_key_ + 7, _mm_aesimc_si128(v3)); + _mm_storeu_si128(round_key_ + 8, _mm_aesimc_si128(v2)); + _mm_storeu_si128(round_key_ + 9, _mm_aesimc_si128(v1)); + _mm_storeu_si128(round_key_ + 10, v0); + } + + void AESDecryptor::ecb_decrypt(const aes_block &ciphertext, aes_block &plaintext) + { + plaintext.i128 = _mm_xor_si128(ciphertext.i128, round_key_[0]); + plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[1]); + plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[2]); + plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[3]); + plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[4]); + plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[5]); + plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[6]); + plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[7]); + plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[8]); + plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[9]); + plaintext.i128 = _mm_aesdeclast_si128(plaintext.i128, round_key_[10]); + } +} + +#endif diff --git a/src/seal/util/aes.h b/src/seal/util/aes.h new file mode 100644 index 000000000..6576fe257 --- /dev/null +++ b/src/seal/util/aes.h @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "seal/util/defines.h" + +#ifdef SEAL_USE_AES_NI_PRNG + +#include +#include +#include + +namespace seal +{ + union aes_block + { + std::uint32_t u32[4]; + std::uint64_t u64[2]; + __m128i i128; + }; + + class AESEncryptor + { + public: + AESEncryptor() = default; + + AESEncryptor(const aes_block &key) + { + set_key(key); + } + + AESEncryptor(std::uint64_t key_lw, std::uint64_t key_hw) + { + aes_block key; + key.u64[0] = key_lw; + key.u64[1] = key_hw; + set_key(key); + } + + void set_key(const aes_block &key); + + void ecb_encrypt(const aes_block &plaintext, aes_block &ciphertext) const; + + inline aes_block ecb_encrypt(const aes_block &plaintext) const + { + aes_block ret; + ecb_encrypt(plaintext, ret); + return ret; + } + + // ECB mode encryption + void ecb_encrypt(const aes_block *plaintext, + std::size_t aes_block_count, aes_block *ciphertext) const; + + // Counter Mode encryption: encrypts the counter + void counter_encrypt(std::size_t start_index, + std::size_t aes_block_count, aes_block *ciphertext) const; + + private: + __m128i round_key_[11]; + }; + + class AESDecryptor + { + public: + AESDecryptor() = default; + + AESDecryptor(const aes_block &key); + + void set_key(const aes_block &key); + + void ecb_decrypt(const aes_block &ciphertext, aes_block &plaintext); + + inline aes_block ecb_decrypt(const aes_block &ciphertext) + { + aes_block ret; + ecb_decrypt(ciphertext, ret); + return ret; + } + + private: + __m128i round_key_[11]; + }; +} + +#endif diff --git a/src/seal/util/baseconverter.cpp b/src/seal/util/baseconverter.cpp new file mode 100644 index 000000000..2806b7a47 --- /dev/null +++ b/src/seal/util/baseconverter.cpp @@ -0,0 +1,975 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include "seal/util/defines.h" +#include "seal/util/pointer.h" +#include "seal/util/uintcore.h" +#include "seal/util/baseconverter.h" +#include "seal/util/uintarith.h" +#include "seal/util/uintarithsmallmod.h" +#include "seal/util/uintarithmod.h" +#include "seal/util/smallntt.h" +#include "seal/util/globals.h" +#include "seal/smallmodulus.h" +#include "seal/defaultparams.h" + +using namespace std; + +namespace seal +{ + namespace util + { + BaseConverter::BaseConverter(const std::vector &coeff_base, + size_t coeff_count, const SmallModulus &small_plain_mod, + MemoryPoolHandle pool) : pool_(move(pool)) + { +#ifdef SEAL_DEBUG + if (!pool) + { + throw std::invalid_argument("pool is uninitialized"); + } +#endif + generate(coeff_base, coeff_count, small_plain_mod); + } + + void BaseConverter::generate(const std::vector &coeff_base, + size_t coeff_count, const SmallModulus &small_plain_mod) + { +#ifdef SEAL_DEBUG + if (get_power_of_two(coeff_count) < 0) + { + throw invalid_argument("coeff_count must be a power of 2"); + } + if (coeff_base.size() < SEAL_COEFF_MOD_COUNT_MIN || + coeff_base.size() > SEAL_COEFF_MOD_COUNT_MAX) + { + throw invalid_argument("coeff_base has invalid size"); + } +#endif + int coeff_count_power = get_power_of_two(coeff_count); + + /** + Perform all the required pre-computations and populate the tables + */ + reset(); + + m_sk_ = global_variables::internal_mods::m_sk; + m_tilde_ = global_variables::internal_mods::m_tilde; + gamma_ = global_variables::internal_mods::gamma; + small_plain_mod_ = small_plain_mod; + coeff_count_ = coeff_count; + coeff_base_mod_count_ = coeff_base.size(); + aux_base_mod_count_ = coeff_base.size(); + + // In some cases we might need to increase the size of the aux base by one, namely + // we require K * n * t * q^2 < q * prod_i m_i * m_sk, where K takes into account + // cross terms when larger size ciphertexts are used, and n is the "delta factor" + // for the ring. We reserve 32 bits for K * n. Here the coeff modulus primes q_i + // are bounded to be 60 bits, and all m_i, m_sk are 61 bits. + int total_coeff_bit_count = accumulate(coeff_base.cbegin(), coeff_base.cend(), 0, + [](int result, auto &mod) { return result + mod.bit_count(); }); + + if (32 + small_plain_mod_.bit_count() + total_coeff_bit_count >= + 61 * safe_cast(coeff_base_mod_count_) + 61) + { + aux_base_mod_count_++; + } + + // Base sizes + bsk_base_mod_count_ = aux_base_mod_count_ + 1; + plain_gamma_count_ = 2; + + // We use a reversed order here for performance reasons + coeff_base_products_mod_aux_bsk_array_ = + allocate>(bsk_base_mod_count_, pool_); + generate_n( + coeff_base_products_mod_aux_bsk_array_.get(), + bsk_base_mod_count_, + [&]() { return allocate_uint(coeff_base_mod_count_, pool_); }); + + // We use a reversed order here for performance reasons + aux_base_products_mod_coeff_array_ = + allocate>(coeff_base_mod_count_, pool_); + generate_n( + aux_base_products_mod_coeff_array_.get(), + coeff_base_mod_count_, + [&]() { return allocate_uint(aux_base_mod_count_, pool_); }); + + coeff_products_mod_plain_gamma_array_ = + allocate>(plain_gamma_count_, pool_); + generate_n( + coeff_products_mod_plain_gamma_array_.get(), + plain_gamma_count_, + [&]() { return allocate_uint(coeff_base_mod_count_, pool_); }); + + // Create moduli arrays + coeff_base_array_ = allocate(coeff_base_mod_count_, pool_); + aux_base_array_ = allocate(aux_base_mod_count_, pool_); + bsk_base_array_ = allocate(bsk_base_mod_count_, pool_); + + copy(coeff_base.cbegin(), coeff_base.cend(), coeff_base_array_.get()); + copy_n(global_variables::internal_mods::aux_small_mods.cbegin(), + aux_base_mod_count_, aux_base_array_.get()); + copy_n(aux_base_array_.get(), aux_base_mod_count_, bsk_base_array_.get()); + bsk_base_array_[bsk_base_mod_count_ - 1] = m_sk_; + + // Generate Bsk U {mtilde} small ntt tables which is used in Evaluator + bsk_small_ntt_tables_ = allocate(bsk_base_mod_count_, pool_); + for (size_t i = 0; i < bsk_base_mod_count_; i++) + { + if (!bsk_small_ntt_tables_[i].generate(coeff_count_power, bsk_base_array_[i])) + { + reset(); + return; + } + } + + size_t coeff_products_uint64_count = coeff_base_mod_count_; + size_t aux_products_uint64_count = aux_base_mod_count_; + + // Generate punctured products of coeff moduli + coeff_products_array_ = allocate_zero_uint( + mul_safe(coeff_products_uint64_count, coeff_base_mod_count_), pool_); + auto tmp_coeff(allocate_uint(coeff_products_uint64_count, pool_)); + + for (size_t i = 0; i < coeff_base_mod_count_; i++) + { + coeff_products_array_[i * coeff_products_uint64_count] = 1; + for (size_t j = 0; j < coeff_base_mod_count_; j++) + { + if (i != j) + { + multiply_uint_uint64(coeff_products_array_.get() + + (i * coeff_products_uint64_count), coeff_products_uint64_count, + coeff_base_array_[j].value(), coeff_products_uint64_count, + tmp_coeff.get()); + set_uint_uint(tmp_coeff.get(), coeff_products_uint64_count, + coeff_products_array_.get() + (i * coeff_products_uint64_count)); + } + } + } + + // Generate punctured products of aux moduli + auto aux_products_array(allocate_zero_uint( + mul_safe(aux_products_uint64_count, aux_base_mod_count_), pool_)); + auto tmp_aux(allocate_uint(aux_products_uint64_count, pool_)); + + for (size_t i = 0; i < aux_base_mod_count_; i++) + { + aux_products_array[i * aux_products_uint64_count] = 1; + for (size_t j = 0; j < aux_base_mod_count_; j++) + { + if (i != j) + { + multiply_uint_uint64(aux_products_array.get() + + (i * aux_products_uint64_count), aux_products_uint64_count, + aux_base_array_[j].value(), aux_products_uint64_count, + tmp_aux.get()); + set_uint_uint(tmp_aux.get(), aux_products_uint64_count, + aux_products_array.get() + (i * aux_products_uint64_count)); + } + } + } + + // Compute auxiliary base products mod m_sk + aux_base_products_mod_msk_array_ = allocate_uint(aux_base_mod_count_, pool_); + for (size_t i = 0; i < aux_base_mod_count_; i++) + { + aux_base_products_mod_msk_array_[i] = + modulo_uint(aux_products_array.get() + (i * aux_products_uint64_count), + aux_products_uint64_count, m_sk_, pool_); + } + + // Compute inverse coeff base mod coeff base array (qi^(-1)) mod qi and + // mtilde inv coeff products mod auxiliary moduli (m_tilda*qi^(-1)) mod qi + inv_coeff_base_products_mod_coeff_array_ = + allocate_uint(coeff_base_mod_count_, pool_); + mtilde_inv_coeff_base_products_mod_coeff_array_ = + allocate_uint(coeff_base_mod_count_, pool_); + for (size_t i = 0; i < coeff_base_mod_count_; i++) + { + inv_coeff_base_products_mod_coeff_array_[i] = + modulo_uint(coeff_products_array_.get() + (i * coeff_products_uint64_count), + coeff_products_uint64_count, coeff_base_array_[i], pool_); + if (!try_invert_uint_mod(inv_coeff_base_products_mod_coeff_array_[i], + coeff_base_array_[i], inv_coeff_base_products_mod_coeff_array_[i])) + { + reset(); + return; + } + mtilde_inv_coeff_base_products_mod_coeff_array_[i] = + multiply_uint_uint_mod(inv_coeff_base_products_mod_coeff_array_[i], + m_tilde_.value(), coeff_base_array_[i]); + } + + // Compute inverse auxiliary moduli mod auxiliary moduli (mi^(-1)) mod mi + inv_aux_base_products_mod_aux_array_ = allocate_uint(aux_base_mod_count_, pool_); + for (size_t i = 0; i < aux_base_mod_count_; i++) + { + inv_aux_base_products_mod_aux_array_[i] = + modulo_uint(aux_products_array.get() + (i * aux_products_uint64_count), + aux_products_uint64_count, aux_base_array_[i], pool_); + if (!try_invert_uint_mod(inv_aux_base_products_mod_aux_array_[i], + aux_base_array_[i], inv_aux_base_products_mod_aux_array_[i])) + { + reset(); + return; + } + } + + // Compute coeff modulus products mod mtilde (qi) mod m_tilde_ + coeff_base_products_mod_mtilde_array_ = allocate_uint(coeff_base_mod_count_, pool_); + for (size_t i = 0; i < coeff_base_mod_count_; i++) + { + coeff_base_products_mod_mtilde_array_[i] = + modulo_uint(coeff_products_array_.get() + (i * coeff_products_uint64_count), + coeff_products_uint64_count, m_tilde_, pool_); + } + + // Compute coeff modulus products mod auxiliary moduli (qi) mod mj U {msk} + coeff_base_products_mod_aux_bsk_array_ = + allocate>(bsk_base_mod_count_, pool_); + for (size_t i = 0; i < aux_base_mod_count_; i++) + { + coeff_base_products_mod_aux_bsk_array_[i] = + allocate_uint(coeff_base_mod_count_, pool_); + for (size_t j = 0; j < coeff_base_mod_count_; j++) + { + coeff_base_products_mod_aux_bsk_array_[i][j] = + modulo_uint(coeff_products_array_.get() + (j * coeff_products_uint64_count), + coeff_products_uint64_count, aux_base_array_[i], pool_); + } + } + + // Add qi mod msk at the end of the array + coeff_base_products_mod_aux_bsk_array_[bsk_base_mod_count_ - 1] = + allocate_uint(coeff_base_mod_count_, pool_); + for (size_t i = 0; i < coeff_base_mod_count_; i++) + { + coeff_base_products_mod_aux_bsk_array_[bsk_base_mod_count_ - 1][i] = + modulo_uint(coeff_products_array_.get() + (i * coeff_products_uint64_count), + coeff_products_uint64_count, m_sk_, pool_); + } + + // Compute auxiliary moduli products mod coeff moduli (mj) mod qi + aux_base_products_mod_coeff_array_ = + allocate>(coeff_base_mod_count_, pool_); + for (size_t i = 0; i < coeff_base_mod_count_; i++) + { + aux_base_products_mod_coeff_array_[i] = allocate_uint(aux_base_mod_count_, pool_); + for (size_t j = 0; j < aux_base_mod_count_; j++) + { + aux_base_products_mod_coeff_array_[i][j] = + modulo_uint(aux_products_array.get() + (j * aux_products_uint64_count), + aux_products_uint64_count, coeff_base_array_[i], pool_); + } + } + + // Compute coeff moduli products inverse mod auxiliary mods (qi^(-1)) mod mj U {msk} + auto coeff_products_all(allocate_uint(coeff_base_mod_count_, pool_)); + auto tmp_products_all(allocate_uint(coeff_base_mod_count_, pool_)); + set_uint(1, coeff_base_mod_count_, coeff_products_all.get()); + + // Compute the product of all coeff moduli + for (size_t i = 0; i < coeff_base_mod_count_; i++) + { + multiply_uint_uint64(coeff_products_all.get(), coeff_base_mod_count_, + coeff_base_array_[i].value(), coeff_base_mod_count_, tmp_products_all.get()); + set_uint_uint(tmp_products_all.get(), coeff_base_mod_count_, + coeff_products_all.get()); + } + + // Compute inverses of coeff_products_all modulo aux moduli + inv_coeff_products_all_mod_aux_bsk_array_ = allocate_uint(bsk_base_mod_count_, pool_); + for (size_t i = 0; i < aux_base_mod_count_; i++) + { + inv_coeff_products_all_mod_aux_bsk_array_[i] = modulo_uint(coeff_products_all.get(), + coeff_base_mod_count_, aux_base_array_[i], pool_); + if (!try_invert_uint_mod(inv_coeff_products_all_mod_aux_bsk_array_[i], + aux_base_array_[i], inv_coeff_products_all_mod_aux_bsk_array_[i])) + { + reset(); + return; + } + } + + // Add product of all coeffs mod msk at the end of the array + inv_coeff_products_all_mod_aux_bsk_array_[bsk_base_mod_count_ - 1] = + modulo_uint(coeff_products_all.get(), coeff_base_mod_count_, m_sk_, pool_); + if (!try_invert_uint_mod(inv_coeff_products_all_mod_aux_bsk_array_[bsk_base_mod_count_ - 1], + m_sk_, inv_coeff_products_all_mod_aux_bsk_array_[bsk_base_mod_count_ - 1])) + { + reset(); + return; + } + + // Compute the products of all aux moduli + auto aux_products_all(allocate_uint(aux_base_mod_count_, pool_)); + auto tmp_aux_products_all(allocate_uint(aux_base_mod_count_, pool_)); + set_uint(1, aux_base_mod_count_, aux_products_all.get()); + + for (size_t i = 0; i < aux_base_mod_count_; i++) + { + multiply_uint_uint64(aux_products_all.get(), aux_base_mod_count_, + aux_base_array_[i].value(), aux_base_mod_count_, tmp_aux_products_all.get()); + set_uint_uint(tmp_aux_products_all.get(), aux_base_mod_count_, + aux_products_all.get()); + } + + // Compute the auxiliary products inverse mod m_sk_ (M-1) mod m_sk_ + inv_aux_products_mod_msk_ = modulo_uint(aux_products_all.get(), + aux_base_mod_count_, m_sk_, pool_); + if (!try_invert_uint_mod(inv_aux_products_mod_msk_, m_sk_, + inv_aux_products_mod_msk_)) + { + reset(); + return; + } + + // Compute auxiliary products all mod coefficient moduli + aux_products_all_mod_coeff_array_ = allocate_uint(coeff_base_mod_count_, pool_); + for (size_t i = 0; i < coeff_base_mod_count_; i++) + { + aux_products_all_mod_coeff_array_[i] = modulo_uint(aux_products_all.get(), + aux_base_mod_count_, coeff_base_array_[i], pool_); + } + + // Compute m_tilde inverse mod bsk base + inv_mtilde_mod_bsk_array_ = allocate_uint(bsk_base_mod_count_, pool_); + for (size_t i = 0; i < aux_base_mod_count_; i++) + { + if (!try_invert_uint_mod(m_tilde_.value() % aux_base_array_[i].value(), + aux_base_array_[i], inv_mtilde_mod_bsk_array_[i])) + { + reset(); + return; + } + } + + // Add m_tilde inverse mod msk at the end of the array + if (!try_invert_uint_mod(m_tilde_.value() % m_sk_.value(), m_sk_, + inv_mtilde_mod_bsk_array_[bsk_base_mod_count_ - 1])) + { + reset(); + return; + } + + // Compute coeff moduli products inverse mod m_tilde + inv_coeff_products_mod_mtilde_ = modulo_uint(coeff_products_all.get(), + coeff_base_mod_count_, m_tilde_, pool_); + if (!try_invert_uint_mod(inv_coeff_products_mod_mtilde_, m_tilde_, + inv_coeff_products_mod_mtilde_)) + { + reset(); + return; + } + + // Compute coeff base products all mod Bsk + coeff_products_all_mod_bsk_array_ = allocate_uint(bsk_base_mod_count_, pool_); + for (size_t i = 0; i < aux_base_mod_count_; i++) + { + coeff_products_all_mod_bsk_array_[i] = + modulo_uint(coeff_products_all.get(), coeff_base_mod_count_, + aux_base_array_[i], pool_); + } + + // Add coeff base products all mod m_sk_ at the end of the array + coeff_products_all_mod_bsk_array_[bsk_base_mod_count_ - 1] = + modulo_uint(coeff_products_all.get(), coeff_base_mod_count_, m_sk_, pool_); + + // Compute inverses of last coeff base modulus modulo the first ones for + // modulus switching/rescaling. + inv_last_coeff_mod_array_ = allocate_uint(coeff_base_mod_count_ - 1, pool_); + for (size_t i = 0; i < coeff_base_mod_count_ - 1; i++) + { + if (!try_mod_inverse(coeff_base_array_[coeff_base_mod_count_ - 1].value(), + coeff_base_array_[i].value(), inv_last_coeff_mod_array_[i])) + { + reset(); + return; + } + } + + // Generate plain gamma array of small_plain_mod_ is set to non-zero. + // Otherwise assume we use CKKS and no plain_modulus is needed. + if (!small_plain_mod_.is_zero()) + { + plain_gamma_array_ = allocate(plain_gamma_count_, pool_); + plain_gamma_array_[0] = small_plain_mod_; + plain_gamma_array_[1] = gamma_; + + // Compute coeff moduli products mod plain gamma + coeff_products_mod_plain_gamma_array_ = + allocate>(plain_gamma_count_, pool_); + for (size_t i = 0; i < plain_gamma_count_; i++) + { + coeff_products_mod_plain_gamma_array_[i] = + allocate_uint(coeff_base_mod_count_, pool_); + for (size_t j = 0; j < coeff_base_mod_count_; j++) + { + coeff_products_mod_plain_gamma_array_[i][j] = + modulo_uint( + coeff_products_array_.get() + (j * coeff_products_uint64_count), + coeff_products_uint64_count, plain_gamma_array_[i], pool_ + ); + } + } + + // Compute inverse of all coeff moduli products mod plain gamma + neg_inv_coeff_products_all_mod_plain_gamma_array_ = + allocate_uint(plain_gamma_count_, pool_); + for (size_t i = 0; i < plain_gamma_count_; i++) + { + uint64_t temp = modulo_uint(coeff_products_all.get(), + coeff_base_mod_count_, plain_gamma_array_[i], pool_); + neg_inv_coeff_products_all_mod_plain_gamma_array_[i] = + negate_uint_mod(temp, plain_gamma_array_[i]); + if (!try_invert_uint_mod(neg_inv_coeff_products_all_mod_plain_gamma_array_[i], + plain_gamma_array_[i], neg_inv_coeff_products_all_mod_plain_gamma_array_[i])) + { + reset(); + return; + } + } + + // Compute inverse of gamma mod plain modulus + inv_gamma_mod_plain_ = modulo_uint(gamma_.data(), gamma_.uint64_count(), + small_plain_mod_, pool_); + if (!try_invert_uint_mod( + inv_gamma_mod_plain_, small_plain_mod_, inv_gamma_mod_plain_)) + { + reset(); + return; + } + + // Compute plain_gamma product mod coeff base moduli + plain_gamma_product_mod_coeff_array_ = + allocate_uint(coeff_base_mod_count_, pool_); + for (size_t i = 0; i < coeff_base_mod_count_; i++) + { + plain_gamma_product_mod_coeff_array_[i] = + multiply_uint_uint_mod(small_plain_mod_.value(), gamma_.value(), + coeff_base_array_[i]); + } + } + + // Everything went well + generated_ = true; + } + + void BaseConverter::reset() noexcept + { + generated_ = false; + coeff_base_array_.release(); + aux_base_array_.release(); + bsk_base_array_.release(); + plain_gamma_array_.release(); + coeff_products_array_.release(); + mtilde_inv_coeff_base_products_mod_coeff_array_.release(); + inv_aux_base_products_mod_aux_array_.release(); + inv_coeff_products_all_mod_aux_bsk_array_.release(); + inv_coeff_base_products_mod_coeff_array_.release(); + aux_base_products_mod_coeff_array_.release(); + coeff_base_products_mod_aux_bsk_array_.release(); + coeff_base_products_mod_mtilde_array_.release(); + aux_base_products_mod_msk_array_.release(); + aux_products_all_mod_coeff_array_.release(); + inv_mtilde_mod_bsk_array_.release(); + coeff_products_all_mod_bsk_array_.release(); + coeff_products_mod_plain_gamma_array_.release(); + neg_inv_coeff_products_all_mod_plain_gamma_array_.release(); + plain_gamma_product_mod_coeff_array_.release(); + bsk_small_ntt_tables_.release(); + inv_last_coeff_mod_array_.release(); + inv_coeff_products_mod_mtilde_ = 0; + m_tilde_ = 0; + m_sk_ = 0; + gamma_ = 0; + coeff_count_ = 0; + coeff_base_mod_count_ = 0; + aux_base_mod_count_ = 0; + plain_gamma_count_ = 0; + inv_gamma_mod_plain_ = 0; + } + + void BaseConverter::fastbconv(const uint64_t *input, + uint64_t *destination, MemoryPoolHandle pool) const + { +#ifdef SEAL_DEBUG + if (input == nullptr) + { + throw invalid_argument("input cannot be null"); + } + if(destination == nullptr) + { + throw invalid_argument("destination cannot be null"); + } + if (!pool) + { + throw invalid_argument("pool is not initialied"); + } + if (!generated_) + { + throw logic_error("BaseConverter is not generated"); + } +#endif + /** + Require: Input in q + Ensure: Output in Bsk = {m1,...,ml} U {msk} + */ + auto temp_coeff_transition(allocate_uint( + mul_safe(coeff_count_, coeff_base_mod_count_), pool)); + for (size_t i = 0; i < coeff_base_mod_count_; i++) + { + uint64_t inv_coeff_base_products_mod_coeff_elt = + inv_coeff_base_products_mod_coeff_array_[i]; + SmallModulus coeff_base_array_elt = coeff_base_array_[i]; + for (size_t k = 0; k < coeff_count_; k++, input++) + { + temp_coeff_transition[i + (k * coeff_base_mod_count_)] = + multiply_uint_uint_mod( + *input, + inv_coeff_base_products_mod_coeff_elt, + coeff_base_array_elt + ); + } + } + + for (size_t j = 0; j < bsk_base_mod_count_; j++) + { + uint64_t *temp_coeff_transition_ptr = temp_coeff_transition.get(); + SmallModulus bsk_base_array_elt = bsk_base_array_[j]; + for (size_t k = 0; k < coeff_count_; k++, destination++) + { + const uint64_t *coeff_base_products_mod_aux_bsk_array_ptr = + coeff_base_products_mod_aux_bsk_array_[j].get(); + unsigned long long aux_transition[2]{ 0, 0 }; + for (size_t i = 0; i < coeff_base_mod_count_; + i++, temp_coeff_transition_ptr++, + coeff_base_products_mod_aux_bsk_array_ptr++) + { + // Lazy reduction + unsigned long long temp[2]; + + // Product is 60 bit + 61 bit = 121 bit, so can sum up to 127 of them with no reduction + // Thus need coeff_base_mod_count_ <= 127 to guarantee success + multiply_uint64(*temp_coeff_transition_ptr, + *coeff_base_products_mod_aux_bsk_array_ptr, temp); + unsigned char carry = add_uint64(aux_transition[0], + temp[0], aux_transition); + aux_transition[1] += temp[1] + carry; + } + *destination = barrett_reduce_128(aux_transition, bsk_base_array_elt); + } + } + } + + void BaseConverter::fastbconv_sk(const uint64_t *input, + uint64_t *destination, MemoryPoolHandle pool) const + { +#ifdef SEAL_DEBUG + if (input == nullptr) + { + throw invalid_argument("input cannot be null"); + } + if (destination == nullptr) + { + throw invalid_argument("destination cannot be null"); + } + if (!pool) + { + throw invalid_argument("pool is not initialied"); + } +#endif + /** + Require: Input in base Bsk = M U {msk} + Ensure: Output in base q + */ + + // Fast convert B -> q + auto temp_coeff_transition(allocate_uint( + mul_safe(coeff_count_, aux_base_mod_count_), pool)); + const uint64_t *input_ptr = input; + for (size_t i = 0; i < aux_base_mod_count_; i++) + { + uint64_t inv_aux_base_products_mod_aux_array_elt = + inv_aux_base_products_mod_aux_array_[i]; + SmallModulus aux_base_array_elt = aux_base_array_[i]; + for (size_t k = 0; k < coeff_count_; k++) + { + temp_coeff_transition[i + (k * aux_base_mod_count_)] = + multiply_uint_uint_mod( + *input_ptr++, + inv_aux_base_products_mod_aux_array_elt, + aux_base_array_elt + ); + } + } + + uint64_t *destination_ptr = destination; + uint64_t *temp_ptr; + for (size_t j = 0; j < coeff_base_mod_count_; j++) + { + temp_ptr = temp_coeff_transition.get(); + SmallModulus coeff_base_array_elt = coeff_base_array_[j]; + for (size_t k = 0; k < coeff_count_; k++, destination_ptr++) + { + const uint64_t *aux_base_products_mod_coeff_array_ptr = + aux_base_products_mod_coeff_array_[j].get(); + unsigned long long aux_transition[2]{ 0, 0 }; + for (size_t i = 0; i < aux_base_mod_count_; i++, temp_ptr++, + aux_base_products_mod_coeff_array_ptr++) + { + // Lazy reduction + unsigned long long temp[2]; + + // Product is 61 bit + 60 bit = 121 bit, so can sum up to 127 of them with no reduction + // Thus need aux_base_mod_count_ <= 127, so coeff_base_mod_count_ <= 126 to guarantee success + multiply_uint64(*temp_ptr, *aux_base_products_mod_coeff_array_ptr, temp); + unsigned char carry = add_uint64(aux_transition[0], temp[0], aux_transition); + aux_transition[1] += temp[1] + carry; + } + *destination_ptr = barrett_reduce_128(aux_transition, coeff_base_array_elt); + } + } + + // Compute alpha_sk + // Require: Input is in Bsk + // we only use coefficient in B + // Fast convert B -> m_sk + auto tmp(allocate_uint(coeff_count_, pool)); + destination_ptr = tmp.get(); + temp_ptr = temp_coeff_transition.get(); + for (size_t k = 0; k < coeff_count_; k++, destination_ptr++) + { + unsigned long long msk_transition[2]{ 0, 0 }; + const uint64_t *aux_base_products_mod_msk_array_ptr = + aux_base_products_mod_msk_array_.get(); + for (size_t i = 0; i < aux_base_mod_count_; i++, temp_ptr++, + aux_base_products_mod_msk_array_ptr++) + { + // Lazy reduction + unsigned long long temp[2]; + + // Product is 61 bit + 61 bit = 122 bit, so can sum up to 63 of them with no reduction + // Thus need aux_base_mod_count_ <= 63, so coeff_base_mod_count_ <= 62 to guarantee success + // This gives the strongest restriction on the number of coeff modulus primes + multiply_uint64(*temp_ptr, *aux_base_products_mod_msk_array_ptr, temp); + unsigned char carry = add_uint64(msk_transition[0], temp[0], msk_transition); + msk_transition[1] += temp[1] + carry; + } + *destination_ptr = barrett_reduce_128(msk_transition, m_sk_); + } + + auto alpha_sk(allocate_uint(coeff_count_, pool)); + input_ptr = input + (aux_base_mod_count_ * coeff_count_); + destination_ptr = alpha_sk.get(); + temp_ptr = tmp.get(); + const uint64_t m_sk_value = m_sk_.value(); + // x_sk is allocated in input[aux_base_mod_count_] + for (size_t i = 0; i < coeff_count_; i++, input_ptr++, temp_ptr++, destination_ptr++) + { + // It is not necessary for the negation to be reduced modulo the small prime + uint64_t negated_input = m_sk_value - *input_ptr; + *destination_ptr = multiply_uint_uint_mod(*temp_ptr + negated_input, + inv_aux_products_mod_msk_, m_sk_); + } + + const uint64_t m_sk_div_2 = m_sk_value >> 1; + destination_ptr = destination; + for (size_t i = 0; i < coeff_base_mod_count_; i++) + { + uint64_t aux_products_all_mod_coeff_array_elt = + aux_products_all_mod_coeff_array_[i]; + temp_ptr = alpha_sk.get(); + SmallModulus coeff_base_array_elt = coeff_base_array_[i]; + uint64_t coeff_base_array_elt_value = coeff_base_array_elt.value(); + for (size_t k = 0; k < coeff_count_; k++, temp_ptr++, destination_ptr++) + { + unsigned long long m_alpha_sk[2]; + + // Correcting alpha_sk since it is a centered modulo + if (*temp_ptr > m_sk_div_2) + { + // Lazy reduction + multiply_uint64(aux_products_all_mod_coeff_array_elt, + m_sk_value - *temp_ptr, m_alpha_sk); + m_alpha_sk[1] += add_uint64(m_alpha_sk[0], *destination_ptr, m_alpha_sk); + *destination_ptr = barrett_reduce_128(m_alpha_sk, coeff_base_array_elt); + } + // No correction needed + else + { + // Lazy reduction + // It is not necessary for the negation to be reduced modulo the small prime + multiply_uint64( + coeff_base_array_elt_value - aux_products_all_mod_coeff_array_elt, + *temp_ptr, m_alpha_sk + ); + m_alpha_sk[1] += add_uint64(*destination_ptr, + m_alpha_sk[0], m_alpha_sk); + *destination_ptr = barrett_reduce_128(m_alpha_sk, coeff_base_array_elt); + } + } + } + } + + void BaseConverter::mont_rq(const uint64_t *input, uint64_t *destination) const + { +#ifdef SEAL_DEBUG + if (input == nullptr) + { + throw invalid_argument("input cannot be null"); + } + if (destination == nullptr) + { + throw invalid_argument("destination cannot be null"); + } +#endif + /** + Require: Input should in Bsk U {m_tilde} + Ensure: Destination array in Bsk = m U {msk} + */ + const uint64_t *input_m_tilde_ptr = + input + mul_safe(coeff_count_, bsk_base_mod_count_); + for (size_t k = 0; k < bsk_base_mod_count_; k++) + { + uint64_t coeff_products_all_mod_bsk_array_elt = + coeff_products_all_mod_bsk_array_[k]; + uint64_t inv_mtilde_mod_bsk_array_elt = inv_mtilde_mod_bsk_array_[k]; + SmallModulus bsk_base_array_elt = bsk_base_array_[k]; + const uint64_t *input_m_tilde_ptr_copy = input_m_tilde_ptr; + + // Compute result for aux base + for (size_t i = 0; i < coeff_count_; i++, destination++, + input_m_tilde_ptr_copy++, input++) + { + // Compute r_mtilde + // Duplicate work here: + // This needs to be computed only once per coefficient, not per Bsk prime. + uint64_t r_mtilde = multiply_uint_uint_mod(*input_m_tilde_ptr_copy, + inv_coeff_products_mod_mtilde_, m_tilde_); + r_mtilde = negate_uint_mod(r_mtilde, m_tilde_); + + // Lazy reduction + unsigned long long tmp[2]; + multiply_uint64(coeff_products_all_mod_bsk_array_elt, r_mtilde, tmp); + tmp[1] += add_uint64(tmp[0], *input, tmp); + r_mtilde = barrett_reduce_128(tmp, bsk_base_array_elt); + *destination = multiply_uint_uint_mod( + r_mtilde, inv_mtilde_mod_bsk_array_elt, bsk_base_array_elt); + } + } + } + + void BaseConverter::fast_floor(const uint64_t *input, + uint64_t *destination, MemoryPoolHandle pool) const + { +#ifdef SEAL_DEBUG + if (input == nullptr) + { + throw invalid_argument("input cannot be null"); + } + if (destination == nullptr) + { + throw invalid_argument("destination cannot be null"); + } + if (!pool) + { + throw invalid_argument("pool is not initialied"); + } +#endif + /** + Require: Input in q U m U {msk} + Ensure: Destination array in Bsk + */ + fastbconv(input, destination, pool); //q -> Bsk + + size_t index_msk = mul_safe(coeff_base_mod_count_, coeff_count_); + input += index_msk; + for (size_t i = 0; i < bsk_base_mod_count_; i++) + { + SmallModulus bsk_base_array_elt = bsk_base_array_[i]; + uint64_t bsk_base_array_value = bsk_base_array_elt.value(); + uint64_t inv_coeff_products_all_mod_aux_bsk_array_elt = + inv_coeff_products_all_mod_aux_bsk_array_[i]; + for (size_t k = 0; k < coeff_count_; k++, input++, destination++) + { + // It is not necessary for the negation to be reduced modulo the small prime + //negate_uint_smallmod(base_convert_Bsk.get() + k + (i * coeff_count_), + // bsk_base_array_[i], &negated_base_convert_Bsk); + *destination = multiply_uint_uint_mod( + *input + bsk_base_array_value - *destination, + inv_coeff_products_all_mod_aux_bsk_array_elt, + bsk_base_array_elt + ); + } + } + } + + void BaseConverter::fastbconv_mtilde(const uint64_t *input, + uint64_t *destination, MemoryPoolHandle pool) const + { +#ifdef SEAL_DEBUG + if (input == nullptr) + { + throw invalid_argument("input cannot be null"); + } + if (destination == nullptr) + { + throw invalid_argument("destination cannot be null"); + } + if (!pool) + { + throw invalid_argument("pool is not initialied"); + } +#endif + /** + Require: Input in q + Ensure: Output in Bsk U {m_tilde} + */ + + // Compute in Bsk first; we compute |m_tilde*q^-1i| mod qi + auto temp_coeff_transition(allocate_uint( + mul_safe(coeff_count_, coeff_base_mod_count_), pool)); + for (size_t i = 0; i < coeff_base_mod_count_; i++) + { + SmallModulus coeff_base_array_elt = coeff_base_array_[i]; + uint64_t mtilde_inv_coeff_base_products_mod_coeff_elt = + mtilde_inv_coeff_base_products_mod_coeff_array_[i]; + for (size_t k = 0; k < coeff_count_; k++, input++) + { + temp_coeff_transition[i + (k * coeff_base_mod_count_)] = + multiply_uint_uint_mod( + *input, + mtilde_inv_coeff_base_products_mod_coeff_elt, + coeff_base_array_elt + ); + } + } + + uint64_t *destination_ptr = destination; + for (size_t j = 0; j < bsk_base_mod_count_; j++) + { + const uint64_t *coeff_base_products_mod_aux_bsk_array_ptr = + coeff_base_products_mod_aux_bsk_array_[j].get(); + uint64_t *temp_coeff_transition_ptr = temp_coeff_transition.get(); + SmallModulus bsk_base_array_elt = bsk_base_array_[j]; + for (size_t k = 0; k < coeff_count_; k++, destination_ptr++) + { + unsigned long long aux_transition[2]{ 0, 0 }; + const uint64_t *temp_ptr = coeff_base_products_mod_aux_bsk_array_ptr; + for (size_t i = 0; i < coeff_base_mod_count_; + i++, temp_ptr++, temp_coeff_transition_ptr++) + { + // Lazy reduction + unsigned long long temp[2]{ 0, 0 }; + + // Product is 60 bit + 61 bit = 121 bit, so can sum up to 127 of them with no reduction + // Thus need coeff_base_mod_count_ <= 127 + multiply_uint64(*temp_coeff_transition_ptr, *temp_ptr, temp); + unsigned char carry = add_uint64(aux_transition[0], + temp[0], aux_transition); + aux_transition[1] += temp[1] + carry; + } + *destination_ptr = barrett_reduce_128(aux_transition, bsk_base_array_elt); + } + } + + // Computing the last element (mod m_tilde) and add it at the end of destination array + uint64_t *temp_coeff_transition_ptr = temp_coeff_transition.get(); + destination += bsk_base_mod_count_ * coeff_count_; + for (size_t k = 0; k < coeff_count_; k++, destination++) + { + unsigned long long wide_result[2]{ 0, 0 }; + const uint64_t *coeff_base_products_mod_mtilde_array_ptr = + coeff_base_products_mod_mtilde_array_.get(); + for (size_t i = 0; i < coeff_base_mod_count_; i++, + temp_coeff_transition_ptr++, + coeff_base_products_mod_mtilde_array_ptr++) + { + // Lazy reduction + unsigned long long aux_transition[2]; + + // Product is 60 bit + 33 bit = 93 bit + multiply_uint64(*temp_coeff_transition_ptr, + *coeff_base_products_mod_mtilde_array_ptr, aux_transition); + unsigned char carry = add_uint64(aux_transition[0], + wide_result[0], wide_result); + wide_result[1] += aux_transition[1] + carry; + } + *destination = barrett_reduce_128(wide_result, m_tilde_); + } + } + + void BaseConverter::fastbconv_plain_gamma(const uint64_t *input, + uint64_t *destination, MemoryPoolHandle pool) const + { +#ifdef SEAL_DEBUG + if (small_plain_mod_.is_zero()) + { + throw logic_error("invalid operation"); + } + if (input == nullptr) + { + throw invalid_argument("input cannot be null"); + } + if (destination == nullptr) + { + throw invalid_argument("destination cannot be null"); + } +#endif + /** + Require: Input in q + Ensure: Output in t (plain modulus) U gamma + */ + auto temp_coeff_transition(allocate_uint( + mul_safe(coeff_count_, coeff_base_mod_count_), pool)); + for (size_t i = 0; i < coeff_base_mod_count_; i++) + { + uint64_t inv_coeff_base_products_mod_coeff_elt = + inv_coeff_base_products_mod_coeff_array_[i]; + SmallModulus coeff_base_array_elt = coeff_base_array_[i]; + for (size_t k = 0; k < coeff_count_; k++, input++) + { + temp_coeff_transition[i + (k * coeff_base_mod_count_)] = + multiply_uint_uint_mod( + *input, + inv_coeff_base_products_mod_coeff_elt, + coeff_base_array_elt + ); + } + } + + for (size_t j = 0; j < plain_gamma_count_; j++) + { + SmallModulus plain_gamma_array_elt = plain_gamma_array_[j]; + uint64_t *temp_coeff_transition_ptr = temp_coeff_transition.get(); + const uint64_t *coeff_products_mod_plain_gamma_array_ptr = + coeff_products_mod_plain_gamma_array_[j].get(); + for (size_t k = 0; k < coeff_count_; k++, destination++) + { + unsigned long long wide_result[2]{ 0, 0 }; + const uint64_t *temp_ptr = coeff_products_mod_plain_gamma_array_ptr; + for (size_t i = 0; i < coeff_base_mod_count_; i++, + temp_coeff_transition_ptr++, temp_ptr++) + { + unsigned long long plain_transition[2]; + + // Lazy reduction + // Product is 60 bit + 61 bit = 121 bit, so can sum up to 127 of them with no reduction + // Thus need coeff_base_mod_count_ <= 127 + multiply_uint64(*temp_coeff_transition_ptr, *temp_ptr, plain_transition); + unsigned char carry = add_uint64(plain_transition[0], + wide_result[0], wide_result); + wide_result[1] += plain_transition[1] + carry; + } + *destination = barrett_reduce_128(wide_result, plain_gamma_array_elt); + } + } + } + } +} diff --git a/src/seal/util/baseconverter.h b/src/seal/util/baseconverter.h new file mode 100644 index 000000000..5fb45927c --- /dev/null +++ b/src/seal/util/baseconverter.h @@ -0,0 +1,272 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include "seal/util/pointer.h" +#include "seal/memorymanager.h" +#include "seal/smallmodulus.h" +#include "seal/util/smallntt.h" +#include "seal/biguint.h" + +namespace seal +{ + namespace util + { + class BaseConverter + { + public: + BaseConverter(MemoryPoolHandle pool) : pool_(std::move(pool)) + { + if (!pool_) + { + throw std::invalid_argument("pool is uninitialized"); + } + } + + BaseConverter(const std::vector &coeff_base, + std::size_t coeff_count, const SmallModulus &small_plain_mod, + MemoryPoolHandle pool); + + /** + Generates the pre-computations for the given parameters. + */ + void generate(const std::vector &coeff_base, + std::size_t coeff_count, const SmallModulus &small_plain_mod); + + /** + Fast base converter from q to Bsk + */ + void fastbconv(const std::uint64_t *input, + std::uint64_t *destination, MemoryPoolHandle pool) const; + + /** + Fast base converter from Bsk to q + */ + void fastbconv_sk(const std::uint64_t *input, + std::uint64_t *destination, MemoryPoolHandle pool) const; + + /** + Reduction from Bsk U {m_tilde} to Bsk + */ + void mont_rq(const std::uint64_t *input, + std::uint64_t *destination) const; + + /** + Fast base converter from q U Bsk to Bsk + */ + void fast_floor(const std::uint64_t *input, + std::uint64_t *destination, MemoryPoolHandle pool) const; + + /** + Fast base converter from q to Bsk U {m_tilde} + */ + void fastbconv_mtilde(const std::uint64_t *input, + std::uint64_t *destination, MemoryPoolHandle pool) const; + + /** + Fast base converter from q to plain_modulus U {gamma} + */ + void fastbconv_plain_gamma(const std::uint64_t *input, + std::uint64_t *destination, MemoryPoolHandle pool) const; + + void reset() noexcept; + + inline auto is_generated() const noexcept + { + return generated_; + } + + inline auto coeff_base_mod_count() const noexcept + { + return coeff_base_mod_count_; + } + + inline auto aux_base_mod_count() const noexcept + { + return aux_base_mod_count_; + } + + inline auto &get_plain_gamma_product() const noexcept + { + return plain_gamma_product_mod_coeff_array_; + } + + inline auto &get_neg_inv_coeff() const noexcept + { + return neg_inv_coeff_products_all_mod_plain_gamma_array_; + } + + inline auto &get_plain_gamma_array() const noexcept + { + return plain_gamma_array_; + } + + inline const std::uint64_t *get_coeff_products_array() const noexcept + { + return coeff_products_array_.get(); + } + + inline std::uint64_t get_inv_gamma() const noexcept + { + return inv_gamma_mod_plain_; + } + + inline auto &get_bsk_small_ntt_tables() const noexcept + { + return bsk_small_ntt_tables_; + } + + inline auto bsk_base_mod_count() const noexcept + { + return bsk_base_mod_count_; + } + + inline auto &get_bsk_mod_array() const noexcept + { + return bsk_base_array_; + } + + inline auto &get_msk() const noexcept + { + return m_sk_; + } + + inline auto &get_m_tilde() const noexcept + { + return m_tilde_; + } + + inline auto &get_mtilde_inv_coeff_products_mod_coeff() const noexcept + { + return mtilde_inv_coeff_base_products_mod_coeff_array_; + } + + inline auto &get_inv_coeff_mod_mtilde() const noexcept + { + return inv_coeff_products_mod_mtilde_; + } + + inline auto &get_inv_coeff_mod_coeff_array() const noexcept + { + return inv_coeff_base_products_mod_coeff_array_; + } + + inline auto &get_inv_last_coeff_mod_array() const noexcept + { + return inv_last_coeff_mod_array_; + } + + inline auto &get_coeff_base_products_mod_msk() const noexcept + { + return coeff_base_products_mod_aux_bsk_array_[bsk_base_mod_count_ - 1]; + } + + private: + BaseConverter(const BaseConverter ©) = delete; + + BaseConverter(BaseConverter &&source) = delete; + + BaseConverter &operator =(const BaseConverter &assign) = delete; + + BaseConverter &operator =(BaseConverter &&assign) = delete; + + MemoryPoolHandle pool_; + + bool generated_ = false; + + std::size_t coeff_base_mod_count_ = 0; + + std::size_t aux_base_mod_count_ = 0; + + std::size_t bsk_base_mod_count_ = 0; + + std::size_t coeff_count_ = 0; + + std::size_t plain_gamma_count_ = 0; + + // Array of coefficient small moduli + Pointer coeff_base_array_; + + // Array of auxiliary moduli + Pointer aux_base_array_; + + // Array of auxiliary U {m_sk_} moduli + Pointer bsk_base_array_; + + // Array of plain modulus U gamma + Pointer plain_gamma_array_; + + // Punctured products of the coeff moduli + Pointer coeff_products_array_; + + // Matrix which contains the products of coeff moduli mod aux + Pointer> coeff_base_products_mod_aux_bsk_array_; + + // Array of inverse coeff modulus products mod each small coeff mods + Pointer inv_coeff_base_products_mod_coeff_array_; + + // Array of coeff moduli products mod m_tilde + Pointer coeff_base_products_mod_mtilde_array_; + + // Array of coeff modulus products times m_tilda mod each coeff modulus + Pointer mtilde_inv_coeff_base_products_mod_coeff_array_; + + // Matrix of the inversion of coeff modulus products mod each auxiliary mods + Pointer inv_coeff_products_all_mod_aux_bsk_array_; + + // Matrix of auxiliary mods products mod each coeff modulus + Pointer> aux_base_products_mod_coeff_array_; + + // Array of inverse auxiliary mod products mod each auxiliary mods + Pointer inv_aux_base_products_mod_aux_array_; + + // Array of auxiliary bases products mod m_sk_ + Pointer aux_base_products_mod_msk_array_; + + // Coeff moduli products inverse mod m_tilde + std::uint64_t inv_coeff_products_mod_mtilde_ = 0; + + // Auxiliary base products mod m_sk_ (m1*m2*...*ml)-1 mod m_sk + std::uint64_t inv_aux_products_mod_msk_ = 0; + + // Gamma inverse mod plain modulus + std::uint64_t inv_gamma_mod_plain_ = 0; + + // Auxiliary base products mod coeff moduli (m1*m2*...*ml) mod qi + Pointer aux_products_all_mod_coeff_array_; + + // Array of m_tilde inverse mod Bsk = m U {msk} + Pointer inv_mtilde_mod_bsk_array_; + + // Array of all coeff base products mod Bsk + Pointer coeff_products_all_mod_bsk_array_; + + // Matrix of coeff base product mod plain modulus and gamma + Pointer> coeff_products_mod_plain_gamma_array_; + + // Array of negative inverse all coeff base product mod plain modulus and gamma + Pointer neg_inv_coeff_products_all_mod_plain_gamma_array_; + + // Array of plain_gamma_product mod coeff base moduli + Pointer plain_gamma_product_mod_coeff_array_; + + // Array of small NTT tables for moduli in Bsk + Pointer bsk_small_ntt_tables_; + + // For modulus switching: inverses of the last coeff base modulus + Pointer inv_last_coeff_mod_array_; + + SmallModulus m_tilde_; + + SmallModulus m_sk_; + + SmallModulus small_plain_mod_; + + SmallModulus gamma_; + }; + } +} diff --git a/src/seal/util/clang.h b/src/seal/util/clang.h new file mode 100644 index 000000000..46f1ed08d --- /dev/null +++ b/src/seal/util/clang.h @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#if SEAL_COMPILER == SEAL_COMPILER_CLANG + +// We require clang >= 5 +#if (__clang_major__ < 5) || not defined(__cplusplus) +#error "SEAL requires __clang_major__ >= 5" +#endif + +// Read in config.h +#include "seal/util/config.h" + +// Are we using MSGSL? +#ifdef SEAL_USE_MSGSL +#include +#endif + +// Are intrinsics enabled? +#ifdef SEAL_USE_INTRIN +#include + +#ifdef SEAL_USE___BUILTIN_CLZLL +#define SEAL_MSB_INDEX_UINT64(result, value) { \ + *result = 63UL - static_cast(__builtin_clzll(value)); \ +} +#endif + +#ifdef SEAL_USE___INT128 +#define SEAL_MULTIPLY_UINT64_HW64(operand1, operand2, hw64) { \ + *hw64 = static_cast( \ + ((static_cast(operand1) \ + * static_cast(operand2)) >> 64)); \ +} + +#define SEAL_MULTIPLY_UINT64(operand1, operand2, result128) { \ + unsigned __int128 product = static_cast(operand1) * operand2;\ + result128[0] = static_cast(product); \ + result128[1] = static_cast(product >> 64); \ +} +#endif + +#ifdef SEAL_USE__ADDCARRY_U64 +#define SEAL_ADD_CARRY_UINT64(operand1, operand2, carry, result) _addcarry_u64( \ + carry, operand1, operand2, result) +#endif + +#ifdef SEAL_USE__SUBBORROW_U64 +#define SEAL_SUB_BORROW_UINT64(operand1, operand2, borrow, result) _subborrow_u64( \ + borrow, operand1, operand2, result) +#endif + +#endif //SEAL_USE_INTRIN + +#endif diff --git a/src/seal/util/clipnormal.cpp b/src/seal/util/clipnormal.cpp new file mode 100644 index 000000000..9f3563e56 --- /dev/null +++ b/src/seal/util/clipnormal.cpp @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include "seal/util/clipnormal.h" + +using namespace std; + +namespace seal +{ + namespace util + { + ClippedNormalDistribution::ClippedNormalDistribution( + result_type mean, + result_type standard_deviation, + result_type max_deviation) : + normal_(mean, standard_deviation), + max_deviation_(max_deviation) + { + // Verify arguments. + if (standard_deviation < 0) + { + throw invalid_argument("standard_deviation"); + } + if (max_deviation < 0) + { + throw invalid_argument("max_deviation"); + } + } + } +} diff --git a/src/seal/util/clipnormal.h b/src/seal/util/clipnormal.h new file mode 100644 index 000000000..f1c1433ae --- /dev/null +++ b/src/seal/util/clipnormal.h @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include + +namespace seal +{ + namespace util + { + class ClippedNormalDistribution + { + public: + using result_type = double; + + using param_type = ClippedNormalDistribution; + + ClippedNormalDistribution(result_type mean, result_type standard_deviation, + result_type max_deviation); + + template + inline result_type operator()(RNG &engine, const param_type &parm) noexcept + { + param(parm); + return operator()(engine); + } + + template + inline result_type operator()(RNG &engine) noexcept + { + result_type mean = normal_.mean(); + while (true) + { + result_type value = normal_(engine); + result_type deviation = std::abs(value - mean); + if (deviation <= max_deviation_) + { + return value; + } + } + } + + inline result_type mean() const noexcept + { + return normal_.mean(); + } + + inline result_type standard_deviation() const noexcept + { + return normal_.stddev(); + } + + inline result_type max_deviation() const noexcept + { + return max_deviation_; + } + + inline result_type min() const noexcept + { + return normal_.mean() - max_deviation_; + } + + inline result_type max() const noexcept + { + return normal_.mean() + max_deviation_; + } + + inline param_type param() const noexcept + { + return *this; + } + + inline void param(const param_type &parm) noexcept + { + *this = parm; + } + + inline void reset() noexcept + { + normal_.reset(); + } + + private: + std::normal_distribution normal_; + + result_type max_deviation_; + }; + } +} diff --git a/src/seal/util/common.h b/src/seal/util/common.h new file mode 100644 index 000000000..a909f0549 --- /dev/null +++ b/src/seal/util/common.h @@ -0,0 +1,575 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "seal/util/defines.h" + +namespace seal +{ + namespace util + { + template + struct is_uint64 : std::conditional< + std::is_integral::value && + std::is_unsigned::value && + (sizeof(T) == sizeof(std::uint64_t)), + std::true_type, std::false_type>::type + { + }; + + template + struct is_uint64 : std::conditional< + is_uint64::value && + is_uint64::value, + std::true_type, std::false_type>::type + { + }; + + template + constexpr bool is_uint64_v = is_uint64::value; + + template::value>, + typename = std::enable_if_t::value>> + inline constexpr bool unsigned_lt(T in1, S in2) noexcept + { + return static_cast(in1) < static_cast(in2); + } + + template::value>, + typename = std::enable_if_t::value>> + inline constexpr bool unsigned_leq(T in1, S in2) noexcept + { + return static_cast(in1) <= static_cast(in2); + } + + template::value>, + typename = std::enable_if_t::value>> + inline constexpr bool unsigned_gt(T in1, S in2) noexcept + { + return static_cast(in1) > static_cast(in2); + } + + template::value>, + typename = std::enable_if_t::value>> + inline constexpr bool unsigned_geq(T in1, S in2) noexcept + { + return static_cast(in1) >= static_cast(in2); + } + + template::value>, + typename = std::enable_if_t::value>> + inline constexpr bool unsigned_eq(T in1, S in2) noexcept + { + return static_cast(in1) == static_cast(in2); + } + + template::value>, + typename = std::enable_if_t::value>> + inline constexpr bool unsigned_neq(T in1, S in2) noexcept + { + return static_cast(in1) != static_cast(in2); + } + + template::value>> + inline constexpr T mul_safe(T in1) noexcept + { + return in1; + } + + template::value>> + inline constexpr T mul_safe(T in1, T in2) + { + SEAL_IF_CONSTEXPR (std::is_unsigned::value) + { + if (in1 && (in2 > std::numeric_limits::max() / in1)) + { + throw std::out_of_range("unsigned overflow"); + } + } + else + { + // Positive inputs + if ((in1 > 0) && (in2 > 0) && + (in2 > std::numeric_limits::max() / in1)) + { + throw std::out_of_range("signed overflow"); + } +#if (SEAL_COMPILER == SEAL_COMPILER_MSVC) && !defined(SEAL_USE_IF_CONSTEXPR) +#pragma warning(push) +#pragma warning(disable: 4146) +#endif + // Negative inputs + else if ((in1 < 0) && (in2 < 0) && + ((-in2) > std::numeric_limits::max() / (-in1))) + { + throw std::out_of_range("signed overflow"); + } + // Negative in1; positive in2 + else if ((in1 < 0) && (in2 > 0) && + (in2 > std::numeric_limits::max() / (-in1))) + { + throw std::out_of_range("signed underflow"); + } +#if (SEAL_COMPILER == SEAL_COMPILER_MSVC) && !defined(SEAL_USE_IF_CONSTEXPR) +#pragma warning(pop) +#endif + // Positive in1; negative in2 + else if ((in1 > 0) && (in2 < 0) && + (in2 < std::numeric_limits::min() / in1)) + { + throw std::out_of_range("signed underflow"); + } + } + return in1 * in2; + } + + template::value>> + inline constexpr T mul_safe(T in1, T in2, Args &&...args) + { + return mul_safe(mul_safe(in1, in2), mul_safe(std::forward(args)...)); + } + + template::value>> + inline constexpr T add_safe(T in1) noexcept + { + return in1; + } + + template::value>> + inline constexpr T add_safe(T in1, T in2) + { + SEAL_IF_CONSTEXPR (std::is_unsigned::value) + { + T result = in1 + in2; + if (result < in1) + { + throw std::out_of_range("unsigned overflow"); + } + return result; + } + else + { + if (in1 > 0 && (in2 > std::numeric_limits::max() - in1)) + { + throw std::out_of_range("signed overflow"); + } + else if (in1 < 0 && + (in2 < std::numeric_limits::min() - in1)) + { + throw std::out_of_range("signed underflow"); + } + return in1 + in2; + } + } + + template::value>> + inline constexpr T add_safe(T in1, T in2, Args &&...args) + { + return add_safe(add_safe(in1, in2), add_safe(std::forward(args)...)); + } + + template::value>> + inline T sub_safe(T in1, T in2) + { + SEAL_IF_CONSTEXPR (std::is_unsigned::value) + { + T result = in1 - in2; + if (result > in1) + { + throw std::out_of_range("unsigned underflow"); + } + return result; + } + else + { + if (in1 < 0 && (in2 > std::numeric_limits::max() + in1)) + { + throw std::out_of_range("signed underflow"); + } + else if (in1 > 0 && + (in2 < std::numeric_limits::min() + in1)) + { + throw std::out_of_range("signed overflow"); + } + return in1 - in2; + } + } + + template::value>, + typename = std::enable_if_t::value>> + inline constexpr bool fits_in(S value SEAL_MAYBE_UNUSED) noexcept + { + SEAL_IF_CONSTEXPR (std::is_same::value) + { + // Same type + return true; + } + + SEAL_IF_CONSTEXPR (sizeof(S) <= sizeof(T)) + { + // Converting to bigger type + SEAL_IF_CONSTEXPR (std::is_integral::value && std::is_integral::value) + { + // Converting to at least equally big integer type + SEAL_IF_CONSTEXPR ((std::is_unsigned::value && std::is_unsigned::value) + || (!std::is_unsigned::value && !std::is_unsigned::value)) + { + // Both either signed or unsigned + return true; + } + else SEAL_IF_CONSTEXPR (std::is_unsigned::value + && std::is_signed::value) + { + // Converting from signed to at least equally big unsigned type + return value >= 0; + } + } + else SEAL_IF_CONSTEXPR (std::is_floating_point::value + && std::is_floating_point::value) + { + // Both floating-point + return true; + } + + // Still need to consider integer-float conversions and all + // unsigned to signed conversions + } + + SEAL_IF_CONSTEXPR (std::is_integral::value && std::is_integral::value) + { + // Both integer types + if (value >= 0) + { + // Non-negative number; compare as std::uint64_t + // Cannot use unsigned_leq with C++14 for lack of `if constexpr' + return static_cast(value) <= + static_cast(std::numeric_limits::max()); + } + else + { + // Negative number; compare as std::int64_t + return (static_cast(value) >= + static_cast(std::numeric_limits::min())); + } + } + else SEAL_IF_CONSTEXPR (std::is_floating_point::value) + { + // Converting to floating-point + return (static_cast(value) <= + static_cast(std::numeric_limits::max())) && + (static_cast(value) >= + -static_cast(std::numeric_limits::max())); + } + else + { + // Converting from floating-point + return (static_cast(value) <= + static_cast(std::numeric_limits::max())) && + (static_cast(value) >= + static_cast(std::numeric_limits::min())); + } + } + + template::value>> + inline constexpr bool sum_fits_in(Args &&...args) + { + return fits_in(add_safe(std::forward(args)...)); + } + + template::value>> + inline constexpr bool sum_fits_in(T in1, Args &&...args) + { + return fits_in(add_safe(in1, std::forward(args)...)); + } + + template::value>> + inline constexpr bool product_fits_in(Args &&...args) + { + return fits_in(mul_safe(std::forward(args)...)); + } + + template::value>> + inline constexpr bool product_fits_in(T in1, Args &&...args) + { + return fits_in(mul_safe(in1, std::forward(args)...)); + } + + template::value>, + typename = std::enable_if_t::value>> + inline T safe_cast(S value) + { + SEAL_IF_CONSTEXPR (!std::is_same::value) + { + if(!fits_in(value)) + { + throw std::out_of_range("cast failed"); + } + } + return static_cast(value); + } + + constexpr int bytes_per_uint64 = sizeof(std::uint64_t); + + constexpr int bytes_per_uint32 = sizeof(std::uint32_t); + + constexpr int uint32_per_uint64 = 2; + + constexpr int bits_per_nibble = 4; + + constexpr int bits_per_byte = 8; + + constexpr int bits_per_uint64 = bytes_per_uint64 * bits_per_byte; + + constexpr int bits_per_uint32 = bytes_per_uint32 * bits_per_byte; + + constexpr int nibbles_per_byte = 2; + + constexpr int nibbles_per_uint64 = bytes_per_uint64 * nibbles_per_byte; + + constexpr std::uint64_t uint64_high_bit = std::uint64_t(1) << (bits_per_uint64 - 1); + + inline constexpr std::uint32_t reverse_bits(std::uint32_t operand) noexcept + { + operand = (((operand & 0xaaaaaaaa) >> 1) | ((operand & 0x55555555) << 1)); + operand = (((operand & 0xcccccccc) >> 2) | ((operand & 0x33333333) << 2)); + operand = (((operand & 0xf0f0f0f0) >> 4) | ((operand & 0x0f0f0f0f) << 4)); + operand = (((operand & 0xff00ff00) >> 8) | ((operand & 0x00ff00ff) << 8)); + return((operand >> 16) | (operand << 16)); + } + + template>> + inline constexpr T reverse_bits(T operand) noexcept + { + return static_cast(reverse_bits(static_cast(operand >> 32))) | + (static_cast(reverse_bits(static_cast(operand & T(0xFFFFFFFF)))) << 32); + } + + inline std::uint32_t reverse_bits(std::uint32_t operand, int bit_count) + { +#ifdef SEAL_DEBUG + if (bit_count < 0 || bit_count > 32) + { + throw std::invalid_argument("bit_count"); + } +#endif + // We need shift by 32 to return zero so convert to uint64_t in-between + return static_cast( + (static_cast(reverse_bits(operand)) >> (32 - bit_count))); + } + + template>> + inline T reverse_bits(T operand, int bit_count) + { +#ifdef SEAL_DEBUG + if (bit_count < 0 || bit_count > 64) + { + throw std::invalid_argument("bit_count"); + } +#endif + // Need return zero on shift by 64 + return (bit_count == 0) ? 0 : (reverse_bits(operand) >> (64 - bit_count)); + } + + inline void get_msb_index_generic(unsigned long *result, std::uint64_t value) + { +#ifdef SEAL_DEBUG + if (result == nullptr) + { + throw std::invalid_argument("result"); + } +#endif + static const unsigned long deBruijnTable64[64] = { + 63, 0, 58, 1, 59, 47, 53, 2, + 60, 39, 48, 27, 54, 33, 42, 3, + 61, 51, 37, 40, 49, 18, 28, 20, + 55, 30, 34, 11, 43, 14, 22, 4, + 62, 57, 46, 52, 38, 26, 32, 41, + 50, 36, 17, 19, 29, 10, 13, 21, + 56, 45, 25, 31, 35, 16, 9, 12, + 44, 24, 15, 8, 23, 7, 6, 5 + }; + + value |= value >> 1; + value |= value >> 2; + value |= value >> 4; + value |= value >> 8; + value |= value >> 16; + value |= value >> 32; + + *result = deBruijnTable64[((value - (value >> 1)) * std::uint64_t(0x07EDD5E59A4E28C2)) >> 58]; + } + + inline int get_significant_bit_count(std::uint64_t value) + { + if (value == 0) + { + return 0; + } + + unsigned long result; + SEAL_MSB_INDEX_UINT64(&result, value); + return static_cast(result + 1); + } + + inline bool is_hex_char(char hex) + { + if (hex >= '0' && hex <= '9') + { + return true; + } + if (hex >= 'A' && hex <= 'F') + { + return true; + } + if (hex >= 'a' && hex <= 'f') + { + return true; + } + return false; + } + + inline char nibble_to_upper_hex(int nibble) + { +#ifdef SEAL_DEBUG + if (nibble < 0 || nibble > 15) + { + throw std::invalid_argument("nibble"); + } +#endif + if (nibble < 10) + { + return static_cast(nibble + static_cast('0')); + } + return static_cast(nibble + static_cast('A') - 10); + } + + inline int hex_to_nibble(char hex) + { + if (hex >= '0' && hex <= '9') + { + return static_cast(hex) - static_cast('0'); + } + if (hex >= 'A' && hex <= 'F') + { + return static_cast(hex) - static_cast('A') + 10; + } + if (hex >= 'a' && hex <= 'f') + { + return static_cast(hex) - static_cast('a') + 10; + } +#ifdef SEAL_DEBUG + throw std::invalid_argument("hex"); +#endif + return -1; + } + + inline SEAL_BYTE *get_uint64_byte(std::uint64_t *value, std::size_t byte_index) + { +#ifdef SEAL_DEBUG + if (value == nullptr) + { + throw std::invalid_argument("value"); + } +#endif + return reinterpret_cast(value) + byte_index; + } + + inline const SEAL_BYTE *get_uint64_byte(const std::uint64_t *value, std::size_t byte_index) + { +#ifdef SEAL_DEBUG + if (value == nullptr) + { + throw std::invalid_argument("value"); + } +#endif + return reinterpret_cast(value) + byte_index; + } + + inline int get_hex_string_bit_count(const char *hex_string, int char_count) + { +#ifdef SEAL_DEBUG + if (hex_string == nullptr && char_count > 0) + { + throw std::invalid_argument("hex_string"); + } + if (char_count < 0) + { + throw std::invalid_argument("char_count"); + } +#endif + for (int i = 0; i < char_count; i++) + { + char hex = *hex_string++; + int nibble = hex_to_nibble(hex); + if (nibble != 0) + { + int nibble_bits = get_significant_bit_count(static_cast(nibble)); + int remaining_nibbles = (char_count - i - 1) * bits_per_nibble; + return nibble_bits + remaining_nibbles; + } + } + return 0; + } + + template::value>> + inline T divide_round_up(T value, T divisor) + { +#ifdef SEAL_DEBUG + if (value < 0) + { + throw std::invalid_argument("value"); + } + if (divisor <= 0) + { + throw std::invalid_argument("divisor"); + } +#endif + return (add_safe(value, divisor - 1)) / divisor; + } + + template + constexpr double epsilon = std::numeric_limits::epsilon(); + + template::value>> + constexpr bool are_close(T value1, T value2) noexcept + { + double scale_factor = std::max({ std::fabs(value1), std::fabs(value2), T{ 1.0 } }); + return std::fabs(value1 - value2) < epsilon * scale_factor; + } + + template::value>> + constexpr bool is_zero(T value) noexcept + { + return value == T{ 0 }; + } + } +} diff --git a/src/seal/util/config.h.in b/src/seal/util/config.h.in new file mode 100644 index 000000000..6c7478a01 --- /dev/null +++ b/src/seal/util/config.h.in @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#define SEAL_VERSION "@SEAL_VERSION@" +#cmakedefine SEAL_DEBUG +#cmakedefine SEAL_USE_IF_CONSTEXPR +#cmakedefine SEAL_USE_MAYBE_UNUSED +#cmakedefine SEAL_USE_STD_BYTE +#cmakedefine SEAL_USE_SHARED_MUTEX +#cmakedefine SEAL_ENFORCE_HE_STD_SECURITY +#cmakedefine SEAL_USE_INTRIN +#cmakedefine SEAL_USE__UMUL128 +#cmakedefine SEAL_USE__BITSCANREVERSE64 +#cmakedefine SEAL_USE___BUILTIN_CLZLL +#cmakedefine SEAL_USE___INT128 +#cmakedefine SEAL_USE__ADDCARRY_U64 +#cmakedefine SEAL_USE__SUBBORROW_U64 +#cmakedefine SEAL_USE_AES_NI_PRNG +#cmakedefine SEAL_USE_MSGSL +#cmakedefine SEAL_USE_MSGSL_SPAN +#cmakedefine SEAL_USE_MSGSL_MULTISPAN diff --git a/src/seal/util/defines.h b/src/seal/util/defines.h new file mode 100644 index 000000000..bdc435cae --- /dev/null +++ b/src/seal/util/defines.h @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +// Debugging help +#define SEAL_ASSERT(condition) { if(!(condition)){ std::cerr << "ASSERT FAILED: " \ + << #condition << " @ " << __FILE__ << " (" << __LINE__ << ")" << std::endl; } } + +// String expansion +#define _SEAL_STRINGIZE(x) #x +#define SEAL_STRINGIZE(x) _SEAL_STRINGIZE(x) + +// Check that double is 64 bits +static_assert(sizeof(double) == 8, "Require sizeof(double) == 8"); + +// Check that int is 32 bits +static_assert(sizeof(int) == 4, "Require sizeof(int) == 4"); + +// Check that unsigned long long is 64 bits +static_assert(sizeof(unsigned long long) == 8, "Require sizeof(unsigned long long) == 8"); + +// Bounds for bit-length of user-defined coefficient moduli +#define SEAL_USER_MOD_BIT_COUNT_MAX 60 +#define SEAL_USER_MOD_BIT_COUNT_MIN 1 + +// Bounds for number of coefficient moduli +#define SEAL_COEFF_MOD_COUNT_MAX 62 +#define SEAL_COEFF_MOD_COUNT_MIN 1 + +// Bounds for polynomial modulus degree +#define SEAL_POLY_MOD_DEGREE_MAX 32768 +#define SEAL_POLY_MOD_DEGREE_MIN 2 + +// Bounds for the plaintext modulus +#define SEAL_PLAIN_MOD_MIN 2 +#define SEAL_PLAIN_MOD_MAX (std::uint64_t(1) << SEAL_USER_MOD_BIT_COUNT_MAX) - 1 + +// Upper bound on the size of a ciphertext +#define SEAL_CIPHERTEXT_SIZE_MIN 2 +#define SEAL_CIPHERTEXT_SIZE_MAX 32768 + +// Bounds for decomposition bit count +#define SEAL_DBC_MAX 60 +#define SEAL_DBC_MIN 1 + +// Bounds for number of relinearization keys +#define SEAL_RELIN_KEY_COUNT_MAX 8 +#define SEAL_RELIN_KEY_COUNT_MIN 1 + +// Use std::byte as byte type +#if defined(SEAL_USE_STD_BYTE) +#include +namespace seal +{ + using SEAL_BYTE = std::byte; +} +#else +namespace seal +{ + enum class SEAL_BYTE : unsigned char {}; +} +#endif + +// Detect compiler +#define SEAL_COMPILER_MSVC 1 +#define SEAL_COMPILER_CLANG 2 +#define SEAL_COMPILER_GCC 3 + +#if defined(_MSC_VER) +#define SEAL_COMPILER SEAL_COMPILER_MSVC +#elif defined(__clang__) +#define SEAL_COMPILER SEAL_COMPILER_CLANG +#elif defined(__GNUC__) && !defined(__clang__) +#define SEAL_COMPILER SEAL_COMPILER_GCC +#endif + +// MSVC support +#include "seal/util/msvc.h" + +// clang support +#include "seal/util/clang.h" + +// gcc support +#include "seal/util/gcc.h" + +// Create a true/false value for indicating debug mode +#ifdef SEAL_DEBUG +#define SEAL_DEBUG_V true +#else +#define SEAL_DEBUG_V false +#endif + +// Use `if constexpr' from C++17 +#ifdef SEAL_USE_IF_CONSTEXPR +#define SEAL_IF_CONSTEXPR if constexpr +#else +#define SEAL_IF_CONSTEXPR if +#endif + +// Use [[maybe_unused]] from C++17 +#ifdef SEAL_USE_MAYBE_UNUSED +#define SEAL_MAYBE_UNUSED [[maybe_unused]] +#else +#define SEAL_MAYBE_UNUSED +#endif + +// Which random number generator factory to use by default +#ifdef SEAL_USE_AES_NI_PRNG +// AES-PRNG with seed from std::random_device +#define SEAL_DEFAULT_RNG_FACTORY FastPRNGFactory() +#else +// std::random_device +#define SEAL_DEFAULT_RNG_FACTORY StandardRandomAdapterFactory +#endif + +// Use generic functions as (slower) fallback +#ifndef SEAL_ADD_CARRY_UINT64 +#define SEAL_ADD_CARRY_UINT64(operand1, operand2, carry, result) add_uint64_generic(operand1, operand2, carry, result) +#endif + +#ifndef SEAL_SUB_BORROW_UINT64 +#define SEAL_SUB_BORROW_UINT64(operand1, operand2, borrow, result) sub_uint64_generic(operand1, operand2, borrow, result) +#endif + +#ifndef SEAL_MULTIPLY_UINT64 +#define SEAL_MULTIPLY_UINT64(operand1, operand2, result128) { \ + multiply_uint64_generic(operand1, operand2, result128); \ +} +#endif + +#ifndef SEAL_MULTIPLY_UINT64_HW64 +#define SEAL_MULTIPLY_UINT64_HW64(operand1, operand2, hw64) { \ + multiply_uint64_hw64_generic(operand1, operand2, hw64); \ +} +#endif + +#ifndef SEAL_MSB_INDEX_UINT64 +#define SEAL_MSB_INDEX_UINT64(result, value) get_msb_index_generic(result, value) +#endif + +// Multiplication by a plaintext zero should not be allowed, and by default SEAL +// throws an exception in this case. For performance reasons one might want to +// undefine this if appropriate checks are guaranteed to be performed elsewhere. +#define SEAL_THROW_ON_MULTIPLY_PLAIN_BY_ZERO + +// HomomorphicEncryption.org security tables only support dimensions up to 32768 +#ifdef SEAL_ENFORCE_HE_STD_SECURITY + static_assert(SEAL_POLY_MOD_DEGREE_MAX <= 32768, "SEAL_POLY_MOD_DEGREE_MAX too large"); +#endif diff --git a/src/seal/util/gcc.h b/src/seal/util/gcc.h new file mode 100644 index 000000000..490e96a64 --- /dev/null +++ b/src/seal/util/gcc.h @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#if SEAL_COMPILER == SEAL_COMPILER_GCC + +// We require GCC >= 6 +#if (__GNUC__ < 6) || not defined(__cplusplus) +#pragma GCC error "SEAL requires __GNUC__ >= 6" +#endif + +// Read in config.h +#include "seal/util/config.h" + +#if (__GNUC__ == 6) && defined(SEAL_USE_IF_CONSTEXPR) +#pragma GCC error "g++-6 cannot compile SEAL as C++17; set CMake build option `SEAL_USE_CXX17' to OFF" +#endif + +// Are we using MSGSL? +#ifdef SEAL_USE_MSGSL +#include +#endif + +// Are intrinsics enabled? +#ifdef SEAL_USE_INTRIN +#include + +#ifdef SEAL_USE___BUILTIN_CLZLL +#define SEAL_MSB_INDEX_UINT64(result, value) { \ + *result = 63UL - static_cast(__builtin_clzll(value)); \ +} +#endif + +#ifdef SEAL_USE___INT128 +#define SEAL_MULTIPLY_UINT64_HW64(operand1, operand2, hw64) { \ + *hw64 = static_cast( \ + ((static_cast(operand1) \ + * static_cast(operand2)) >> 64)); \ +} + +#define SEAL_MULTIPLY_UINT64(operand1, operand2, result128) { \ + unsigned __int128 product = static_cast(operand1) * operand2;\ + result128[0] = static_cast(product); \ + result128[1] = static_cast(product >> 64); \ +} +#endif + +#ifdef SEAL_USE__ADDCARRY_U64 +#define SEAL_ADD_CARRY_UINT64(operand1, operand2, carry, result) _addcarry_u64( \ + carry, operand1, operand2, result) +#endif + +#ifdef SEAL_USE__SUBBORROW_U64 +#if ((__GNUC__ == 7) && (__GNUC_MINOR__ >= 2)) || (__GNUC__ >= 8) +// The inverted arguments problem was fixed in GCC-7.2 +// (https://patchwork.ozlabs.org/patch/784309/) +#define SEAL_SUB_BORROW_UINT64(operand1, operand2, borrow, result) _subborrow_u64( \ + borrow, operand1, operand2, result) +#else +// Warning: Note the inverted order of operand1 and operand2 +#define SEAL_SUB_BORROW_UINT64(operand1, operand2, borrow, result) _subborrow_u64( \ + borrow, operand2, operand1, result) +#endif //(__GNUC__ == 7) && (__GNUC_MINOR__ >= 2) +#endif + +#endif //SEAL_USE_INTRIN + +#endif diff --git a/src/seal/util/globals.cpp b/src/seal/util/globals.cpp new file mode 100644 index 000000000..ddd3aa319 --- /dev/null +++ b/src/seal/util/globals.cpp @@ -0,0 +1,333 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include "seal/util/globals.h" +#include "seal/smallmodulus.h" + +using namespace std; + +namespace seal +{ + namespace util + { + namespace global_variables + { + std::unique_ptr const global_memory_pool(new MemoryPoolMT); +#ifndef _M_CEE + thread_local std::unique_ptr const tls_memory_pool(new MemoryPoolST); +#else +#pragma message("WARNING: Thread-local memory pools disabled to support /clr") +#endif + const map> default_coeff_modulus_128 + { + /* + Polynomial modulus: 1x^1024 + 1 + Modulus count: 1 + Total bit count: 27 + */ + { 1024,{ + 0x7e00001 + } }, + + /* + Polynomial modulus: 1x^2048 + 1 + Modulus count: 1 + Total bit count: 54 + */ + { 2048,{ + 0x3fffffff000001 + } }, + + /* + Polynomial modulus: 1x^4096 + 1 + Modulus count: 2 + Total bit count: 109 = 55 + 54 + */ + { 4096,{ + 0x7fffffff380001, 0x3fffffff000001 + } }, + + /* + Polynomial modulus: 1x^8192 + 1 + Modulus count: 4 + Total bit count: 218 = 2 * 55 + 2 * 54 + */ + { 8192,{ + 0x7fffffff380001, 0x7ffffffef00001, + 0x3fffffff000001, 0x3ffffffef40001 + } }, + + /* + Polynomial modulus: 1x^16384 + 1 + Modulus count: 8 + Total bit count: 438 = 6 * 55 + 2 * 54 + */ + { 16384,{ + 0x7fffffff380001, 0x7ffffffef00001, + 0x7ffffffeac0001, 0x7ffffffe700001, + 0x7ffffffe600001, 0x7ffffffe4c0001, + 0x3fffffff000001, 0x3ffffffef40001 + } }, + + /* + Polynomial modulus: 1x^32768 + 1 + Modulus count: 15 + Total bit count: 881 = 11 * 59 + 4 * 58 + */ + { 32768,{ + 0x7ffffffffcc0001, 0x7ffffffffb00001, 0x7ffffffff2c0001, + 0x7ffffffff240001, 0x7fffffffe900001, 0x7fffffffe3c0001, + 0x7fffffffe240001, 0x7fffffffddc0001, 0x7fffffffd740001, + 0x7fffffffd640001, 0x7fffffffd080001, 0x3ffffffff040001, + 0x3fffffffed00001, 0x3fffffffeb00001, 0x3fffffffea00001 + } } + }; + + const map> default_coeff_modulus_192 + { + /* + Polynomial modulus: 1x^1024 + 1 + Modulus count: 1 + Total bit count: 19 + */ + { 1024,{ + 0x7f001 + } }, + + /* + Polynomial modulus: 1x^2048 + 1 + Modulus count: 1 + Total bit count: 37 + */ + { 2048,{ + 0x1ffffc0001 + } }, + + /* + Polynomial modulus: 1x^4096 + 1 + Modulus count: 2 + Total bit count: 75 = 38 + 37 + */ + { 4096,{ + 0x3fffe80001, 0x1ffffc0001 + } }, + + /* + Polynomial modulus: 1x^8192 + 1 + Modulus count: 3 + Total bit count: 152 = 2 * 51 + 50 + */ + { 8192,{ + 0x7ffffff9c0001, 0x7ffffff900001, 0x3ffffffb80001 + } }, + + /* + Polynomial modulus: 1x^16384 + 1 + Modulus count: 5 + Total bit count: 300 = 5 * 60 + */ + { 16384,{ + 0xffffffffffc0001, 0xfffffffff840001, + 0xfffffffff240001, 0xffffffffe7c0001, + 0xffffffffe740001 + + } }, + + /* + Polynomial modulus: 1x^32768 + 1 + Modulus count: 10 + Total bit count: 600 = 10 * 60 + */ + { 32768,{ + 0xffffffffffc0001, 0xfffffffff840001, + 0xfffffffff240001, 0xffffffffe7c0001, + 0xffffffffe740001, 0xffffffffe4c0001, + 0xffffffffe440001, 0xffffffffe400001, + 0xffffffffdbc0001, 0xffffffffd840001 + } } + }; + + const map> default_coeff_modulus_256 + { + /* + Polynomial modulus: 1x^1024 + 1 + Modulus count: 1 + Total bit count: 14 + */ + { 1024,{ + 0x3001 + } }, + + /* + Polynomial modulus: 1x^2048 + 1 + Modulus count: 1 + Total bit count: 29 + */ + { 2048,{ + 0x1ffc0001 + } }, + + /* + Polynomial modulus: 1x^4096 + 1 + Modulus count: 1 + Total bit count: 58 + */ + { 4096,{ + 0x3ffffffff040001 + } }, + + /* + Polynomial modulus: 1x^8192 + 1 + Modulus count: 2 + Total bit count: 118 = 2 * 59 + */ + { 8192,{ + 0x7ffffffffcc0001, 0x7ffffffffb00001 + } }, + + /* + Polynomial modulus: 1x^16384 + 1 + Modulus count: 4 + Total bit count: 237 = 60 + 3 * 59 + */ + { 16384,{ + 0xffffffffffc0001, 0x7ffffffffcc0001, + 0x7ffffffffb00001, 0x7ffffffff2c0001 + } }, + + /* + Polynomial modulus: 1x^32768 + 1 + Modulus count: 8 + Total bit count: 476 = 4 * 60 + 4 * 59 + */ + { 32768,{ + 0xffffffffffc0001, 0xfffffffff840001, + 0xfffffffff240001, 0xffffffffe7c0001, + 0x7ffffffffcc0001, 0x7ffffffffb00001, + 0x7ffffffff2c0001, 0x7ffffffff240001 + } } + }; + + const vector small_mods_60bit{ + 0xffffffffffc0001, 0xfffffffff840001, 0xfffffffff240001, 0xffffffffe7c0001, + 0xffffffffe740001, 0xffffffffe4c0001, 0xffffffffe440001, 0xffffffffe400001, + 0xffffffffdbc0001, 0xffffffffd840001, 0xffffffffd680001, 0xffffffffd000001, + 0xffffffffcf00001, 0xffffffffcdc0001, 0xffffffffcc40001, 0xffffffffc300001, + 0xffffffffbf40001, 0xffffffffbdc0001, 0xffffffffb880001, 0xffffffffaec0001, + 0xffffffffa380001, 0xffffffffa200001, 0xffffffffa0c0001, 0xffffffff9600001, + 0xffffffff91c0001, 0xffffffff8f40001, 0xffffffff8680001, 0xffffffff7e40001, + 0xffffffff7bc0001, 0xffffffff76c0001, 0xffffffff7680001, 0xffffffff6fc0001, + 0xffffffff6880001, 0xffffffff6340001, 0xffffffff5d40001, 0xffffffff54c0001, + 0xffffffff4d40001, 0xffffffff4380001, 0xffffffff3e80001, 0xffffffff37c0001, + 0xffffffff36c0001, 0xffffffff2100001, 0xffffffff1d80001, 0xffffffff1cc0001, + 0xffffffff1900001, 0xffffffff1740001, 0xffffffff15c0001, 0xffffffff0e80001, + 0xfffffffeff80001, 0xfffffffeff40001, 0xfffffffeefc0001, 0xfffffffee8c0001, + 0xfffffffede40001, 0xfffffffedcc0001, 0xfffffffed040001, 0xfffffffecf40001, + 0xfffffffecec0001, 0xfffffffecb00001, 0xfffffffec380001, 0xfffffffebb40001, + 0xfffffffeb200001, 0xfffffffeaf40001, 0xfffffffea700001, 0xfffffffea400001 + }; + + const vector small_mods_50bit{ + 0x3ffffffb80001, 0x3fffffec80001, 0x3fffffea40001, 0x3fffffe940001, + 0x3fffffdd40001, 0x3fffffd900001, 0x3fffffd540001, 0x3fffffd500001, + 0x3fffffcc40001, 0x3fffffcb40001, 0x3fffffc600001, 0x3fffffc4c0001, + 0x3fffffc3c0001, 0x3fffffc240001, 0x3fffffc0c0001, 0x3fffffbb00001, + 0x3fffffbac0001, 0x3fffffb800001, 0x3fffffb7c0001, 0x3fffffb580001, + 0x3fffffafc0001, 0x3fffffaf80001, 0x3fffffaf00001, 0x3fffffac00001, + 0x3fffffaa40001, 0x3fffffa440001, 0x3fffffa0c0001, 0x3fffff9a00001, + 0x3fffff9640001, 0x3fffff9300001, 0x3fffff8b80001, 0x3fffff8740001, + 0x3fffff8340001, 0x3fffff7ec0001, 0x3fffff7e40001, 0x3fffff76c0001, + 0x3fffff6e80001, 0x3fffff6900001, 0x3fffff6600001, 0x3fffff6580001, + 0x3fffff6100001, 0x3fffff5d40001, 0x3fffff5ac0001, 0x3fffff55c0001, + 0x3fffff5400001, 0x3fffff5040001, 0x3fffff4b00001, 0x3fffff4680001, + 0x3fffff4080001, 0x3fffff3880001, 0x3fffff3400001, 0x3fffff30c0001, + 0x3fffff2f80001, 0x3fffff2280001, 0x3fffff21c0001, 0x3fffff1e40001, + 0x3fffff1080001, 0x3fffff0fc0001, 0x3fffff0d00001, 0x3fffff07c0001, + 0x3fffff0540001, 0x3fffff00c0001, 0x3fffff0040001, 0x3ffffefd00001 + }; + + const vector small_mods_40bit{ + 0xffffe80001, 0xffffc40001, 0xffff940001, 0xffff780001, + 0xffff580001, 0xffff480001, 0xffff340001, 0xfffeb00001, + 0xfffe680001, 0xfffe2c0001, 0xfffe100001, 0xfffd800001, + 0xfffd080001, 0xfffca00001, 0xfffc940001, 0xfffc880001, + 0xfffc640001, 0xfffc600001, 0xfffc540001, 0xfffbf40001, + 0xfffbdc0001, 0xfffbb80001, 0xfffba00001, 0xfffb340001, + 0xfffaf80001, 0xfffaf00001, 0xfffad80001, 0xfffa800001, + 0xfffa780001, 0xfffa6c0001, 0xfffa5c0001, 0xfffa240001, + 0xfffa140001, 0xfff9a80001, 0xfff9880001, 0xfff9240001, + 0xfff9040001, 0xfff8dc0001, 0xfff8ac0001, 0xfff8a40001, + 0xfff8800001, 0xfff8440001, 0xfff8340001, 0xfff8080001, + 0xfff7ec0001, 0xfff6dc0001, 0xfff6cc0001, 0xfff67c0001, + 0xfff6780001, 0xfff6100001, 0xfff58c0001, 0xfff5440001, + 0xfff51c0001, 0xfff4d40001, 0xfff3c00001, 0xfff3940001, + 0xfff36c0001, 0xfff3400001, 0xfff2c80001, 0xfff2b00001, + 0xfff2680001, 0xfff2440001, 0xfff1e00001, 0xfff1b40001 + }; + + const vector small_mods_30bit{ + 0x3ffc0001, 0x3fac0001, 0x3f540001, 0x3ef80001, + 0x3ef40001, 0x3ed00001, 0x3ebc0001, 0x3eb00001, + 0x3e880001, 0x3e500001, 0x3dd40001, 0x3dcc0001, + 0x3cfc0001, 0x3cc40001, 0x3cb40001, 0x3c840001, + 0x3c600001, 0x3c3c0001, 0x3c100001, 0x3bf80001, + 0x3be80001, 0x3be00001, 0x3b800001, 0x3b580001, + 0x3b340001, 0x3ac00001, 0x3aa40001, 0x3a6c0001, + 0x3a5c0001, 0x3a440001, 0x3a300001, 0x3a200001, + 0x39f00001, 0x39e40001, 0x39c40001, 0x39640001, + 0x39600001, 0x39280001, 0x391c0001, 0x39100001, + 0x38b80001, 0x38a00001, 0x388c0001, 0x38680001, + 0x38400001, 0x38100001, 0x37f00001, 0x37c00001, + 0x379c0001, 0x37300001, 0x37200001, 0x36d00001, + 0x36cc0001, 0x36c00001, 0x367c0001, 0x36700001, + 0x36340001, 0x36240001, 0x361c0001, 0x36180001, + 0x36100001, 0x35d40001, 0x35ac0001, 0x35a00001 + }; + + namespace internal_mods + { + const SmallModulus m_sk(0x1fffffffffe00001); + + const SmallModulus m_tilde(uint64_t(1) << 32); + + const SmallModulus gamma(0x1fffffffffc80001); + + const vector aux_small_mods{ + 0x1fffffffffb40001, 0x1fffffffff500001, 0x1fffffffff380001, 0x1fffffffff000001, + 0x1ffffffffef00001, 0x1ffffffffee80001, 0x1ffffffffeb40001, 0x1ffffffffe780001, + 0x1ffffffffe600001, 0x1ffffffffe4c0001, 0x1ffffffffdf40001, 0x1ffffffffdac0001, + 0x1ffffffffda40001, 0x1ffffffffc680001, 0x1ffffffffc000001, 0x1ffffffffb880001, + 0x1ffffffffb7c0001, 0x1ffffffffb300001, 0x1ffffffffb1c0001, 0x1ffffffffadc0001, + 0x1ffffffffa400001, 0x1ffffffffa140001, 0x1ffffffff9d80001, 0x1ffffffff9140001, + 0x1ffffffff8ac0001, 0x1ffffffff8a80001, 0x1ffffffff81c0001, 0x1ffffffff7800001, + 0x1ffffffff7680001, 0x1ffffffff7080001, 0x1ffffffff6c80001, 0x1ffffffff6140001, + 0x1ffffffff5f40001, 0x1ffffffff5700001, 0x1ffffffff4bc0001, 0x1ffffffff4380001, + 0x1ffffffff3240001, 0x1ffffffff2dc0001, 0x1ffffffff1a40001, 0x1ffffffff11c0001, + 0x1ffffffff0fc0001, 0x1ffffffff0d80001, 0x1ffffffff0c80001, 0x1ffffffff08c0001, + 0x1fffffffefd00001, 0x1fffffffef9c0001, 0x1fffffffef600001, 0x1fffffffeef40001, + 0x1fffffffeed40001, 0x1fffffffeed00001, 0x1fffffffeebc0001, 0x1fffffffed540001, + 0x1fffffffed440001, 0x1fffffffed2c0001, 0x1fffffffed200001, 0x1fffffffec940001, + 0x1fffffffec6c0001, 0x1fffffffebe80001, 0x1fffffffebac0001, 0x1fffffffeba40001, + 0x1fffffffeb4c0001, 0x1fffffffeb280001, 0x1fffffffea780001, 0x1fffffffea440001, + 0x1fffffffe9f40001, 0x1fffffffe97c0001, 0x1fffffffe9300001, 0x1fffffffe8d00001, + 0x1fffffffe8400001, 0x1fffffffe7cc0001, 0x1fffffffe7bc0001, 0x1fffffffe7a80001, + 0x1fffffffe7600001, 0x1fffffffe7500001, 0x1fffffffe6fc0001, 0x1fffffffe6d80001, + 0x1fffffffe6ac0001, 0x1fffffffe6000001, 0x1fffffffe5d40001, 0x1fffffffe5a00001, + 0x1fffffffe5940001, 0x1fffffffe54c0001, 0x1fffffffe5340001, 0x1fffffffe4bc0001, + 0x1fffffffe4a40001, 0x1fffffffe3fc0001, 0x1fffffffe3540001, 0x1fffffffe2b00001, + 0x1fffffffe2680001, 0x1fffffffe0480001, 0x1fffffffe00c0001, 0x1fffffffdfd00001, + 0x1fffffffdfc40001, 0x1fffffffdf700001, 0x1fffffffdf340001, 0x1fffffffdef80001, + 0x1fffffffdea80001, 0x1fffffffde680001, 0x1fffffffde000001, 0x1fffffffdde40001, + 0x1fffffffddd80001, 0x1fffffffddd00001, 0x1fffffffddb40001, 0x1fffffffdd780001, + 0x1fffffffdd4c0001, 0x1fffffffdcb80001, 0x1fffffffdca40001, 0x1fffffffdc380001, + 0x1fffffffdc040001, 0x1fffffffdbb40001, 0x1fffffffdba80001, 0x1fffffffdb9c0001, + 0x1fffffffdb740001, 0x1fffffffdb380001, 0x1fffffffda600001, 0x1fffffffda340001, + 0x1fffffffda180001, 0x1fffffffd9700001, 0x1fffffffd9680001, 0x1fffffffd9440001, + 0x1fffffffd9080001, 0x1fffffffd8c80001, 0x1fffffffd8800001, 0x1fffffffd82c0001, + 0x1fffffffd7cc0001, 0x1fffffffd7b80001, 0x1fffffffd7840001, 0x1fffffffd73c0001 + }; + } + } + } +} diff --git a/src/seal/util/globals.h b/src/seal/util/globals.h new file mode 100644 index 000000000..ea2eadf68 --- /dev/null +++ b/src/seal/util/globals.h @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include +#include +#include "seal/util/defines.h" +#include "seal/util/hestdparms.h" + +namespace seal +{ + class SmallModulus; + + namespace util + { + class MemoryPool; + + namespace global_variables + { + extern std::unique_ptr const global_memory_pool; + +/* +For .NET Framework wrapper support (C++/CLI) we need to + (1) compile the MemoryManager class as thread-unsafe because C++ + mutexes cannot be brought through C++/CLI layer; + (2) disable thread-safe memory pools. +*/ +#ifndef _M_CEE + extern thread_local std::unique_ptr const tls_memory_pool; +#endif + /** + HomomorphicEncryption.org security tables provide an upper bound for the bit-length + of the coefficient modulus up to poly_modulus_degree 65536. + */ + const std::map + max_secure_coeff_modulus_bit_count { SEAL_HE_STD_PARMS_128_TC }; + + /** + Default value for the standard deviation of the noise (error) distribution. + */ + constexpr double default_noise_standard_deviation = + SEAL_HE_STD_PARMS_ERROR_STD_DEV; + + constexpr double noise_distribution_width_multiplier = 6; + + /** + This data structure is a key-value storage that maps degrees of the polynomial modulus + to vectors of SmallModulus elements so that when used with the default value for the + standard deviation of the noise distribution (noise_standard_deviation), the security + level is at least 128 bits according to http://HomomorphicEncryption.org. This makes + it easy for non-expert users to select secure parameters. + */ + extern const std::map> default_coeff_modulus_128; + + /** + This data structure is a key-value storage that maps degrees of the polynomial modulus + to vectors of SmallModulus elements so that when used with the default value for the + standard deviation of the noise distribution (noise_standard_deviation), the security + level is at least 192 bits according to http://HomomorphicEncryption.org. This makes + it easy for non-expert users to select secure parameters. + */ + extern const std::map> default_coeff_modulus_192; + + /** + This data structure is a key-value storage that maps degrees of the polynomial modulus + to vectors of SmallModulus elements so that when used with the default value for the + standard deviation of the noise distribution (noise_standard_deviation), the security + level is at least 256 bits according to http://HomomorphicEncryption.org. This makes + it easy for non-expert users to select secure parameters. + */ + extern const std::map> default_coeff_modulus_256; + + /** + In SEAL the encryption parameter coeff_modulus is a vector of prime numbers + represented by instances of the SmallModulus class. We present here vectors + of pre-selected primes that the user can choose from. These are the largest + 60-bit, 50-bit, 40-bit, 30-bit primes that are congruent to 1 modulo 2^18. + The primes presented here work for poly_modulus up to degree 131072. + + The user can also use their own primes. The only restriction is that they + must be at most 60 bits in length, and need to be congruent to 1 modulo + 2 * poly_modulus_degree. + */ + extern const std::vector small_mods_60bit; + + extern const std::vector small_mods_50bit; + + extern const std::vector small_mods_40bit; + + extern const std::vector small_mods_30bit; + + // For internal use only, do not modify + namespace internal_mods + { + // Prime, 61 bits, and congruent to 1 mod 2^18 + extern const SmallModulus m_sk; + + // 33 bits + extern const SmallModulus m_tilde; + + // Prime, 61 bits, and congruent to 1 mod 2^18 + extern const SmallModulus gamma; + + // For internal use only, all primes 61 bits and congruent to 1 mod 2^18 + extern const std::vector aux_small_mods; + } + } + } +} diff --git a/src/seal/util/hash.cpp b/src/seal/util/hash.cpp new file mode 100644 index 000000000..b7bdef1ac --- /dev/null +++ b/src/seal/util/hash.cpp @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/util/hash.h" +#include "seal/util/common.h" +#include "seal/util/uintcore.h" +#include "seal/util/pointer.h" +#include "seal/util/globals.h" +#include "seal/memorymanager.h" +#include + +using namespace std; + +namespace seal +{ + namespace util + { + // For C++14 compatibility need to define static constexpr + // member variables with no initialization here. + constexpr std::uint64_t HashFunction::sha3_round_consts[sha3_round_count]; + + constexpr std::uint8_t HashFunction::sha3_rho[24]; + + constexpr HashFunction::sha3_block_type HashFunction::sha3_zero_block; + + void HashFunction::keccak_1600(sha3_state_type &state) noexcept + { + for (uint8_t round = 0; round < sha3_round_count; round++) + { + // theta + uint64_t C[5]; + uint64_t D[5]; + for (uint8_t x = 0; x < 5; x++) + { + C[x] = state[x][0]; + for (uint8_t y = 1; y < 5; y++) + { + C[x] ^= state[x][y]; + } + } + for (uint8_t x = 0; x < 5; x++) + { + D[x] = C[(x + 4) % 5] ^ rot(C[(x + 1) % 5], 1); + for (uint8_t y = 0; y < 5; y++) + { + state[x][y] ^= D[x]; + } + } + + // rho and pi + uint64_t ind_x = 1; + uint64_t ind_y = 0; + uint64_t curr = state[ind_x][ind_y]; + for (uint8_t i = 0; i < 24; i++) + { + uint64_t ind_X = ind_y; + uint64_t ind_Y = (2 * ind_x + 3 * ind_y) % 5; + uint64_t temp = state[ind_X][ind_Y]; + state[ind_X][ind_Y] = rot(curr, sha3_rho[i]); + curr = temp; + ind_x = ind_X; + ind_y = ind_Y; + } + + // xi + for (uint8_t y = 0; y < 5; y++) + { + for (uint8_t x = 0; x < 5; x++) + { + C[x] = state[x][y]; + } + for (uint8_t x = 0; x < 5; x++) + { + state[x][y] = C[x] ^ ((~C[(x + 1) % 5]) & C[(x + 2) % 5]); + } + } + + // iota + state[0][0] ^= sha3_round_consts[round]; + } + } + + void HashFunction::sha3_hash(const uint64_t *input, size_t uint64_count, + sha3_block_type &sha3_block) + { +#ifdef SEAL_DEBUG + if (input == nullptr) + { + throw invalid_argument("input cannot be null"); + } +#endif + // Padding + auto pool = MemoryManager::GetPool(); + size_t padded_uint64_count = sha3_rate_uint64_count * ((uint64_count / sha3_rate_uint64_count) + 1); + auto padded_input(allocate_uint(padded_uint64_count, pool)); + set_uint_uint(input, uint64_count, padded_input.get()); + for (size_t i = uint64_count; i < padded_uint64_count; i++) + { + padded_input[i] = 0; + if (i == uint64_count) + { + padded_input[i] |= 0x6; + } + if (i == padded_uint64_count - 1) + { + padded_input[i] |= uint64_t(1) << 63; + } + } + + // Absorb + sha3_state_type state; + memset(state, 0, sha3_state_uint64_count * static_cast(bytes_per_uint64)); + for (size_t i = 0; i < padded_uint64_count; i += sha3_rate_uint64_count) + { + sponge_absorb(padded_input.get() + i, state); + } + + sha3_block = sha3_zero_block; + sponge_squeeze(state, sha3_block); + } + } +} diff --git a/src/seal/util/hash.h b/src/seal/util/hash.h new file mode 100644 index 000000000..8c97243ff --- /dev/null +++ b/src/seal/util/hash.h @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include + +namespace seal +{ + namespace util + { + class HashFunction + { + public: + HashFunction() = delete; + + static constexpr std::size_t sha3_block_uint64_count = 4; + + using sha3_block_type = std::array; + + static constexpr sha3_block_type sha3_zero_block{ { 0, 0, 0, 0 } }; + + static void sha3_hash(const std::uint64_t *input, std::size_t uint64_count, + sha3_block_type &destination); + + inline static void sha3_hash(std::uint64_t input, sha3_block_type &destination) + { + sha3_hash(&input, 1, destination); + } + + private: + static constexpr std::uint8_t sha3_round_count = 24; + + // Rate 1088 = 17 * 64 bits + static constexpr std::uint8_t sha3_rate_uint64_count = 17; + + // Capacity 512 = 8 * 64 bits + static constexpr std::uint8_t sha3_capacity_uint64_count = 8; + + // State size = 1600 = 25 * 64 bits + static constexpr std::uint8_t sha3_state_uint64_count = 25; + + using sha3_state_type = std::uint64_t[5][5]; + + static constexpr std::uint8_t sha3_rho[24]{ + 1, 3, 6, 10, 15, 21, + 28, 36, 45, 55, 2, 14, + 27, 41, 56, 8, 25, 43, + 62, 18, 39, 61, 20, 44 + }; + + static constexpr std::uint64_t sha3_round_consts[sha3_round_count]{ + 0x0000000000000001, 0x0000000000008082, 0x800000000000808a, + 0x8000000080008000, 0x000000000000808b, 0x0000000080000001, + 0x8000000080008081, 0x8000000000008009, 0x000000000000008a, + 0x0000000000000088, 0x0000000080008009, 0x000000008000000a, + 0x000000008000808b, 0x800000000000008b, 0x8000000000008089, + 0x8000000000008003, 0x8000000000008002, 0x8000000000000080, + 0x000000000000800a, 0x800000008000000a, 0x8000000080008081, + 0x8000000000008080, 0x0000000080000001, 0x8000000080008008 + }; + + inline static std::uint64_t rot(std::uint64_t input, std::uint8_t s) + { + return (input << s) | (input >> (64 - s)); + } + + static void keccak_1600(sha3_state_type &state) noexcept; + + inline static void sponge_absorb( + const std::uint64_t sha3_block[sha3_rate_uint64_count], + sha3_state_type &state) noexcept + { + //for (std::uint8_t x = 0; x < 5; x++) + //{ + // for (std::uint8_t y = 0; y < 5; y++) + // { + // std::uint8_t index = 5 * y + x; + // state[x][y] ^= index < sha3_rate_uint64_count ? sha3_block[index] : std::uint64_t(0); + // } + //} + + state[0][0] ^= 0 < sha3_rate_uint64_count ? sha3_block[0] : std::uint64_t(0); + state[0][1] ^= 5 < sha3_rate_uint64_count ? sha3_block[5] : std::uint64_t(0); + state[0][2] ^= 10 < sha3_rate_uint64_count ? sha3_block[10] : std::uint64_t(0); + state[0][3] ^= 15 < sha3_rate_uint64_count ? sha3_block[15] : std::uint64_t(0); + state[0][4] ^= 20 < sha3_rate_uint64_count ? sha3_block[20] : std::uint64_t(0); + + state[1][0] ^= 1 < sha3_rate_uint64_count ? sha3_block[1] : std::uint64_t(0); + state[1][1] ^= 6 < sha3_rate_uint64_count ? sha3_block[6] : std::uint64_t(0); + state[1][2] ^= 11 < sha3_rate_uint64_count ? sha3_block[11] : std::uint64_t(0); + state[1][3] ^= 16 < sha3_rate_uint64_count ? sha3_block[16] : std::uint64_t(0); + state[1][4] ^= 21 < sha3_rate_uint64_count ? sha3_block[21] : std::uint64_t(0); + + state[2][0] ^= 2 < sha3_rate_uint64_count ? sha3_block[2] : std::uint64_t(0); + state[2][1] ^= 7 < sha3_rate_uint64_count ? sha3_block[7] : std::uint64_t(0); + state[2][2] ^= 12 < sha3_rate_uint64_count ? sha3_block[12] : std::uint64_t(0); + state[2][3] ^= 17 < sha3_rate_uint64_count ? sha3_block[17] : std::uint64_t(0); + state[2][4] ^= 22 < sha3_rate_uint64_count ? sha3_block[22] : std::uint64_t(0); + + state[3][0] ^= 3 < sha3_rate_uint64_count ? sha3_block[3] : std::uint64_t(0); + state[3][1] ^= 8 < sha3_rate_uint64_count ? sha3_block[8] : std::uint64_t(0); + state[3][2] ^= 13 < sha3_rate_uint64_count ? sha3_block[13] : std::uint64_t(0); + state[3][3] ^= 18 < sha3_rate_uint64_count ? sha3_block[18] : std::uint64_t(0); + state[3][4] ^= 23 < sha3_rate_uint64_count ? sha3_block[23] : std::uint64_t(0); + + state[4][0] ^= 4 < sha3_rate_uint64_count ? sha3_block[4] : std::uint64_t(0); + state[4][1] ^= 9 < sha3_rate_uint64_count ? sha3_block[9] : std::uint64_t(0); + state[4][2] ^= 14 < sha3_rate_uint64_count ? sha3_block[14] : std::uint64_t(0); + state[4][3] ^= 19 < sha3_rate_uint64_count ? sha3_block[19] : std::uint64_t(0); + state[4][4] ^= 24 < sha3_rate_uint64_count ? sha3_block[24] : std::uint64_t(0); + + keccak_1600(state); + } + + inline static void sponge_squeeze(const sha3_state_type &sha3_state, + sha3_block_type &sha3_block) noexcept + { + // Trivial in this case: we simply output the first blocks of the state + static_assert(sha3_block_uint64_count == 4, "sha3_block_uint64_count must equal 4"); + + sha3_block[0] = sha3_state[0][0]; + sha3_block[1] = sha3_state[1][0]; + sha3_block[2] = sha3_state[2][0]; + sha3_block[3] = sha3_state[3][0]; + } + }; + } +} diff --git a/src/seal/util/hestdparms.h b/src/seal/util/hestdparms.h new file mode 100644 index 000000000..3b80b444c --- /dev/null +++ b/src/seal/util/hestdparms.h @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +/** +Largest allowed bit counts for coeff_modulus based on the security estimates from +HomomorphicEncryption.org security standard. SEAL always samples the secret key +from a ternary {-1, 0, 1} distribution. These tables are used to enforce a minimum +security level when constructing a SEALContext. SEAL_HE_STD_PARMS_128_TC (below) +is used for this purpose by default, but this can easily be changed by editing +seal/util/globals.h if, e.g., higher than 128-bit or post-quantum security levels +should be enforced. +*/ +// Ternary secret; 128 bits classical security +#define SEAL_HE_STD_PARMS_128_TC \ + { std::size_t(1024), 27 }, \ + { std::size_t(2048), 54 }, \ + { std::size_t(4096), 109 }, \ + { std::size_t(8192), 218 }, \ + { std::size_t(16384), 438 }, \ + { std::size_t(32768), 881 } + +// Ternary secret; 192 bits classical security +#define SEAL_HE_STD_PARMS_192_TC \ + { std::size_t(1024), 19 }, \ + { std::size_t(2048), 37 }, \ + { std::size_t(4096), 75 }, \ + { std::size_t(8192), 152 }, \ + { std::size_t(16384), 305 }, \ + { std::size_t(32768), 611 } + +// Ternary secret; 256 bits classical security +#define SEAL_HE_STD_PARMS_256_TC \ + { std::size_t(1024), 14 }, \ + { std::size_t(2048), 29 }, \ + { std::size_t(4096), 58 }, \ + { std::size_t(8192), 118 }, \ + { std::size_t(16384), 237 }, \ + { std::size_t(32768), 476 } + +// Ternary secret; 128 bits quantum security +#define SEAL_HE_STD_PARMS_128_TQ \ + { std::size_t(1024), 25 }, \ + { std::size_t(2048), 51 }, \ + { std::size_t(4096), 101 }, \ + { std::size_t(8192), 202 }, \ + { std::size_t(16384), 411 }, \ + { std::size_t(32768), 827 } + +// Ternary secret; 192 bits quantum security +#define SEAL_HE_STD_PARMS_192_TQ \ + { std::size_t(1024), 17 }, \ + { std::size_t(2048), 35 }, \ + { std::size_t(4096), 70 }, \ + { std::size_t(8192), 141 }, \ + { std::size_t(16384), 284 }, \ + { std::size_t(32768), 571 } + +// Ternary secret; 256 bits quantum security +#define SEAL_HE_STD_PARMS_256_TQ \ + { std::size_t(1024), 13 }, \ + { std::size_t(2048), 27 }, \ + { std::size_t(4096), 54 }, \ + { std::size_t(8192), 109 }, \ + { std::size_t(16384), 220 }, \ + { std::size_t(32768), 443 } + +// Standard deviation for error distribution +#define SEAL_HE_STD_PARMS_ERROR_STD_DEV 3.20 diff --git a/src/seal/util/locks.h b/src/seal/util/locks.h new file mode 100644 index 000000000..dfd7f817d --- /dev/null +++ b/src/seal/util/locks.h @@ -0,0 +1,299 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "seal/util/defines.h" + +#ifdef SEAL_USE_SHARED_MUTEX +#include + +namespace seal +{ + namespace util + { + using ReaderLock = std::shared_lock; + + using WriterLock = std::unique_lock; + + class ReaderWriterLocker + { + public: + ReaderWriterLocker() = default; + + inline ReaderLock acquire_read() + { + return ReaderLock(rw_lock_mutex_); + } + + inline WriterLock acquire_write() + { + return WriterLock(rw_lock_mutex_); + } + + inline ReaderLock try_acquire_read() noexcept + { + return ReaderLock(rw_lock_mutex_, std::try_to_lock); + } + + inline WriterLock try_acquire_write() noexcept + { + return WriterLock(rw_lock_mutex_, std::try_to_lock); + } + + private: + ReaderWriterLocker(const ReaderWriterLocker ©) = delete; + + ReaderWriterLocker &operator =(const ReaderWriterLocker &assign) = delete; + + std::shared_mutex rw_lock_mutex_; + }; + } +} +#else +#include + +namespace seal +{ + namespace util + { + struct try_to_lock_t + { + }; + + constexpr try_to_lock_t try_to_lock{}; + + class ReaderWriterLocker; + + class ReaderLock + { + public: + ReaderLock() noexcept : locker_(nullptr) + { + } + + ReaderLock(ReaderLock &&move) noexcept : locker_(move.locker_) + { + move.locker_ = nullptr; + } + + ReaderLock(ReaderWriterLocker &locker) noexcept : locker_(nullptr) + { + acquire(locker); + } + + ReaderLock(ReaderWriterLocker &locker, try_to_lock_t) noexcept : + locker_(nullptr) + { + try_acquire(locker); + } + + ~ReaderLock() noexcept + { + unlock(); + } + + inline bool owns_lock() const noexcept + { + return locker_ != nullptr; + } + + void unlock() noexcept; + + inline void swap_with(ReaderLock &lock) noexcept + { + std::swap(locker_, lock.locker_); + } + + inline ReaderLock &operator =(ReaderLock &&lock) noexcept + { + swap_with(lock); + lock.unlock(); + return *this; + } + + private: + void acquire(ReaderWriterLocker &locker) noexcept; + + bool try_acquire(ReaderWriterLocker &locker) noexcept; + + ReaderWriterLocker *locker_; + }; + + class WriterLock + { + public: + WriterLock() noexcept : locker_(nullptr) + { + } + + WriterLock(WriterLock &&move) noexcept : locker_(move.locker_) + { + move.locker_ = nullptr; + } + + WriterLock(ReaderWriterLocker &locker) noexcept : locker_(nullptr) + { + acquire(locker); + } + + WriterLock(ReaderWriterLocker &locker, try_to_lock_t) noexcept : + locker_(nullptr) + { + try_acquire(locker); + } + + ~WriterLock() noexcept + { + unlock(); + } + + inline bool owns_lock() const noexcept + { + return locker_ != nullptr; + } + + void unlock() noexcept; + + inline void swap_with(WriterLock &lock) noexcept + { + std::swap(locker_, lock.locker_); + } + + inline WriterLock &operator =(WriterLock &&lock) noexcept + { + swap_with(lock); + lock.unlock(); + return *this; + } + + private: + void acquire(ReaderWriterLocker &locker) noexcept; + + bool try_acquire(ReaderWriterLocker &locker) noexcept; + + ReaderWriterLocker *locker_; + }; + + class ReaderWriterLocker + { + friend class ReaderLock; + + friend class WriterLock; + + public: + ReaderWriterLocker() noexcept : reader_locks_(0), writer_locked_(false) + { + } + + inline ReaderLock acquire_read() noexcept + { + return ReaderLock(*this); + } + + inline WriterLock acquire_write() noexcept + { + return WriterLock(*this); + } + + inline ReaderLock try_acquire_read() noexcept + { + return ReaderLock(*this, try_to_lock); + } + + inline WriterLock try_acquire_write() noexcept + { + return WriterLock(*this, try_to_lock); + } + + private: + ReaderWriterLocker(const ReaderWriterLocker ©) = delete; + + ReaderWriterLocker &operator =(const ReaderWriterLocker &assign) = delete; + + std::atomic reader_locks_; + + std::atomic writer_locked_; + }; + + inline void ReaderLock::unlock() noexcept + { + if (locker_ == nullptr) + { + return; + } + locker_->reader_locks_.fetch_sub(1, std::memory_order_release); + locker_ = nullptr; + } + + inline void ReaderLock::acquire(ReaderWriterLocker &locker) noexcept + { + unlock(); + do + { + locker.reader_locks_.fetch_add(1, std::memory_order_acquire); + locker_ = &locker; + if (locker.writer_locked_.load(std::memory_order_acquire)) + { + unlock(); + while (locker.writer_locked_.load(std::memory_order_acquire)); + } + } while (locker_ == nullptr); + } + + inline bool ReaderLock::try_acquire(ReaderWriterLocker &locker) noexcept + { + unlock(); + locker.reader_locks_.fetch_add(1, std::memory_order_acquire); + locker_ = &locker; + if (locker.writer_locked_.load(std::memory_order_acquire)) + { + unlock(); + return false; + } + return true; + } + + inline void WriterLock::acquire(ReaderWriterLocker &locker) noexcept + { + unlock(); + bool expected = false; + while (!locker.writer_locked_.compare_exchange_strong( + expected, true, std::memory_order_acquire)) + { + expected = false; + } + locker_ = &locker; + while (locker.reader_locks_.load(std::memory_order_acquire) != 0); + } + + inline bool WriterLock::try_acquire(ReaderWriterLocker &locker) noexcept + { + unlock(); + bool expected = false; + if (!locker.writer_locked_.compare_exchange_strong( + expected, true, std::memory_order_acquire)) + { + return false; + } + locker_ = &locker; + if (locker.reader_locks_.load(std::memory_order_acquire) != 0) + { + unlock(); + return false; + } + return true; + } + + inline void WriterLock::unlock() noexcept + { + if (locker_ == nullptr) + { + return; + } + locker_->writer_locked_.store(false, std::memory_order_release); + locker_ = nullptr; + } + } +} +#endif diff --git a/src/seal/util/mempool.cpp b/src/seal/util/mempool.cpp new file mode 100644 index 000000000..2fd765686 --- /dev/null +++ b/src/seal/util/mempool.cpp @@ -0,0 +1,508 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include +#include "seal/util/mempool.h" +#include "seal/util/common.h" +#include "seal/util/uintarith.h" + +using namespace std; + +namespace seal +{ + namespace util + { + MemoryPoolHeadMT::MemoryPoolHeadMT(size_t item_byte_count, + bool clear_on_destruction) : + clear_on_destruction_(clear_on_destruction), + locked_(false), item_byte_count_(item_byte_count), + item_count_(MemoryPool::first_alloc_count), + first_item_(nullptr) + { + if ((item_byte_count_ == 0) || + (item_byte_count_ > MemoryPool::max_batch_alloc_byte_count) || + (mul_safe(item_byte_count_, MemoryPool::first_alloc_count) > + MemoryPool::max_batch_alloc_byte_count)) + { + throw invalid_argument("invalid allocation size"); + } + + // Initial allocation + allocation new_alloc; + try + { + new_alloc.data_ptr = new SEAL_BYTE[ + mul_safe(MemoryPool::first_alloc_count, item_byte_count_)]; + } + catch (const bad_alloc &) + { + // Allocation failed; rethrow + throw; + } + + new_alloc.size = MemoryPool::first_alloc_count; + new_alloc.free = MemoryPool::first_alloc_count; + new_alloc.head_ptr = new_alloc.data_ptr; + allocs_.clear(); + allocs_.push_back(new_alloc); + } + + MemoryPoolHeadMT::~MemoryPoolHeadMT() noexcept + { + bool expected = false; + while (!locked_.compare_exchange_strong( + expected, true, memory_order_acquire)) + { + expected = false; + } + + // Delete the items (but not the memory) + MemoryPoolItem *curr_item = first_item_; + while (curr_item) + { + MemoryPoolItem *next_item = curr_item->next(); + delete curr_item; + curr_item = next_item; + } + first_item_ = nullptr; + + if (clear_on_destruction_) + { + // Delete the memory + for (auto &alloc : allocs_) + { + // Do we need to clear the memory? + if (clear_on_destruction_) + { + std::size_t curr_alloc_byte_count = mul_safe(item_byte_count_, alloc.size); + volatile SEAL_BYTE *data_ptr = reinterpret_cast(alloc.data_ptr); + while (curr_alloc_byte_count--) + { + *data_ptr++ = static_cast(0); + } + } + + // Delete this allocation + delete[] alloc.data_ptr; + } + } + else + { + // Delete the memory + for (auto &alloc : allocs_) + { + // Delete this allocation + delete[] alloc.data_ptr; + } + } + + allocs_.clear(); + } + + MemoryPoolItem *MemoryPoolHeadMT::get() + { + bool expected = false; + while (!locked_.compare_exchange_strong( + expected, true, memory_order_acquire)) + { + expected = false; + } + MemoryPoolItem *old_first = first_item_; + + // Is pool empty? + if (old_first == nullptr) + { + allocation &last_alloc = allocs_.back(); + MemoryPoolItem *new_item = nullptr; + if (last_alloc.free > 0) + { + // Pool is empty; there is memory + new_item = new MemoryPoolItem(last_alloc.head_ptr); + last_alloc.free--; + last_alloc.head_ptr += item_byte_count_; + } + else + { + // Pool is empty; there is no memory + allocation new_alloc; + + // Increase allocation size unless we are already at max + size_t new_size = safe_cast( + ceil(MemoryPool::alloc_size_multiplier * + static_cast(last_alloc.size))); + size_t new_alloc_byte_count = mul_safe(new_size, item_byte_count_); + if (new_alloc_byte_count > + MemoryPool::max_batch_alloc_byte_count) + { + new_size = last_alloc.size; + new_alloc_byte_count = new_size * item_byte_count_; + } + + try + { + new_alloc.data_ptr = new SEAL_BYTE[new_alloc_byte_count]; + } + catch (const bad_alloc &) + { + // Allocation failed; rethrow + throw; + } + + new_alloc.size = new_size; + new_alloc.free = new_size - 1; + new_alloc.head_ptr = new_alloc.data_ptr + item_byte_count_; + allocs_.push_back(new_alloc); + item_count_ += new_size; + new_item = new MemoryPoolItem(new_alloc.data_ptr); + } + + locked_.store(false, memory_order_release); + return new_item; + } + + // Pool is not empty + first_item_ = old_first->next(); + old_first->next() = nullptr; + locked_.store(false, memory_order_release); + return old_first; + } + + MemoryPoolHeadST::MemoryPoolHeadST(size_t item_byte_count, + bool clear_on_destruction) : + clear_on_destruction_(clear_on_destruction), + item_byte_count_(item_byte_count), + item_count_(MemoryPool::first_alloc_count), + first_item_(nullptr) + { + if ((item_byte_count_ == 0) || + (item_byte_count_ > MemoryPool::max_batch_alloc_byte_count) || + (mul_safe(item_byte_count_, MemoryPool::first_alloc_count) > + MemoryPool::max_batch_alloc_byte_count)) + { + throw invalid_argument("invalid allocation size"); + } + + // Initial allocation + allocation new_alloc; + try + { + new_alloc.data_ptr = new SEAL_BYTE[ + mul_safe(MemoryPool::first_alloc_count, item_byte_count_)]; + } + catch (const bad_alloc &) + { + // Allocation failed; rethrow + throw; + } + + new_alloc.size = MemoryPool::first_alloc_count; + new_alloc.free = MemoryPool::first_alloc_count; + new_alloc.head_ptr = new_alloc.data_ptr; + allocs_.clear(); + allocs_.push_back(new_alloc); + } + + MemoryPoolHeadST::~MemoryPoolHeadST() noexcept + { + // Delete the items (but not the memory) + MemoryPoolItem *curr_item = first_item_; + while(curr_item) + { + MemoryPoolItem *next_item = curr_item->next(); + delete curr_item; + curr_item = next_item; + } + first_item_ = nullptr; + + if (clear_on_destruction_) + { + // Delete the memory + for (auto &alloc : allocs_) + { + // Do we need to clear the memory? + if (clear_on_destruction_) + { + std::size_t curr_alloc_byte_count = mul_safe(item_byte_count_, alloc.size); + volatile SEAL_BYTE *data_ptr = reinterpret_cast(alloc.data_ptr); + while (curr_alloc_byte_count--) + { + *data_ptr++ = static_cast(0); + } + } + + // Delete this allocation + delete[] alloc.data_ptr; + } + } + else + { + // Delete the memory + for (auto &alloc : allocs_) + { + // Delete this allocation + delete[] alloc.data_ptr; + } + } + + allocs_.clear(); + } + + MemoryPoolItem *MemoryPoolHeadST::get() + { + MemoryPoolItem *old_first = first_item_; + + // Is pool empty? + if (old_first == nullptr) + { + allocation &last_alloc = allocs_.back(); + MemoryPoolItem *new_item = nullptr; + if (last_alloc.free > 0) + { + // Pool is empty; there is memory + new_item = new MemoryPoolItem(last_alloc.head_ptr); + last_alloc.free--; + last_alloc.head_ptr += item_byte_count_; + } + else + { + // Pool is empty; there is no memory + allocation new_alloc; + + // Increase allocation size unless we are already at max + size_t new_size = safe_cast( + ceil(MemoryPool::alloc_size_multiplier * + static_cast(last_alloc.size))); + size_t new_alloc_byte_count = mul_safe(new_size, item_byte_count_); + if (new_alloc_byte_count > + MemoryPool::max_batch_alloc_byte_count) + { + new_size = last_alloc.size; + new_alloc_byte_count = new_size * item_byte_count_; + } + + try + { + new_alloc.data_ptr = new SEAL_BYTE[new_alloc_byte_count]; + } + catch (const bad_alloc &) + { + // Allocation failed; rethrow + throw; + } + + new_alloc.size = new_size; + new_alloc.free = new_size - 1; + new_alloc.head_ptr = new_alloc.data_ptr + item_byte_count_; + allocs_.push_back(new_alloc); + item_count_ += new_size; + new_item = new MemoryPoolItem(new_alloc.data_ptr); + } + + return new_item; + } + + // Pool is not empty + first_item_ = old_first->next(); + old_first->next() = nullptr; + return old_first; + } + + const size_t MemoryPool::max_single_alloc_byte_count = + []() -> size_t { + int bit_shift = static_cast( + ceil(log2(MemoryPool::alloc_size_multiplier))); + if (bit_shift < 0 || unsigned_geq(bit_shift, + sizeof(size_t) * static_cast(bits_per_byte))) + { + throw logic_error("alloc_size_multiplier too large"); + } + return numeric_limits::max() >> bit_shift; + }(); + + const size_t MemoryPool::max_batch_alloc_byte_count = + []() -> size_t { + int bit_shift = static_cast( + ceil(log2(MemoryPool::alloc_size_multiplier))); + if (bit_shift < 0 || unsigned_geq(bit_shift, + sizeof(size_t) * static_cast(bits_per_byte))) + { + throw logic_error("alloc_size_multiplier too large"); + } + return numeric_limits::max() >> bit_shift; + }(); + + MemoryPoolMT::~MemoryPoolMT() noexcept + { + WriterLock lock(pools_locker_.acquire_write()); + for(MemoryPoolHead *head : pools_) + { + delete head; + } + pools_.clear(); + } + + Pointer MemoryPoolMT::get_for_byte_count(size_t byte_count) + { + if (byte_count > max_single_alloc_byte_count) + { + throw invalid_argument("invalid allocation size"); + } + else if (byte_count == 0) + { + return Pointer(); + } + + // Attempt to find size. + ReaderLock reader_lock(pools_locker_.acquire_read()); + size_t start = 0; + size_t end = pools_.size(); + while (start < end) + { + size_t mid = (start + end) / 2; + MemoryPoolHead *mid_head = pools_[mid]; + size_t mid_byte_count = mid_head->item_byte_count(); + if (byte_count < mid_byte_count) + { + start = mid + 1; + } + else if (byte_count > mid_byte_count) + { + end = mid; + } + else + { + return Pointer(mid_head); + } + } + reader_lock.unlock(); + + // Size was not found, so obtain an exclusive lock and search again. + WriterLock writer_lock(pools_locker_.acquire_write()); + start = 0; + end = pools_.size(); + while (start < end) + { + size_t mid = (start + end) / 2; + MemoryPoolHead *mid_head = pools_[mid]; + size_t mid_byte_count = mid_head->item_byte_count(); + if (byte_count < mid_byte_count) + { + start = mid + 1; + } + else if (byte_count > mid_byte_count) + { + end = mid; + } + else + { + return Pointer(mid_head); + } + } + + // Size was still not found, but we own an exclusive lock so just add it, + // but first check if we are at maximum pool head count already. + if (pools_.size() >= max_pool_head_count) + { + throw runtime_error("maximum pool head count reached"); + } + + MemoryPoolHead *new_head = new MemoryPoolHeadMT(byte_count, clear_on_destruction_); + if (!pools_.empty()) + { + pools_.insert(pools_.begin() + static_cast(start), new_head); + } + else + { + pools_.emplace_back(new_head); + } + + return Pointer(new_head); + } + + size_t MemoryPoolMT::alloc_byte_count() const + { + ReaderLock lock(pools_locker_.acquire_read()); + + return accumulate(pools_.cbegin(), pools_.cend(), size_t(0), + [](size_t byte_count, MemoryPoolHead *head) { + return add_safe(byte_count, + mul_safe(head->item_count(), head->item_byte_count())); + }); + } + + MemoryPoolST::~MemoryPoolST() noexcept + { + for(MemoryPoolHead *head : pools_) + { + delete head; + } + pools_.clear(); + } + + Pointer MemoryPoolST::get_for_byte_count(size_t byte_count) + { + if (byte_count > MemoryPool::max_single_alloc_byte_count) + { + throw invalid_argument("invalid allocation size"); + } + else if (byte_count == 0) + { + return Pointer(); + } + + // Attempt to find size. + size_t start = 0; + size_t end = pools_.size(); + while (start < end) + { + size_t mid = (start + end) / 2; + MemoryPoolHead *mid_head = pools_[mid]; + size_t mid_byte_count = mid_head->item_byte_count(); + if (byte_count < mid_byte_count) + { + start = mid + 1; + } + else if (byte_count > mid_byte_count) + { + end = mid; + } + else + { + return Pointer(mid_head); + } + } + + // Size was not found so just add it, but first check if we are at + // maximum pool head count already. + if (pools_.size() >= max_pool_head_count) + { + throw runtime_error("maximum pool head count reached"); + } + + MemoryPoolHead *new_head = new MemoryPoolHeadST(byte_count, clear_on_destruction_); + if (!pools_.empty()) + { + pools_.insert(pools_.begin() + static_cast(start), new_head); + } + else + { + pools_.emplace_back(new_head); + } + + return Pointer(new_head); + } + + size_t MemoryPoolST::alloc_byte_count() const + { + return accumulate(pools_.cbegin(), pools_.cend(), size_t(0), + [](size_t byte_count, MemoryPoolHead *head) { + return add_safe(byte_count, + mul_safe(head->item_count(), head->item_byte_count())); + }); + } + } +} diff --git a/src/seal/util/mempool.h b/src/seal/util/mempool.h new file mode 100644 index 000000000..937697865 --- /dev/null +++ b/src/seal/util/mempool.h @@ -0,0 +1,298 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "seal/util/defines.h" +#include "seal/util/globals.h" +#include "seal/util/common.h" +#include "seal/util/locks.h" + +namespace seal +{ + namespace util + { + template::value>> + class ConstPointer; + + template<> + class ConstPointer; + + template::value>> + class Pointer; + + class MemoryPoolItem + { + public: + MemoryPoolItem(SEAL_BYTE *data) noexcept : data_(data) + { + } + + inline SEAL_BYTE *data() noexcept + { + return data_; + } + + inline const SEAL_BYTE *data() const noexcept + { + return data_; + } + + inline MemoryPoolItem* &next() noexcept + { + return next_; + } + + inline const MemoryPoolItem *next() const noexcept + { + return next_; + } + + private: + MemoryPoolItem(const MemoryPoolItem ©) = delete; + + MemoryPoolItem &operator =(const MemoryPoolItem &assign) = delete; + + SEAL_BYTE *data_ = nullptr; + + MemoryPoolItem *next_ = nullptr; + }; + + class MemoryPoolHead + { + public: + struct allocation + { + allocation() : + size(0), data_ptr(nullptr), free(0), head_ptr(nullptr) + { + } + + // Size of the allocation (number of items it can hold) + std::size_t size; + + // Pointer to start of the allocation + SEAL_BYTE *data_ptr; + + // How much free space is left (number of items that still fit) + std::size_t free; + + // Pointer to current head of allocation + SEAL_BYTE *head_ptr; + }; + + // The overriding functions are noexcept(false) + virtual ~MemoryPoolHead() = default; + + // Byte size of the allocations (items) owned by this pool + virtual std::size_t item_byte_count() const noexcept = 0; + + // Total number of items allocated + virtual std::size_t item_count() const noexcept = 0; + + virtual MemoryPoolItem *get() = 0; + + // Return item back to this pool + virtual void add(MemoryPoolItem *new_first) noexcept = 0; + }; + + class MemoryPoolHeadMT : public MemoryPoolHead + { + public: + // Creates a new MemoryPoolHeadMT with allocation for one single item. + MemoryPoolHeadMT(std::size_t item_byte_count, + bool clear_on_destruction = false); + + ~MemoryPoolHeadMT() noexcept override; + + // Byte size of the allocations (items) owned by this pool + inline std::size_t item_byte_count() const noexcept override + { + return item_byte_count_; + } + + // Returns the total number of items allocated + inline std::size_t item_count() const noexcept override + { + return item_count_; + } + + MemoryPoolItem *get() override; + + inline void add(MemoryPoolItem *new_first) noexcept override + { + bool expected = false; + while (!locked_.compare_exchange_strong( + expected, true, std::memory_order_acquire)) + { + expected = false; + } + MemoryPoolItem *old_first = first_item_; + new_first->next() = old_first; + first_item_ = new_first; + locked_.store(false, std::memory_order_release); + } + + private: + MemoryPoolHeadMT(const MemoryPoolHeadMT ©) = delete; + + MemoryPoolHeadMT &operator =(const MemoryPoolHeadMT &assign) = delete; + + const bool clear_on_destruction_; + + mutable std::atomic locked_; + + const std::size_t item_byte_count_; + + volatile std::size_t item_count_; + + std::vector allocs_; + + MemoryPoolItem* volatile first_item_; + }; + + class MemoryPoolHeadST : public MemoryPoolHead + { + public: + // Creates a new MemoryPoolHeadST with allocation for one single item. + MemoryPoolHeadST(std::size_t item_byte_count, + bool clear_on_destruction = false); + + ~MemoryPoolHeadST() noexcept override; + + // Byte size of the allocations (items) owned by this pool + inline std::size_t item_byte_count() const noexcept override + { + return item_byte_count_; + } + + // Returns the total number of items allocated + inline std::size_t item_count() const noexcept override + { + return item_count_; + } + + MemoryPoolItem *get() override; + + inline void add(MemoryPoolItem *new_first) noexcept override + { + new_first->next() = first_item_; + first_item_ = new_first; + } + + private: + MemoryPoolHeadST(const MemoryPoolHeadST ©) = delete; + + MemoryPoolHeadST &operator =(const MemoryPoolHeadST &assign) = delete; + + const bool clear_on_destruction_; + + std::size_t item_byte_count_; + + std::size_t item_count_; + + std::vector allocs_; + + MemoryPoolItem *first_item_; + }; + + class MemoryPool + { + public: + static constexpr double alloc_size_multiplier = 1.05; + + // Largest size of single allocation that can be requested from memory pool + static const std::size_t max_single_alloc_byte_count; + + // Number of different size allocations allowed by a single memory pool + static constexpr std::size_t max_pool_head_count = + std::numeric_limits::max(); + + // Largest allowed size of batch allocation + static const std::size_t max_batch_alloc_byte_count; + + static constexpr std::size_t first_alloc_count = 1; + + virtual ~MemoryPool() = default; + + virtual Pointer get_for_byte_count(std::size_t byte_count) = 0; + + virtual std::size_t pool_count() const = 0; + + virtual std::size_t alloc_byte_count() const = 0; + }; + + class MemoryPoolMT : public MemoryPool + { + public: + MemoryPoolMT(bool clear_on_destruction = false) : + clear_on_destruction_(clear_on_destruction) + { + }; + + ~MemoryPoolMT() noexcept override; + + Pointer get_for_byte_count(std::size_t byte_count) override; + + inline std::size_t pool_count() const override + { + ReaderLock lock(pools_locker_.acquire_read()); + return pools_.size(); + } + + std::size_t alloc_byte_count() const override; + + protected: + MemoryPoolMT(const MemoryPoolMT ©) = delete; + + MemoryPoolMT &operator =(const MemoryPoolMT &assign) = delete; + + const bool clear_on_destruction_; + + mutable ReaderWriterLocker pools_locker_; + + std::vector pools_; + }; + + class MemoryPoolST : public MemoryPool + { + public: + MemoryPoolST(bool clear_on_destruction = false) : + clear_on_destruction_(clear_on_destruction) + { + }; + + ~MemoryPoolST() noexcept override; + + Pointer get_for_byte_count(std::size_t byte_count) override; + + inline std::size_t pool_count() const override + { + return pools_.size(); + } + + std::size_t alloc_byte_count() const override; + + protected: + MemoryPoolST(const MemoryPoolST ©) = delete; + + MemoryPoolST &operator =(const MemoryPoolST &assign) = delete; + + const bool clear_on_destruction_; + + std::vector pools_; + }; + } +} diff --git a/src/seal/util/msvc.h b/src/seal/util/msvc.h new file mode 100644 index 000000000..5c524694f --- /dev/null +++ b/src/seal/util/msvc.h @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#if SEAL_COMPILER == SEAL_COMPILER_MSVC + +// Require Visual Studio 2017 version 15.3 or newer +#if (_MSC_VER < 1911) +#error "Microsoft Visual Studio 2017 version 15.3 or newer required" +#endif + +// Read in config.h +#include "seal/util/config.h" + +// Try to check presence of additional headers using __has_include +#ifdef __has_include + +// Check for MSGSL +#if __has_include() +#include +#define SEAL_USE_MSGSL +#else +#undef SEAL_USE_MSGSL +#endif //__has_include() + +#endif + +// Are we compiling with C++17 or newer +#if (__cplusplus >= 201703L) +// Use `if constexpr' +#define SEAL_USE_IF_CONSTEXPR + +// Use [[maybe_unused]] +#define SEAL_USE_MAYBE_UNUSED +#else +#undef SEAL_USE_IF_CONSTEXPR +#undef SEAL_USE_MAYBE_UNUSED +#endif + +// Define SEAL_ENFORCE_HE_STD_SECURITY to enforce at least 128-bit security level +// based on HomomorphicEncryption.org estimates. This is incompatible with the +// unit tests so it is disabled by default. +#undef SEAL_ENFORCE_HE_STD_SECURITY + +// X64 +#ifdef _M_X64 + +#ifdef SEAL_USE_INTRIN +#include + +#ifdef SEAL_USE__UMUL128 +#pragma intrinsic(_umul128) +#define SEAL_MULTIPLY_UINT64_HW64(operand1, operand2, hw64) { \ + _umul128(operand1, operand2, hw64); \ +} + +#define SEAL_MULTIPLY_UINT64(operand1, operand2, result128) { \ + result128[0] = _umul128(operand1, operand2, result128 + 1); \ +} +#endif + +#ifdef SEAL_USE__BITSCANREVERSE64 +#pragma intrinsic(_BitScanReverse64) +#define SEAL_MSB_INDEX_UINT64(result, value) _BitScanReverse64(result, value) +#endif + +#ifdef SEAL_USE__ADDCARRY_U64 +#pragma intrinsic(_addcarry_u64) +#define SEAL_ADD_CARRY_UINT64(operand1, operand2, carry, result) _addcarry_u64( \ + carry, operand1, operand2, result) +#endif + +#ifdef SEAL_USE__SUBBORROW_U64 +#pragma intrinsic(_subborrow_u64) +#define SEAL_SUB_BORROW_UINT64(operand1, operand2, borrow, result) _subborrow_u64( \ + borrow, operand1, operand2, result) +#endif + +#endif +#else +#undef SEAL_USE_INTRIN + +#endif //_M_X64 + +#endif diff --git a/src/seal/util/numth.cpp b/src/seal/util/numth.cpp new file mode 100644 index 000000000..c2637728e --- /dev/null +++ b/src/seal/util/numth.cpp @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/util/numth.h" +#include "seal/util/uintcore.h" + +using namespace std; + +namespace seal +{ + namespace util + { + vector conjugate_classes(uint64_t modulus, + uint64_t subgroup_generator) + { + if (!product_fits_in(modulus, subgroup_generator)) + { + throw invalid_argument("inputs too large"); + } + + vector classes{}; + for (uint64_t i = 0; i < modulus; i++) + { + if (gcd(i, modulus) > 1) + { + classes.push_back(0); + } + else + { + classes.push_back(i); + } + } + for (uint64_t i = 0; i < modulus; i++) + { + if (classes[i] == 0) + { + continue; + } + if (classes[i] < i) + { + // i is not a pivot, updated its pivot + classes[i] = classes[classes[i]]; + continue; + } + // If i is a pivot, update other pivots to point to it + uint64_t j = (i * subgroup_generator) % modulus; + while (classes[j] != i) + { + // Merge the equivalence classes of j and i + // Note: if classes[j] != j then classes[j] will be updated later, + // when we get to i = j and use the code for "i not pivot". + classes[classes[j]] = i; + j = (j * subgroup_generator) % modulus; + } + } + return classes; + } + + vector multiplicative_orders( + vector conjugate_classes, uint64_t modulus) + { + if (!product_fits_in(modulus, modulus)) + { + throw invalid_argument("inputs too large"); + } + + vector orders{}; + orders.push_back(0); + orders.push_back(1); + + for (uint64_t i = 2; i < modulus; i++) + { + if (conjugate_classes[i] <= 1) + { + orders.push_back(conjugate_classes[i]); + continue; + } + if (conjugate_classes[i] < i) + { + orders.push_back(orders[conjugate_classes[i]]); + continue; + } + uint64_t j = (i * i) % modulus; + uint64_t order = 2; + while (conjugate_classes[j] != 1) + { + j = (j * i) % modulus; + order++; + } + orders.push_back(order); + } + return orders; + } + + void babystep_giantstep(uint64_t modulus, + vector &baby_steps, vector &giant_steps) + { + int exponent = get_power_of_two(modulus); + if (exponent < 0) + { + throw invalid_argument("modulus must be a power of 2"); + } + + // Compute square root of modulus (k stores the baby steps) + uint64_t k = uint64_t(1) << (exponent / 2); + uint64_t l = modulus / k; + + baby_steps.clear(); + giant_steps.clear(); + + uint64_t m = mul_safe(modulus, uint64_t(2)); + uint64_t g = 3; // the generator + uint64_t kprime = k >> 1; + uint64_t value = 1; + for (uint64_t i = 0; i < kprime; i++) + { + baby_steps.push_back(value); + baby_steps.push_back(m - value); + value = mul_safe(value, g) % m; + } + + // now value should equal to g**kprime + uint64_t value2 = value; + for (uint64_t j = 0; j < l; j++) + { + giant_steps.push_back(value2); + value2 = mul_safe(value2, value) % m; + } + } + + pair decompose_babystep_giantstep( + uint64_t modulus, uint64_t input, + const vector &baby_steps, + const vector &giant_steps) + { + for (size_t i = 0; i < giant_steps.size(); i++) + { + uint64_t gs = giant_steps[i]; + for (size_t j = 0; j < baby_steps.size(); j++) + { + uint64_t bs = baby_steps[j]; + if (mul_safe(gs, bs) % modulus == input) + { + return { i, j }; + } + } + } + throw logic_error("failed to decompose input"); + } + } +} diff --git a/src/seal/util/numth.h b/src/seal/util/numth.h new file mode 100644 index 000000000..6df151010 --- /dev/null +++ b/src/seal/util/numth.h @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include +#include "seal/util/common.h" + +namespace seal +{ + namespace util + { + inline std::uint64_t gcd(std::uint64_t x, std::uint64_t y) + { +#ifdef SEAL_DEBUG + if (x == 0) + { + std::invalid_argument("x cannot be zero"); + } + if (y == 0) + { + std::invalid_argument("y cannot be zero"); + } +#endif + if (x < y) + { + return gcd(y, x); + } + else if (y == 0) + { + return x; + } + else + { + std::uint64_t f = x % y; + if (f == 0) + { + return y; + } + else + { + return gcd(y, f); + } + } + } + + inline auto xgcd(std::uint64_t x, std::uint64_t y) + -> std::tuple + { + /* Extended GCD: + Returns (gcd, x, y) where gcd is the greatest common divisor of a and b. + The numbers x, y are such that gcd = ax + by. + */ +#ifdef SEAL_DEBUG + if (x == 0) + { + std::invalid_argument("x cannot be zero"); + } + if (y == 0) + { + std::invalid_argument("y cannot be zero"); + } +#endif + std::int64_t prev_a = 1; + std::int64_t a = 0; + std::int64_t prev_b = 0; + std::int64_t b = 1; + + while (y != 0) + { + std::int64_t q = util::safe_cast(x / y); + std::int64_t temp = util::safe_cast(x % y); + x = y; + y = util::safe_cast(temp); + + temp = a; + a = util::sub_safe(prev_a, mul_safe(q, a)); + prev_a = temp; + + temp = b; + b = util::sub_safe(prev_b, mul_safe(q, b)); + prev_b = temp; + } + return std::make_tuple(x, prev_a, prev_b); + } + + inline bool try_mod_inverse(std::uint64_t value, + std::uint64_t modulus, std::uint64_t &result) + { +#ifdef SEAL_DEBUG + if (value == 0) + { + std::invalid_argument("value cannot be zero"); + } + if (modulus <= 1) + { + std::invalid_argument("modulus must be at least 2"); + } +#endif + auto gcd_tuple = xgcd(value, modulus); + if (std::get<0>(gcd_tuple) != 1) + { + return false; + } + else if (std::get<1>(gcd_tuple) < 0) + { + result = static_cast(std::get<1>(gcd_tuple)) + modulus; + return true; + } + else + { + result = static_cast(std::get<1>(gcd_tuple)); + return true; + } + } + + std::vector multiplicative_orders( + std::vector conjugate_classes, + std::uint64_t modulus); + + std::vector conjugate_classes(std::uint64_t modulus, + std::uint64_t subgroup_generator); + + void babystep_giantstep(std::uint64_t modulus, + std::vector &baby_steps, + std::vector &giant_steps); + + auto decompose_babystep_giantstep( + std::uint64_t modulus, + std::uint64_t input, + const std::vector &baby_steps, + const std::vector &giant_steps) + -> std::pair; + } +} diff --git a/src/seal/util/pointer.h b/src/seal/util/pointer.h new file mode 100644 index 000000000..31eb71513 --- /dev/null +++ b/src/seal/util/pointer.h @@ -0,0 +1,1251 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "seal/util/defines.h" +#include "seal/util/common.h" +#include "seal/util/mempool.h" +#include +#include +#include + +namespace seal +{ + namespace util + { + // Specialization for SEAL_BYTE + template<> + class Pointer + { + friend class MemoryPoolST; + friend class MemoryPoolMT; + + public: + template friend class Pointer; + template friend class ConstPointer; + + Pointer() = default; + + // Move of the same type + Pointer(Pointer &&source) noexcept : + data_(source.data_), head_(source.head_), + item_(source.item_), alias_(source.alias_) + { + source.data_ = nullptr; + source.head_ = nullptr; + source.item_ = nullptr; + source.alias_ = false; + } + + // Move of the same type + Pointer(Pointer &&source, SEAL_BYTE value) : + Pointer(std::move(source)) + { + std::fill_n(data_, head_->item_byte_count(), value); + } + + inline SEAL_BYTE &operator [](std::size_t index) + { + return data_[index]; + } + + inline const SEAL_BYTE &operator [](std::size_t index) const + { + return data_[index]; + } + + inline auto &operator =(Pointer &&assign) noexcept + { + acquire(std::move(assign)); + return *this; + } + + inline bool is_set() const noexcept + { + return data_ != nullptr; + } + + inline SEAL_BYTE *get() noexcept + { + return data_; + } + + inline const SEAL_BYTE *get() const noexcept + { + return data_; + } + + inline SEAL_BYTE *operator ->() noexcept + { + return data_; + } + + inline const SEAL_BYTE *operator ->() const noexcept + { + return data_; + } + + inline SEAL_BYTE &operator *() + { + return *data_; + } + + inline const SEAL_BYTE &operator *() const + { + return *data_; + } + + inline bool is_alias() const noexcept + { + return alias_; + } + + inline void release() noexcept + { + if (head_) + { + // Return the memory to pool + head_->add(item_); + } + else if (data_ && !alias_) + { + // Free the memory + delete[] data_; + } + + data_ = nullptr; + head_ = nullptr; + item_ = nullptr; + alias_ = false; + } + + void acquire(Pointer &other) noexcept + { + if (this == &other) + { + return; + } + + release(); + + data_ = other.data_; + head_ = other.head_; + item_ = other.item_; + alias_ = other.alias_; + other.data_ = nullptr; + other.head_ = nullptr; + other.item_ = nullptr; + other.alias_ = false; + } + + inline void acquire(Pointer &&other) noexcept + { + acquire(other); + } + + inline void swap_with(Pointer &other) noexcept + { + std::swap(data_, other.data_); + std::swap(head_, other.head_); + std::swap(item_, other.item_); + std::swap(alias_, other.alias_); + } + + inline void swap_with(Pointer &&other) noexcept + { + swap_with(other); + } + + ~Pointer() noexcept + { + release(); + } + + operator bool() const noexcept + { + return (data_ != nullptr); + } + + inline static Pointer Owning(SEAL_BYTE *pointer) noexcept + { + return {pointer, false}; + } + + inline static Pointer Aliasing(SEAL_BYTE *pointer) noexcept + { + return {pointer, true}; + } + + private: + Pointer(const Pointer ©) = delete; + + Pointer &operator =(const Pointer &assign) = delete; + + Pointer(SEAL_BYTE *pointer, bool alias) noexcept : + data_(pointer), alias_(alias) + { + } + + Pointer(class MemoryPoolHead *head) + { +#ifdef SEAL_DEBUG + if(!head) + { + throw std::invalid_argument("head cannot be nullptr"); + } +#endif + head_ = head; + item_ = head->get(); + data_ = item_->data(); + } + + SEAL_BYTE *data_ = nullptr; + + MemoryPoolHead *head_ = nullptr; + + MemoryPoolItem *item_ = nullptr; + + bool alias_ = false; + }; + + template + class Pointer + { + friend class MemoryPoolST; + friend class MemoryPoolMT; + + public: + friend class Pointer; + friend class ConstPointer; + friend class ConstPointer; + + Pointer() = default; + + // Move of the same type + Pointer(Pointer &&source) noexcept : + data_(source.data_), head_(source.head_), + item_(source.item_), alias_(source.alias_) + { + source.data_ = nullptr; + source.head_ = nullptr; + source.item_ = nullptr; + source.alias_ = false; + } + + // Move when T is not SEAL_BYTE + Pointer(Pointer &&source) + { + // Cannot acquire a non-pool pointer of different type + if (!source.head_ && source.data_) + { + throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); + } + + head_ = source.head_; + item_ = source.item_; + if (head_) + { + data_ = reinterpret_cast(item_->data()); + SEAL_IF_CONSTEXPR (!std::is_trivially_constructible::value) + { + auto count = head_->item_byte_count() / sizeof(T); + for (auto alloc_ptr = data_; count--; alloc_ptr++) + { + new(alloc_ptr) T; + } + } + } + alias_ = source.alias_; + + source.data_ = nullptr; + source.head_ = nullptr; + source.item_ = nullptr; + source.alias_ = false; + } + + // Move when T is not SEAL_BYTE + template + Pointer(Pointer &&source, Args &&...args) + { + // Cannot acquire a non-pool pointer of different type + if (!source.head_ && source.data_) + { + throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); + } + + head_ = source.head_; + item_ = source.item_; + if (head_) + { + data_ = reinterpret_cast(item_->data()); + auto count = head_->item_byte_count() / sizeof(T); + for (auto alloc_ptr = data_; count--; alloc_ptr++) + { + new(alloc_ptr) T(std::forward(args)...); + } + } + alias_ = source.alias_; + + source.data_ = nullptr; + source.head_ = nullptr; + source.item_ = nullptr; + source.alias_ = false; + } + + inline T &operator [](std::size_t index) + { + return data_[index]; + } + + inline const T &operator [](std::size_t index) const + { + return data_[index]; + } + + inline auto &operator =(Pointer &&assign) noexcept + { + acquire(std::move(assign)); + return *this; + } + + inline auto &operator =(Pointer &&assign) + { + acquire(std::move(assign)); + return *this; + } + + inline bool is_set() const noexcept + { + return data_ != nullptr; + } + + inline T *get() noexcept + { + return data_; + } + + inline const T *get() const noexcept + { + return data_; + } + + inline T *operator ->() noexcept + { + return data_; + } + + inline const T *operator ->() const noexcept + { + return data_; + } + + inline T &operator *() + { + return *data_; + } + + inline const T &operator *() const + { + return *data_; + } + + inline bool is_alias() const noexcept + { + return alias_; + } + + inline void release() noexcept + { + if (head_) + { + SEAL_IF_CONSTEXPR (!std::is_trivially_destructible::value) + { + // Manual destructor calls + auto count = head_->item_byte_count() / sizeof(T); + for (auto alloc_ptr = data_; count--; alloc_ptr++) + { + alloc_ptr->~T(); + } + } + + // Return the memory to pool + head_->add(item_); + } + else if (data_ && !alias_) + { + // Free the memory + delete[] data_; + } + + data_ = nullptr; + head_ = nullptr; + item_ = nullptr; + alias_ = false; + } + + void acquire(Pointer &other) noexcept + { + if (this == &other) + { + return; + } + + release(); + + data_ = other.data_; + head_ = other.head_; + item_ = other.item_; + alias_ = other.alias_; + other.data_ = nullptr; + other.head_ = nullptr; + other.item_ = nullptr; + other.alias_ = false; + } + + inline void acquire(Pointer &&other) noexcept + { + acquire(other); + } + + void acquire(Pointer &other) + { + // Cannot acquire a non-pool pointer of different type + if (!other.head_ && other.data_) + { + throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); + } + + release(); + + head_ = other.head_; + item_ = other.item_; + if (head_) + { + data_ = reinterpret_cast(item_->data()); + SEAL_IF_CONSTEXPR (!std::is_trivially_constructible::value) + { + auto count = head_->item_byte_count() / sizeof(T); + for (auto alloc_ptr = data_; count--; alloc_ptr++) + { + new(alloc_ptr) T; + } + } + } + alias_ = other.alias_; + other.data_ = nullptr; + other.head_ = nullptr; + other.item_ = nullptr; + other.alias_ = false; + } + + inline void acquire(Pointer &&other) + { + acquire(other); + } + + inline void swap_with(Pointer &other) noexcept + { + std::swap(data_, other.data_); + std::swap(head_, other.head_); + std::swap(item_, other.item_); + std::swap(alias_, other.alias_); + } + + inline void swap_with(Pointer &&other) noexcept + { + swap_with(other); + } + + ~Pointer() noexcept + { + release(); + } + + operator bool() const noexcept + { + return (data_ != nullptr); + } + + inline static Pointer Owning(T *pointer) noexcept + { + return {pointer, false}; + } + + inline static Pointer Aliasing(T *pointer) noexcept + { + return {pointer, true}; + } + + private: + Pointer(const Pointer ©) = delete; + + Pointer &operator =(const Pointer &assign) = delete; + + Pointer(T *pointer, bool alias) noexcept : + data_(pointer), alias_(alias) + { + } + + Pointer(class MemoryPoolHead *head) + { +#ifdef SEAL_DEBUG + if(!head) + { + throw std::invalid_argument("head cannot be nullptr"); + } +#endif + head_ = head; + item_ = head->get(); + data_ = reinterpret_cast(item_->data()); + SEAL_IF_CONSTEXPR (!std::is_trivially_constructible::value) + { + auto count = head_->item_byte_count() / sizeof(T); + for (auto alloc_ptr = data_; count--; alloc_ptr++) + { + new(alloc_ptr) T; + } + } + } + + template + Pointer(class MemoryPoolHead *head, Args &&...args) + { +#ifdef SEAL_DEBUG + if(!head) + { + throw std::invalid_argument("head cannot be nullptr"); + } +#endif + head_ = head; + item_ = head->get(); + data_ = reinterpret_cast(item_->data()); + auto count = head_->item_byte_count() / sizeof(T); + for (auto alloc_ptr = data_; count--; alloc_ptr++) + { + new(alloc_ptr) T(std::forward(args)...); + } + } + + T *data_ = nullptr; + + MemoryPoolHead *head_ = nullptr; + + MemoryPoolItem *item_ = nullptr; + + bool alias_ = false; + }; + + // Specialization for SEAL_BYTE + template<> + class ConstPointer + { + friend class MemoryPoolST; + friend class MemoryPoolMT; + + public: + template friend class ConstPointer; + + ConstPointer() = default; + + // Move of the same type + ConstPointer(Pointer &&source) noexcept : + data_(source.data_), head_(source.head_), + item_(source.item_), alias_(source.alias_) + { + source.data_ = nullptr; + source.head_ = nullptr; + source.item_ = nullptr; + source.alias_ = false; + } + + // Move of the same type + ConstPointer(Pointer &&source, SEAL_BYTE value) noexcept : + ConstPointer(std::move(source)) + { + std::fill_n(data_, head_->item_byte_count(), value); + } + + // Move of the same type + ConstPointer(ConstPointer &&source) noexcept : + data_(source.data_), head_(source.head_), + item_(source.item_), alias_(source.alias_) + { + source.data_ = nullptr; + source.head_ = nullptr; + source.item_ = nullptr; + source.alias_ = false; + } + + // Move of the same type + ConstPointer(ConstPointer &&source, SEAL_BYTE value) noexcept : + ConstPointer(std::move(source)) + { + std::fill_n(data_, head_->item_byte_count(), value); + } + + inline auto &operator =(ConstPointer &&assign) noexcept + { + acquire(std::move(assign)); + return *this; + } + + inline auto &operator =(Pointer &&assign) noexcept + { + acquire(std::move(assign)); + return *this; + } + + inline const SEAL_BYTE &operator [](std::size_t index) const + { + return data_[index]; + } + + inline bool is_set() const noexcept + { + return data_ != nullptr; + } + + inline const SEAL_BYTE *get() const noexcept + { + return data_; + } + + inline const SEAL_BYTE *operator ->() const noexcept + { + return data_; + } + + inline const SEAL_BYTE &operator *() const + { + return *data_; + } + + inline void release() noexcept + { + if (head_) + { + // Return the memory to pool + head_->add(item_); + } + else if (data_ && !alias_) + { + // Free the memory + delete[] data_; + } + + data_ = nullptr; + head_ = nullptr; + item_ = nullptr; + alias_ = false; + } + + void acquire(Pointer &other) noexcept + { + release(); + + data_ = other.data_; + head_ = other.head_; + item_ = other.item_; + alias_ = other.alias_; + other.data_ = nullptr; + other.head_ = nullptr; + other.item_ = nullptr; + other.alias_ = false; + } + + inline void acquire(Pointer &&other) noexcept + { + acquire(other); + } + + void acquire(ConstPointer &other) noexcept + { + if (this == &other) + { + return; + } + + release(); + + data_ = other.data_; + head_ = other.head_; + item_ = other.item_; + alias_ = other.alias_; + other.data_ = nullptr; + other.head_ = nullptr; + other.item_ = nullptr; + other.alias_ = false; + } + + void acquire(ConstPointer &&other) noexcept + { + acquire(other); + } + + inline void swap_with(ConstPointer &other) noexcept + { + std::swap(data_, other.data_); + std::swap(head_, other.head_); + std::swap(item_, other.item_); + std::swap(alias_, other.alias_); + } + + + inline void swap_with(ConstPointer &&other) noexcept + { + swap_with(other); + } + + ~ConstPointer() noexcept + { + release(); + } + + operator bool() const + { + return (data_ != nullptr); + } + + inline static auto Owning(SEAL_BYTE *pointer) noexcept + -> ConstPointer + { + return {pointer, false}; + } + + inline static auto Aliasing(const SEAL_BYTE *pointer) noexcept + -> ConstPointer + { + return {const_cast(pointer), true}; + } + + private: + ConstPointer(const ConstPointer ©) = delete; + + ConstPointer &operator =(const ConstPointer &assign) = delete; + + ConstPointer(SEAL_BYTE *pointer, bool alias) noexcept : + data_(pointer), alias_(alias) + { + } + + ConstPointer(class MemoryPoolHead *head) + { +#ifdef SEAL_DEBUG + if(!head) + { + throw std::invalid_argument("head cannot be nullptr"); + } +#endif + head_ = head; + item_ = head->get(); + data_ = item_->data(); + } + + SEAL_BYTE *data_ = nullptr; + + MemoryPoolHead *head_ = nullptr; + + MemoryPoolItem *item_ = nullptr; + + bool alias_ = false; + }; + + template + class ConstPointer + { + friend class MemoryPoolST; + friend class MemoryPoolMT; + + public: + ConstPointer() = default; + + // Move of the same type + ConstPointer(Pointer &&source) noexcept : + data_(source.data_), head_(source.head_), + item_(source.item_), alias_(source.alias_) + { + source.data_ = nullptr; + source.head_ = nullptr; + source.item_ = nullptr; + source.alias_ = false; + } + + // Move when T is not SEAL_BYTE + ConstPointer(Pointer &&source) + { + // Cannot acquire a non-pool pointer of different type + if (!source.head_ && source.data_) + { + throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); + } + + head_ = source.head_; + item_ = source.item_; + if (head_) + { + data_ = reinterpret_cast(item_->data()); + SEAL_IF_CONSTEXPR (!std::is_trivially_constructible::value) + { + auto count = head_->item_byte_count() / sizeof(T); + for (auto alloc_ptr = data_; count--; alloc_ptr++) + { + new(alloc_ptr) T; + } + } + } + alias_ = source.alias_; + + source.data_ = nullptr; + source.head_ = nullptr; + source.item_ = nullptr; + source.alias_ = false; + } + + // Move when T is not SEAL_BYTE + template + ConstPointer(Pointer &&source, Args &&...args) + { + // Cannot acquire a non-pool pointer of different type + if (!source.head_ && source.data_) + { + throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); + } + + head_ = source.head_; + item_ = source.item_; + if (head_) + { + data_ = reinterpret_cast(item_->data()); + auto count = head_->item_byte_count() / sizeof(T); + for (auto alloc_ptr = data_; count--; alloc_ptr++) + { + new(alloc_ptr) T(std::forward(args)...); + } + } + alias_ = source.alias_; + + source.data_ = nullptr; + source.head_ = nullptr; + source.item_ = nullptr; + source.alias_ = false; + } + + // Move of the same type + ConstPointer(ConstPointer &&source) noexcept : + data_(source.data_), head_(source.head_), + item_(source.item_), alias_(source.alias_) + { + source.data_ = nullptr; + source.head_ = nullptr; + source.item_ = nullptr; + source.alias_ = false; + } + + // Move when T is not SEAL_BYTE + ConstPointer(ConstPointer &&source) + { + // Cannot acquire a non-pool pointer of different type + if (!source.head_ && source.data_) + { + throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); + } + + head_ = source.head_; + item_ = source.item_; + if (head_) + { + data_ = reinterpret_cast(item_->data()); + SEAL_IF_CONSTEXPR (!std::is_trivially_constructible::value) + { + auto count = head_->item_byte_count() / sizeof(T); + for (auto alloc_ptr = data_; count--; alloc_ptr++) + { + new(alloc_ptr) T; + } + } + } + alias_ = source.alias_; + + source.data_ = nullptr; + source.head_ = nullptr; + source.item_ = nullptr; + source.alias_ = false; + } + + // Move when T is not SEAL_BYTE + template + ConstPointer(ConstPointer &&source, Args &&...args) + { + // Cannot acquire a non-pool pointer of different type + if (!source.head_ && source.data_) + { + throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); + } + + head_ = source.head_; + item_ = source.item_; + if (head_) + { + data_ = reinterpret_cast(item_->data()); + auto count = head_->item_byte_count() / sizeof(T); + for (auto alloc_ptr = data_; count--; alloc_ptr++) + { + new(alloc_ptr) T(std::forward(args)...); + } + } + alias_ = source.alias_; + + source.data_ = nullptr; + source.head_ = nullptr; + source.item_ = nullptr; + source.alias_ = false; + } + + inline auto &operator =(ConstPointer &&assign) noexcept + { + acquire(std::move(assign)); + return *this; + } + + inline auto &operator =(ConstPointer &&assign) + { + acquire(std::move(assign)); + return *this; + } + + inline auto &operator =(Pointer &&assign) noexcept + { + acquire(std::move(assign)); + return *this; + } + + inline auto &operator =(Pointer &&assign) + { + acquire(std::move(assign)); + return *this; + } + + inline const T &operator [](std::size_t index) const + { + return data_[index]; + } + + inline bool is_set() const noexcept + { + return data_ != nullptr; + } + + inline const T *get() const noexcept + { + return data_; + } + + inline const T *operator ->() const noexcept + { + return data_; + } + + inline const T &operator *() const + { + return *data_; + } + + inline void release() noexcept + { + if (head_) + { + SEAL_IF_CONSTEXPR (!std::is_trivially_destructible::value) + { + // Manual destructor calls + auto count = head_->item_byte_count() / sizeof(T); + for (auto alloc_ptr = data_; count--; alloc_ptr++) + { + alloc_ptr->~T(); + } + } + + // Return the memory to pool + head_->add(item_); + } + else if (data_ && !alias_) + { + // Free the memory + delete[] data_; + } + + data_ = nullptr; + head_ = nullptr; + item_ = nullptr; + alias_ = false; + } + + void acquire(ConstPointer &other) noexcept + { + if (this == &other) + { + return; + } + + release(); + + data_ = other.data_; + head_ = other.head_; + item_ = other.item_; + alias_ = other.alias_; + other.data_ = nullptr; + other.head_ = nullptr; + other.item_ = nullptr; + other.alias_ = false; + } + + inline void acquire(ConstPointer &&other) noexcept + { + acquire(other); + } + + void acquire(ConstPointer &other) + { + // Cannot acquire a non-pool pointer of different type + if (!other.head_ && other.data_) + { + throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); + } + + release(); + + head_ = other.head_; + item_ = other.item_; + if (head_) + { + data_ = reinterpret_cast(item_->data()); + SEAL_IF_CONSTEXPR (!std::is_trivially_constructible::value) + { + auto count = head_->item_byte_count() / sizeof(T); + for (auto alloc_ptr = data_; count--; alloc_ptr++) + { + new(alloc_ptr) T; + } + } + } + alias_ = other.alias_; + other.data_ = nullptr; + other.head_ = nullptr; + other.item_ = nullptr; + other.alias_ = false; + } + + inline void acquire(ConstPointer &&other) + { + acquire(other); + } + + void acquire(Pointer &other) noexcept + { + release(); + + data_ = other.data_; + head_ = other.head_; + item_ = other.item_; + alias_ = other.alias_; + other.data_ = nullptr; + other.head_ = nullptr; + other.item_ = nullptr; + other.alias_ = false; + } + + inline void acquire(Pointer &&other) noexcept + { + acquire(other); + } + + void acquire(Pointer &other) + { + // Cannot acquire a non-pool pointer of different type + if (!other.head_ && other.data_) + { + throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); + } + + release(); + + head_ = other.head_; + item_ = other.item_; + if (head_) + { + data_ = reinterpret_cast(item_->data()); + SEAL_IF_CONSTEXPR (!std::is_trivially_constructible::value) + { + auto count = head_->item_byte_count() / sizeof(T); + for (auto alloc_ptr = data_; count--; alloc_ptr++) + { + new(alloc_ptr) T; + } + } + } + alias_ = other.alias_; + other.data_ = nullptr; + other.head_ = nullptr; + other.item_ = nullptr; + other.alias_ = false; + } + + inline void acquire(Pointer &&other) + { + acquire(other); + } + + inline void swap_with(ConstPointer &other) noexcept + { + std::swap(data_, other.data_); + std::swap(head_, other.head_); + std::swap(item_, other.item_); + std::swap(alias_, other.alias_); + } + + inline void swap_with(ConstPointer &&other) noexcept + { + swap_with(other); + } + + ~ConstPointer() noexcept + { + release(); + } + + operator bool() const noexcept + { + return (data_ != nullptr); + } + + inline static ConstPointer Owning(T *pointer) noexcept + { + return {pointer, false}; + } + + inline static ConstPointer Aliasing(const T *pointer) noexcept + { + return {const_cast(pointer), true}; + } + + private: + ConstPointer(const ConstPointer ©) = delete; + + ConstPointer &operator =(const ConstPointer &assign) = delete; + + ConstPointer(T *pointer, bool alias) noexcept : data_(pointer), alias_(alias) + { + } + + ConstPointer(class MemoryPoolHead *head) + { +#ifdef SEAL_DEBUG + if(!head) + { + throw std::invalid_argument("head cannot be nullptr"); + } +#endif + head_ = head; + item_ = head->get(); + data_ = reinterpret_cast(item_->data()); + SEAL_IF_CONSTEXPR (!std::is_trivially_constructible::value) + { + auto count = head_->item_byte_count() / sizeof(T); + for (auto alloc_ptr = data_; count--; alloc_ptr++) + { + new(alloc_ptr) T; + } + } + } + + template + ConstPointer(class MemoryPoolHead *head, Args &&...args) + { +#ifdef SEAL_DEBUG + if(!head) + { + throw std::invalid_argument("head cannot be nullptr"); + } +#endif + head_ = head; + item_ = head->get(); + data_ = reinterpret_cast(item_->data()); + auto count = head_->item_byte_count() / sizeof(T); + for (auto alloc_ptr = data_; count--; alloc_ptr++) + { + new(alloc_ptr) T(std::forward(args)...); + } + } + + T *data_ = nullptr; + + MemoryPoolHead *head_ = nullptr; + + MemoryPoolItem *item_ = nullptr; + + bool alias_ = false; + }; + + // Allocate single element + template::value>> + inline auto allocate(MemoryPool &pool, Args &&...args) + { + using T_ = typename std::remove_cv::type>::type; + return Pointer(pool.get_for_byte_count(sizeof(T_)), + std::forward(args)...); + } + + // Allocate array of elements + template::value>> + inline auto allocate(std::size_t count, MemoryPool &pool, Args &&...args) + { + using T_ = typename std::remove_cv::type>::type; + return Pointer(pool.get_for_byte_count(util::mul_safe(count, sizeof(T_))), + std::forward(args)...); + } + + template::value>> + inline auto duplicate_if_needed(T *original, std::size_t count, + bool condition, MemoryPool &pool) + { + using T_ = typename std::remove_cv::type>::type; +#ifdef SEAL_DEBUG + if (original == nullptr && count > 0) + { + throw std::invalid_argument("original"); + } +#endif + if (condition == false) + { + return Pointer::Aliasing(original); + } + auto allocation(allocate(count, pool)); + std::copy_n(original, count, allocation.get()); + return allocation; + } + + template::value>> + inline auto duplicate_if_needed(const T *original, + std::size_t count, bool condition, MemoryPool &pool) + { + using T_ = typename std::remove_cv::type>::type; +#ifdef SEAL_DEBUG + if (original == nullptr && count > 0) + { + throw std::invalid_argument("original"); + } +#endif + if (condition == false) + { + return ConstPointer::Aliasing(original); + } + auto allocation(allocate(count, pool)); + std::copy_n(original, count, allocation.get()); + return ConstPointer(std::move(allocation)); + } + } +} diff --git a/src/seal/util/polyarith.cpp b/src/seal/util/polyarith.cpp new file mode 100644 index 000000000..524699d0f --- /dev/null +++ b/src/seal/util/polyarith.cpp @@ -0,0 +1,250 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/util/common.h" +#include "seal/util/uintcore.h" +#include "seal/util/uintarith.h" +#include "seal/util/uintarithmod.h" +#include "seal/util/polycore.h" +#include "seal/util/polyarith.h" + +using namespace std; + +namespace seal +{ + namespace util + { + void multiply_poly_poly(const uint64_t *operand1, + size_t operand1_coeff_count, size_t operand1_coeff_uint64_count, + const uint64_t *operand2, size_t operand2_coeff_count, + size_t operand2_coeff_uint64_count, size_t result_coeff_count, + size_t result_coeff_uint64_count, uint64_t *result, MemoryPool &pool) + { +#ifdef SEAL_DEBUG + if (operand1 == nullptr && operand1_coeff_count > 0 && + operand1_coeff_uint64_count > 0) + { + throw invalid_argument("operand1"); + } + if (operand2 == nullptr && operand2_coeff_count > 0 && + operand2_coeff_uint64_count > 0) + { + throw invalid_argument("operand2"); + } + if (result == nullptr && result_coeff_count > 0 && + result_coeff_uint64_count > 0) + { + throw invalid_argument("result"); + } + if (result != nullptr && + (operand1 == result || operand2 == result)) + { + throw invalid_argument("result cannot point to the same value as operand1 or operand2"); + } + if (!sum_fits_in(operand1_coeff_count, operand2_coeff_count)) + { + throw invalid_argument("operand1 and operand2 too large"); + } +#endif + auto intermediate(allocate_uint(result_coeff_uint64_count, pool)); + + // Clear product. + set_zero_poly(result_coeff_count, result_coeff_uint64_count, result); + + operand1_coeff_count = get_significant_coeff_count_poly( + operand1, operand1_coeff_count, operand1_coeff_uint64_count); + operand2_coeff_count = get_significant_coeff_count_poly( + operand2, operand2_coeff_count, operand2_coeff_uint64_count); + for (size_t operand1_index = 0; + operand1_index < operand1_coeff_count; operand1_index++) + { + const uint64_t *operand1_coeff = get_poly_coeff( + operand1, operand1_index, operand1_coeff_uint64_count); + for (size_t operand2_index = 0; + operand2_index < operand2_coeff_count; operand2_index++) + { + size_t product_coeff_index = operand1_index + operand2_index; + if (product_coeff_index >= result_coeff_count) + { + break; + } + + const uint64_t *operand2_coeff = get_poly_coeff( + operand2, operand2_index, operand2_coeff_uint64_count); + multiply_uint_uint(operand1_coeff, operand1_coeff_uint64_count, + operand2_coeff, operand2_coeff_uint64_count, + result_coeff_uint64_count, intermediate.get()); + uint64_t *result_coeff = get_poly_coeff( + result, product_coeff_index, result_coeff_uint64_count); + add_uint_uint(result_coeff, intermediate.get(), + result_coeff_uint64_count, result_coeff); + } + } + } + + void poly_eval_poly(const uint64_t *poly_to_eval, + size_t poly_to_eval_coeff_count, + size_t poly_to_eval_coeff_uint64_count, + const uint64_t *value, size_t value_coeff_count, + size_t value_coeff_uint64_count, size_t result_coeff_count, + size_t result_coeff_uint64_count, uint64_t *result, MemoryPool &pool) + { +#ifdef SEAL_DEBUG + if (poly_to_eval == nullptr) + { + throw invalid_argument("poly_to_eval"); + } + if (value == nullptr) + { + throw invalid_argument("value"); + } + if (result == nullptr) + { + throw invalid_argument("result"); + } + if (poly_to_eval_coeff_count == 0) + { + throw invalid_argument("poly_to_eval_coeff_count"); + } + if (poly_to_eval_coeff_uint64_count == 0) + { + throw invalid_argument("poly_to_eval_coeff_uint64_count"); + } + if (value_coeff_count == 0) + { + throw invalid_argument("value_coeff_count"); + } + if (value_coeff_uint64_count == 0) + { + throw invalid_argument("value_coeff_uint64_count"); + } + if (result_coeff_count == 0) + { + throw invalid_argument("result_coeff_count"); + } + if (result_coeff_uint64_count == 0) + { + throw invalid_argument("result_coeff_uint64_count"); + } +#endif + // Evaluate poly at value using Horner's method + auto temp1(allocate_poly(result_coeff_count, result_coeff_uint64_count, pool)); + auto temp2(allocate_zero_poly(result_coeff_count, result_coeff_uint64_count, pool)); + uint64_t *productptr = temp1.get(); + uint64_t *intermediateptr = temp2.get(); + + while (poly_to_eval_coeff_count--) + { + multiply_poly_poly(intermediateptr, result_coeff_count, + result_coeff_uint64_count, value, value_coeff_count, + value_coeff_uint64_count, result_coeff_count, + result_coeff_uint64_count, productptr, pool); + const uint64_t *curr_coeff = get_poly_coeff( + poly_to_eval, poly_to_eval_coeff_count, + poly_to_eval_coeff_uint64_count); + add_uint_uint(productptr, result_coeff_uint64_count, curr_coeff, + poly_to_eval_coeff_uint64_count, false, + result_coeff_uint64_count, productptr); + swap(productptr, intermediateptr); + } + set_poly_poly(intermediateptr, result_coeff_count, + result_coeff_uint64_count, result); + } + + void exponentiate_poly(const std::uint64_t *poly, size_t poly_coeff_count, + size_t poly_coeff_uint64_count, const uint64_t *exponent, + size_t exponent_uint64_count, size_t result_coeff_count, + size_t result_coeff_uint64_count, std::uint64_t *result, MemoryPool &pool) + { +#ifdef SEAL_DEBUG + if (poly == nullptr) + { + throw invalid_argument("poly"); + } + if (poly_coeff_count == 0) + { + throw invalid_argument("poly_coeff_count"); + } + if (poly_coeff_uint64_count == 0) + { + throw invalid_argument("poly_coeff_uint64_count"); + } + if (exponent == nullptr) + { + throw invalid_argument("exponent"); + } + if (exponent_uint64_count == 0) + { + throw invalid_argument("exponent_uint64_count"); + } + if (result == nullptr) + { + throw invalid_argument("result"); + } + if (result_coeff_count == 0) + { + throw invalid_argument("result_coeff_count"); + } + if (result_coeff_uint64_count == 0) + { + throw invalid_argument("result_coeff_uint64_count"); + } +#endif + // Fast cases + if (is_zero_uint(exponent, exponent_uint64_count)) + { + set_zero_poly(result_coeff_count, result_coeff_uint64_count, result); + *result = 1; + return; + } + if (is_equal_uint(exponent, exponent_uint64_count, 1)) + { + set_poly_poly(poly, poly_coeff_count, poly_coeff_uint64_count, + result_coeff_count, result_coeff_uint64_count, result); + return; + } + + // Need to make a copy of exponent + auto exponent_copy(allocate_uint(exponent_uint64_count, pool)); + set_uint_uint(exponent, exponent_uint64_count, exponent_copy.get()); + + // Perform binary exponentiation. + auto big_alloc(allocate_uint(mul_safe( + add_safe(result_coeff_count, result_coeff_count, result_coeff_count), + result_coeff_uint64_count), pool)); + + uint64_t *powerptr = big_alloc.get(); + uint64_t *productptr = get_poly_coeff( + powerptr, result_coeff_count, result_coeff_uint64_count); + uint64_t *intermediateptr = get_poly_coeff( + productptr, result_coeff_count, result_coeff_uint64_count); + + set_poly_poly(poly, poly_coeff_count, poly_coeff_uint64_count, result_coeff_count, + result_coeff_uint64_count, powerptr); + set_zero_poly(result_coeff_count, result_coeff_uint64_count, intermediateptr); + *intermediateptr = 1; + + // Initially: power = operand and intermediate = 1, product is not initialized. + while (true) + { + if ((*exponent_copy.get() % 2) == 1) + { + multiply_poly_poly(powerptr, result_coeff_count, result_coeff_uint64_count, + intermediateptr, result_coeff_count, result_coeff_uint64_count, + result_coeff_count, result_coeff_uint64_count, productptr, pool); + swap(productptr, intermediateptr); + } + right_shift_uint(exponent_copy.get(), 1, exponent_uint64_count, exponent_copy.get()); + if (is_zero_uint(exponent_copy.get(), exponent_uint64_count)) + { + break; + } + multiply_poly_poly(powerptr, result_coeff_count, result_coeff_uint64_count, + powerptr, result_coeff_count, result_coeff_uint64_count, + result_coeff_count, result_coeff_uint64_count, productptr, pool); + swap(productptr, powerptr); + } + set_poly_poly(intermediateptr, result_coeff_count, result_coeff_uint64_count, result); + } + } +} diff --git a/src/seal/util/polyarith.h b/src/seal/util/polyarith.h new file mode 100644 index 000000000..f5974d5e4 --- /dev/null +++ b/src/seal/util/polyarith.h @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include "seal/util/uintcore.h" +#include "seal/util/uintarith.h" +#include "seal/util/uintarithmod.h" +#include "seal/util/polycore.h" +#include "seal/util/pointer.h" + +namespace seal +{ + namespace util + { + inline void right_shift_poly_coeffs( + const std::uint64_t *poly, std::size_t coeff_count, + std::size_t coeff_uint64_count, int shift_amount, + std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (poly == nullptr && coeff_count > 0 && coeff_uint64_count > 0) + { + throw std::invalid_argument("poly"); + } +#endif + while (coeff_count--) + { + right_shift_uint(poly, shift_amount, coeff_uint64_count, result); + poly += coeff_uint64_count; + result += coeff_uint64_count; + } + } + + inline void negate_poly(const std::uint64_t *poly, + std::size_t coeff_count, std::size_t coeff_uint64_count, + std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (poly == nullptr && coeff_count > 0 && coeff_uint64_count > 0) + { + throw std::invalid_argument("poly"); + } + if (result == nullptr && coeff_count > 0 && coeff_uint64_count > 0) + { + throw std::invalid_argument("result"); + } +#endif + while(coeff_count--) + { + negate_uint(poly, coeff_uint64_count, result); + poly += coeff_uint64_count; + result += coeff_uint64_count; + } + } + + inline void add_poly_poly(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t coeff_count, + std::size_t coeff_uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (operand1 == nullptr && coeff_count > 0 && coeff_uint64_count > 0) + { + throw std::invalid_argument("operand1"); + } + if (operand2 == nullptr && coeff_count > 0 && coeff_uint64_count > 0) + { + throw std::invalid_argument("operand2"); + } + if (result == nullptr && coeff_count > 0 && coeff_uint64_count > 0) + { + throw std::invalid_argument("result"); + } +#endif + while(coeff_count--) + { + add_uint_uint(operand1, operand2, coeff_uint64_count, result); + operand1 += coeff_uint64_count; + operand2 += coeff_uint64_count; + result += coeff_uint64_count; + } + } + + inline void sub_poly_poly(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t coeff_count, + std::size_t coeff_uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (operand1 == nullptr && coeff_count > 0 && coeff_uint64_count > 0) + { + throw std::invalid_argument("operand1"); + } + if (operand2 == nullptr && coeff_count > 0 && coeff_uint64_count > 0) + { + throw std::invalid_argument("operand2"); + } + if (result == nullptr && coeff_count > 0 && coeff_uint64_count > 0) + { + throw std::invalid_argument("result"); + } +#endif + while(coeff_count--) + { + sub_uint_uint(operand1, operand2, coeff_uint64_count, result); + operand1 += coeff_uint64_count; + operand2 += coeff_uint64_count; + result += coeff_uint64_count; + } + } + + void multiply_poly_poly( + const std::uint64_t *operand1, std::size_t operand1_coeff_count, + std::size_t operand1_coeff_uint64_count, const std::uint64_t *operand2, + std::size_t operand2_coeff_count, std::size_t operand2_coeff_uint64_count, + std::size_t result_coeff_count, std::size_t result_coeff_uint64_count, + std::uint64_t *result, MemoryPool &pool); + + inline void poly_infty_norm(const std::uint64_t *poly, + std::size_t coeff_count, std::size_t coeff_uint64_count, + std::uint64_t *result) + { + set_zero_uint(coeff_uint64_count, result); + while(coeff_count--) + { + if (is_greater_than_uint_uint(poly, result, coeff_uint64_count)) + { + set_uint_uint(poly, coeff_uint64_count, result); + } + + poly += coeff_uint64_count; + } + } + + void poly_eval_poly(const std::uint64_t *poly_to_eval, + std::size_t poly_to_eval_coeff_count, + std::size_t poly_to_eval_coeff_uint64_count, const std::uint64_t *value, + std::size_t value_coeff_count, std::size_t value_coeff_uint64_count, + std::size_t result_coeff_count, std::size_t result_coeff_uint64_count, + std::uint64_t *result, MemoryPool &pool); + + void exponentiate_poly(const std::uint64_t *poly, std::size_t poly_coeff_count, + std::size_t poly_coeff_uint64_count, const std::uint64_t *exponent, + std::size_t exponent_uint64_count, std::size_t result_coeff_count, + std::size_t result_coeff_uint64_count, std::uint64_t *result, MemoryPool &pool); + } +} diff --git a/src/seal/util/polyarithmod.cpp b/src/seal/util/polyarithmod.cpp new file mode 100644 index 000000000..38f35dd26 --- /dev/null +++ b/src/seal/util/polyarithmod.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/util/uintcore.h" +#include "seal/util/uintarith.h" +#include "seal/util/uintarithmod.h" +#include "seal/util/polycore.h" +#include "seal/util/polyarith.h" +#include "seal/util/polyarithmod.h" + +using namespace std; + +namespace seal +{ + namespace util + { + void poly_infty_norm_coeffmod(const uint64_t *poly, size_t coeff_count, + size_t coeff_uint64_count, const uint64_t *modulus, uint64_t *result, + MemoryPool &pool) + { + // Construct negative threshold (first negative modulus value) to compute + // absolute values of coeffs. + auto modulus_neg_threshold(allocate_uint(coeff_uint64_count, pool)); + + // Set to value of (modulus + 1) / 2. To prevent overflowing with the +1, just + // add 1 to the result if modulus was odd. + half_round_up_uint(modulus, coeff_uint64_count, modulus_neg_threshold.get()); + + // Mod out the poly coefficients and choose a symmetric representative from + // [-modulus,modulus). Keep track of the max. + set_zero_uint(coeff_uint64_count, result); + auto coeff_abs_value(allocate_uint(coeff_uint64_count, pool)); + for (size_t i = 0; i < coeff_count; i++, poly += coeff_uint64_count) + { + if (is_greater_than_or_equal_uint_uint( + poly, modulus_neg_threshold.get(), coeff_uint64_count)) + { + sub_uint_uint(modulus, poly, coeff_uint64_count, coeff_abs_value.get()); + } + else + { + set_uint_uint(poly, coeff_uint64_count, coeff_abs_value.get()); + } + if (is_greater_than_uint_uint(coeff_abs_value.get(), result, + coeff_uint64_count)) + { + set_uint_uint(coeff_abs_value.get(), coeff_uint64_count, result); + } + } + } + } +} diff --git a/src/seal/util/polyarithmod.h b/src/seal/util/polyarithmod.h new file mode 100644 index 000000000..1899ae0fe --- /dev/null +++ b/src/seal/util/polyarithmod.h @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include "seal/util/pointer.h" +#include "seal/util/polycore.h" +#include "seal/util/uintarithmod.h" + +namespace seal +{ + namespace util + { + inline void negate_poly_coeffmod(const std::uint64_t *poly, + std::size_t coeff_count, const std::uint64_t *coeff_modulus, + std::size_t coeff_uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (poly == nullptr && coeff_count > 0) + { + throw std::invalid_argument("poly"); + } + if (coeff_modulus == nullptr) + { + throw std::invalid_argument("coeff_modulus"); + } + if (coeff_uint64_count == 0) + { + throw std::invalid_argument("coeff_uint64_count"); + } + if (result == nullptr && coeff_count > 0) + { + throw std::invalid_argument("result"); + } +#endif + for (std::size_t i = 0; i < coeff_count; i++) + { + negate_uint_mod(poly, coeff_modulus, coeff_uint64_count, result); + poly += coeff_uint64_count; + result += coeff_uint64_count; + } + } + + inline void add_poly_poly_coeffmod(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t coeff_count, + const std::uint64_t *coeff_modulus, std::size_t coeff_uint64_count, + std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (operand1 == nullptr && coeff_count > 0) + { + throw std::invalid_argument("operand1"); + } + if (operand2 == nullptr && coeff_count > 0) + { + throw std::invalid_argument("operand2"); + } + if (coeff_modulus == nullptr) + { + throw std::invalid_argument("coeff_modulus"); + } + if (coeff_uint64_count == 0) + { + throw std::invalid_argument("coeff_uint64_count"); + } + if (result == nullptr && coeff_count > 0) + { + throw std::invalid_argument("result"); + } +#endif + for (std::size_t i = 0; i < coeff_count; i++) + { + add_uint_uint_mod(operand1, operand2, coeff_modulus, + coeff_uint64_count, result); + operand1 += coeff_uint64_count; + operand2 += coeff_uint64_count; + result += coeff_uint64_count; + } + } + + inline void sub_poly_poly_coeffmod(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t coeff_count, + const std::uint64_t *coeff_modulus, std::size_t coeff_uint64_count, + std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (operand1 == nullptr && coeff_count > 0) + { + throw std::invalid_argument("operand1"); + } + if (operand2 == nullptr && coeff_count > 0) + { + throw std::invalid_argument("operand2"); + } + if (coeff_modulus == nullptr) + { + throw std::invalid_argument("coeff_modulus"); + } + if (coeff_uint64_count == 0) + { + throw std::invalid_argument("coeff_uint64_count"); + } + if (result == nullptr && coeff_count > 0) + { + throw std::invalid_argument("result"); + } +#endif + for (std::size_t i = 0; i < coeff_count; i++) + { + sub_uint_uint_mod(operand1, operand2, coeff_modulus, + coeff_uint64_count, result); + operand1 += coeff_uint64_count; + operand2 += coeff_uint64_count; + result += coeff_uint64_count; + } + } + + void poly_infty_norm_coeffmod(const std::uint64_t *poly, + std::size_t coeff_count, std::size_t coeff_uint64_count, + const std::uint64_t *modulus, std::uint64_t *result, MemoryPool &pool); + } +} diff --git a/src/seal/util/polyarithsmallmod.cpp b/src/seal/util/polyarithsmallmod.cpp new file mode 100644 index 000000000..f11b14e27 --- /dev/null +++ b/src/seal/util/polyarithsmallmod.cpp @@ -0,0 +1,714 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/util/common.h" +#include "seal/util/uintcore.h" +#include "seal/util/uintarithsmallmod.h" +#include "seal/util/uintarith.h" +#include "seal/util/polycore.h" +#include "seal/util/polyarith.h" +#include "seal/util/polyarithsmallmod.h" +#include "seal/util/defines.h" + +using namespace std; + +namespace seal +{ + namespace util + { + void multiply_poly_scalar_coeffmod(const uint64_t *poly, + size_t coeff_count, uint64_t scalar, const SmallModulus &modulus, + uint64_t *result) + { +#ifdef SEAL_DEBUG + if (poly == nullptr && coeff_count > 0) + { + throw invalid_argument("poly"); + } + if (result == nullptr && coeff_count > 0) + { + throw invalid_argument("result"); + } + if (modulus.is_zero()) + { + throw invalid_argument("modulus"); + } +#endif + // Explicit inline + //for (int i = 0; i < coeff_count; i++) + //{ + // *result++ = multiply_uint_uint_mod(*poly++, scalar, modulus); + //} + const uint64_t modulus_value = modulus.value(); + const uint64_t const_ratio_0 = modulus.const_ratio()[0]; + const uint64_t const_ratio_1 = modulus.const_ratio()[1]; + for (; coeff_count--; poly++, result++) + { + unsigned long long z[2], tmp1, tmp2[2], tmp3, carry; + multiply_uint64(*poly, scalar, z); + + // Reduces z using base 2^64 Barrett reduction + + // Multiply input and const_ratio + // Round 1 + multiply_uint64_hw64(z[0], const_ratio_0, &carry); + multiply_uint64(z[0], const_ratio_1, tmp2); + tmp3 = tmp2[1] + add_uint64(tmp2[0], carry, &tmp1); + + // Round 2 + multiply_uint64(z[1], const_ratio_0, tmp2); + carry = tmp2[1] + add_uint64(tmp1, tmp2[0], &tmp1); + + // This is all we care about + tmp1 = z[1] * const_ratio_1 + tmp3 + carry; + + // Barrett subtraction + tmp3 = z[0] - tmp1 * modulus_value; + + // Claim: One more subtraction is enough + *result = tmp3 - (modulus_value & static_cast( + -static_cast(tmp3 >= modulus_value))); + } + } + + void multiply_poly_poly_coeffmod(const uint64_t *operand1, + size_t operand1_coeff_count, const uint64_t *operand2, + size_t operand2_coeff_count, const SmallModulus &modulus, + size_t result_coeff_count, uint64_t *result) + { +#ifdef SEAL_DEBUG + if (operand1 == nullptr && operand1_coeff_count > 0) + { + throw invalid_argument("operand1"); + } + if (operand2 == nullptr && operand2_coeff_count > 0) + { + throw invalid_argument("operand2"); + } + if (result == nullptr && result_coeff_count > 0) + { + throw invalid_argument("result"); + } + if (result != nullptr && (operand1 == result || operand2 == result)) + { + throw invalid_argument("result cannot point to the same value as operand1, operand2, or modulus"); + } + if (modulus.is_zero()) + { + throw invalid_argument("modulus"); + } + if (!sum_fits_in(operand1_coeff_count, operand2_coeff_count)) + { + throw invalid_argument("operand1 and operand2 too large"); + } +#endif + // Clear product. + set_zero_uint(result_coeff_count, result); + + operand1_coeff_count = get_significant_coeff_count_poly( + operand1, operand1_coeff_count, 1); + operand2_coeff_count = get_significant_coeff_count_poly( + operand2, operand2_coeff_count, 1); + for (size_t operand1_index = 0; + operand1_index < operand1_coeff_count; operand1_index++) + { + if (operand1[operand1_index] == 0) + { + // If coefficient is 0, then move on to next coefficient. + continue; + } + // Do expensive add + for (size_t operand2_index = 0; + operand2_index < operand2_coeff_count; operand2_index++) + { + size_t product_coeff_index = operand1_index + operand2_index; + if (product_coeff_index >= result_coeff_count) + { + break; + } + + if (operand2[operand2_index] == 0) + { + // If coefficient is 0, then move on to next coefficient. + continue; + } + + // Lazy reduction + unsigned long long temp[2]; + multiply_uint64(operand1[operand1_index], operand2[operand2_index], temp); + temp[1] += add_uint64(temp[0], result[product_coeff_index], 0, temp); + result[product_coeff_index] = barrett_reduce_128(temp, modulus); + } + } + } + + void multiply_poly_poly_coeffmod(const uint64_t *operand1, + const uint64_t *operand2, size_t coeff_count, + const SmallModulus &modulus, uint64_t *result) + { +#ifdef SEAL_DEBUG + if (operand1 == nullptr && coeff_count > 0) + { + throw invalid_argument("operand1"); + } + if (operand2 == nullptr && coeff_count > 0) + { + throw invalid_argument("operand2"); + } + if (result == nullptr && coeff_count > 0) + { + throw invalid_argument("result"); + } + if (result != nullptr && (operand1 == result || operand2 == result)) + { + throw invalid_argument("result cannot point to the same value as operand1, operand2, or modulus"); + } + if (modulus.is_zero()) + { + throw invalid_argument("modulus"); + } +#endif + size_t result_coeff_count = coeff_count + coeff_count - 1; + + // Clear product. + set_zero_uint(result_coeff_count, result); + + for (size_t operand1_index = 0; operand1_index < coeff_count; operand1_index++) + { + if (operand1[operand1_index] == 0) + { + // If coefficient is 0, then move on to next coefficient. + continue; + } + // Lastly, do more expensive add if other cases don't handle it. + for (size_t operand2_index = 0; operand2_index < coeff_count; operand2_index++) + { + uint64_t operand2_coeff = operand2[operand2_index]; + if (operand2_coeff == 0) + { + // If coefficient is 0, then move on to next coefficient. + continue; + } + + // Lazy reduction + unsigned long long temp[2]; + multiply_uint64(operand1[operand1_index], operand2_coeff, temp); + temp[1] += add_uint64(temp[0], result[operand1_index + operand2_index], 0, temp); + + result[operand1_index + operand2_index] = barrett_reduce_128(temp, modulus); + } + } + } + + void divide_poly_poly_coeffmod_inplace(uint64_t *numerator, + const uint64_t *denominator, size_t coeff_count, + const SmallModulus &modulus, uint64_t *quotient) + { +#ifdef SEAL_DEBUG + if (numerator == nullptr) + { + throw invalid_argument("numerator"); + } + if (denominator == nullptr) + { + throw invalid_argument("denominator"); + } + if (is_zero_poly(denominator, coeff_count, modulus.uint64_count())) + { + throw invalid_argument("denominator"); + } + if (quotient == nullptr) + { + throw invalid_argument("quotient"); + } + if (numerator == quotient || denominator == quotient) + { + throw invalid_argument("quotient cannot point to same value as numerator or denominator"); + } + if (numerator == denominator) + { + throw invalid_argument("numerator cannot point to same value as denominator"); + } + if (modulus.is_zero()) + { + throw invalid_argument("modulus"); + } +#endif + // Clear quotient. + set_zero_uint(coeff_count, quotient); + + // Determine most significant coefficients of numerator and denominator. + size_t numerator_coeffs = get_significant_uint64_count_uint( + numerator, coeff_count); + size_t denominator_coeffs = get_significant_uint64_count_uint( + denominator, coeff_count); + + // If numerator has lesser degree than denominator, then done. + if (numerator_coeffs < denominator_coeffs) + { + return; + } + + // Create scalar to store value that makes denominator monic. + uint64_t monic_denominator_scalar; + + // Create temporary scalars used during calculation of quotient. + // Both are purposely twice as wide to store intermediate product prior to modulo operation. + uint64_t temp_quotient; + uint64_t subtrahend; + + // Determine scalar necessary to make denominator monic. + uint64_t leading_denominator_coeff = denominator[denominator_coeffs - 1]; + if (!try_invert_uint_mod(leading_denominator_coeff, modulus, monic_denominator_scalar)) + { + throw invalid_argument("modulus is not coprime with leading denominator coefficient"); + } + + // Perform coefficient-wise division algorithm. + while (numerator_coeffs >= denominator_coeffs) + { + // Determine leading numerator coefficient. + uint64_t leading_numerator_coeff = numerator[numerator_coeffs - 1]; + + // If leading numerator coefficient is not zero, then need to make zero by subtraction. + if (leading_numerator_coeff) + { + // Determine shift necesarry to bring significant coefficients in alignment. + size_t denominator_shift = numerator_coeffs - denominator_coeffs; + + // Determine quotient's coefficient, which is scalar that makes + // denominator's leading coefficient one multiplied by leading + // coefficient of denominator (which when subtracted will zero + // out the topmost denominator coefficient). + uint64_t "ient_coeff = quotient[denominator_shift]; + temp_quotient = multiply_uint_uint_mod( + monic_denominator_scalar, leading_numerator_coeff, modulus); + quotient_coeff = temp_quotient; + + // Subtract numerator and quotient*denominator (shifted by denominator_shift). + for (size_t denominator_coeff_index = 0; + denominator_coeff_index < denominator_coeffs; denominator_coeff_index++) + { + // Multiply denominator's coefficient by quotient. + uint64_t denominator_coeff = denominator[denominator_coeff_index]; + subtrahend = multiply_uint_uint_mod(temp_quotient, denominator_coeff, modulus); + + // Subtract numerator with resulting product, appropriately shifted by denominator shift. + uint64_t &numerator_coeff = numerator[denominator_coeff_index + denominator_shift]; + numerator_coeff = sub_uint_uint_mod(numerator_coeff, subtrahend, modulus); + } + } + + // Top numerator coefficient must now be zero, so adjust coefficient count. + numerator_coeffs--; + } + } + + void apply_galois(const uint64_t *input, int coeff_count_power, + uint64_t galois_elt, const SmallModulus &modulus, uint64_t *result) + { +#ifdef SEAL_DEBUG + if (input == nullptr) + { + throw invalid_argument("input"); + } + if (result == nullptr) + { + throw invalid_argument("result"); + } + if (input == result) + { + throw invalid_argument("result cannot point to the same value as input"); + } + if (coeff_count_power < get_power_of_two(SEAL_POLY_MOD_DEGREE_MIN) || + coeff_count_power > get_power_of_two(SEAL_POLY_MOD_DEGREE_MAX)) + { + throw invalid_argument("coeff_count_power"); + } + // Verify coprime conditions. + if (!(galois_elt & 1) || + (galois_elt >= 2 * (uint64_t(1) << coeff_count_power))) + { + throw invalid_argument("galois element is not valid"); + } + if (modulus.is_zero()) + { + throw invalid_argument("modulus"); + } +#endif + const uint64_t modulus_value = modulus.value(); + uint64_t coeff_count_minus_one = (uint64_t(1) << coeff_count_power) - 1; + for (uint64_t i = 0; i <= coeff_count_minus_one; i++) + { + uint64_t index_raw = i * galois_elt; + uint64_t index = index_raw & coeff_count_minus_one; + uint64_t result_value = *input++; + if ((index_raw >> coeff_count_power) & 1) + { + // Explicit inline + //result[index] = negate_uint_mod(result[index], modulus); + int64_t non_zero = (result_value != 0); + result_value = (modulus_value - result_value) & + static_cast(-non_zero); + } + result[index] = result_value; + } + } + + void apply_galois_ntt(const uint64_t *input, int coeff_count_power, + uint64_t galois_elt, uint64_t *result) + { +#ifdef SEAL_DEBUG + if (input == nullptr) + { + throw invalid_argument("input"); + } + if (result == nullptr) + { + throw invalid_argument("result"); + } + if (input == result) + { + throw invalid_argument("result cannot point to the same value as input"); + } + if (coeff_count_power <= 0) + { + throw invalid_argument("coeff_count_power"); + } + // Verify coprime conditions. + if (!(galois_elt & 1) || + (galois_elt >= 2 * (uint64_t(1) << coeff_count_power))) + { + throw invalid_argument("galois element is not valid"); + } +#endif + size_t coeff_count = size_t(1) << coeff_count_power; + uint64_t m_minus_one = 2 * coeff_count - 1; + for (size_t i = 0; i < coeff_count; i++) + { + uint64_t reversed = reverse_bits(i, coeff_count_power); + uint64_t index_raw = galois_elt * (2 * reversed + 1); + index_raw &= m_minus_one; + uint64_t index = reverse_bits((index_raw - 1) >> 1, coeff_count_power); + result[i] = input[index]; + } + } + + void dyadic_product_coeffmod(const uint64_t *operand1, + const uint64_t *operand2, size_t coeff_count, + const SmallModulus &modulus, uint64_t *result) + { +#ifdef SEAL_DEBUG + if (operand1 == nullptr) + { + throw invalid_argument("operand1"); + } + if (operand2 == nullptr) + { + throw invalid_argument("operand2"); + } + if (result == nullptr) + { + throw invalid_argument("result"); + } + if (coeff_count == 0) + { + throw invalid_argument("coeff_count"); + } + if (modulus.is_zero()) + { + throw invalid_argument("modulus"); + } +#endif + // Explicit inline + //for (int i = 0; i < coeff_count; i++) + //{ + // *result++ = multiply_uint_uint_mod(*operand1++, *operand2++, modulus); + //} + const uint64_t modulus_value = modulus.value(); + const uint64_t const_ratio_0 = modulus.const_ratio()[0]; + const uint64_t const_ratio_1 = modulus.const_ratio()[1]; + for (; coeff_count--; operand1++, operand2++, result++) + { + // Reduces z using base 2^64 Barrett reduction + unsigned long long z[2], tmp1, tmp2[2], tmp3, carry; + multiply_uint64(*operand1, *operand2, z); + + // Multiply input and const_ratio + // Round 1 + multiply_uint64_hw64(z[0], const_ratio_0, &carry); + multiply_uint64(z[0], const_ratio_1, tmp2); + tmp3 = tmp2[1] + add_uint64(tmp2[0], carry, &tmp1); + + // Round 2 + multiply_uint64(z[1], const_ratio_0, tmp2); + carry = tmp2[1] + add_uint64(tmp1, tmp2[0], &tmp1); + + // This is all we care about + tmp1 = z[1] * const_ratio_1 + tmp3 + carry; + + // Barrett subtraction + tmp3 = z[0] - tmp1 * modulus_value; + + // Claim: One more subtraction is enough + *result = tmp3 - (modulus_value & static_cast( + -static_cast(tmp3 >= modulus_value))); + } + } + + uint64_t poly_infty_norm_coeffmod(const uint64_t *operand, + size_t coeff_count, const SmallModulus &modulus) + { +#ifdef SEAL_DEBUG + if (operand == nullptr && coeff_count > 0) + { + throw invalid_argument("operand"); + } + if (modulus.is_zero()) + { + throw invalid_argument("modulus"); + } +#endif + // Construct negative threshold (first negative modulus value) to compute absolute values of coeffs. + uint64_t modulus_neg_threshold = (modulus.value() + 1) >> 1; + + // Mod out the poly coefficients and choose a symmetric representative from + // [-modulus,modulus). Keep track of the max. + uint64_t result = 0; + for (size_t coeff_index = 0; coeff_index < coeff_count; coeff_index++) + { + uint64_t poly_coeff = operand[coeff_index] % modulus.value(); + if (poly_coeff >= modulus_neg_threshold) + { + poly_coeff = modulus.value() - poly_coeff; + } + if (poly_coeff > result) + { + result = poly_coeff; + } + } + return result; + } + + bool try_invert_poly_coeffmod(const uint64_t *operand, const uint64_t *poly_modulus, + size_t coeff_count, const SmallModulus &modulus, uint64_t *result, MemoryPool &pool) + { +#ifdef SEAL_DEBUG + if (operand == nullptr) + { + throw invalid_argument("operand"); + } + if (poly_modulus == nullptr) + { + throw invalid_argument("poly_modulus"); + } + if (coeff_count == 0) + { + throw invalid_argument("coeff_count"); + } + if (result == nullptr) + { + throw invalid_argument("result"); + } + if (get_significant_uint64_count_uint(operand, coeff_count) >= + get_significant_uint64_count_uint(poly_modulus, coeff_count)) + { + throw out_of_range("operand"); + } + if (modulus.is_zero()) + { + throw invalid_argument("modulus"); + } +#endif + // Cannot invert 0 poly. + if (is_zero_poly(operand, coeff_count, size_t(1))) + { + return false; + } + + // Construct a mutable copy of operand and modulus, with numerator being modulus + // and operand being denominator. Notice that degree(numerator) >= degree(denominator). + auto numerator_anchor(allocate_uint(coeff_count, pool)); + uint64_t *numerator = numerator_anchor.get(); + set_uint_uint(poly_modulus, coeff_count, numerator); + auto denominator_anchor(allocate_uint(coeff_count, pool)); + uint64_t *denominator = denominator_anchor.get(); + set_uint_uint(operand, coeff_count, denominator); + + // Determine most significant coefficients of each. + size_t numerator_coeffs = get_significant_coeff_count_poly( + numerator, coeff_count, size_t(1)); + size_t denominator_coeffs = get_significant_coeff_count_poly( + denominator, coeff_count, size_t(1)); + + // Create poly to store quotient. + auto quotient(allocate_uint(coeff_count, pool)); + + // Create scalar to store value that makes denominator monic. + uint64_t monic_denominator_scalar; + + // Create temporary scalars used during calculation of quotient. + // Both are purposely twice as wide to store intermediate product prior to modulo operation. + uint64_t temp_quotient; + uint64_t subtrahend; + + // Create three polynomials to store inverse. + // Initialize invert_prior to 0 and invert_curr to 1. + auto invert_prior_anchor(allocate_uint(coeff_count, pool)); + uint64_t *invert_prior = invert_prior_anchor.get(); + set_zero_uint(coeff_count, invert_prior); + auto invert_curr_anchor(allocate_uint(coeff_count, pool)); + uint64_t *invert_curr = invert_curr_anchor.get(); + set_zero_uint(coeff_count, invert_curr); + invert_curr[0] = 1; + auto invert_next_anchor(allocate_uint(coeff_count, pool)); + uint64_t *invert_next = invert_next_anchor.get(); + + // Perform extended Euclidean algorithm. + while (true) + { + // NOTE: degree(numerator) >= degree(denominator). + + // Determine scalar necessary to make denominator monic. + uint64_t leading_denominator_coeff = + denominator[denominator_coeffs - 1]; + if (!try_invert_uint_mod(leading_denominator_coeff, modulus, + monic_denominator_scalar)) + { + throw invalid_argument("modulus is not coprime with leading denominator coefficient"); + } + + // Clear quotient. + set_zero_uint(coeff_count, quotient.get()); + + // Perform coefficient-wise division algorithm. + while (numerator_coeffs >= denominator_coeffs) + { + // Determine leading numerator coefficient. + uint64_t leading_numerator_coeff = numerator[numerator_coeffs - 1]; + + // If leading numerator coefficient is not zero, then need to make zero by subtraction. + if (leading_numerator_coeff) + { + // Determine shift necessary to bring significant coefficients in alignment. + size_t denominator_shift = numerator_coeffs - denominator_coeffs; + + // Determine quotient's coefficient, which is scalar that makes + // denominator's leading coefficient one multiplied by leading + // coefficient of denominator (which when subtracted will zero + // out the topmost denominator coefficient). + uint64_t "ient_coeff = quotient[denominator_shift]; + temp_quotient = multiply_uint_uint_mod( + monic_denominator_scalar, leading_numerator_coeff, modulus); + quotient_coeff = temp_quotient; + + // Subtract numerator and quotient*denominator (shifted by denominator_shift). + for (size_t denominator_coeff_index = 0; + denominator_coeff_index < denominator_coeffs; + denominator_coeff_index++) + { + // Multiply denominator's coefficient by quotient. + uint64_t denominator_coeff = denominator[denominator_coeff_index]; + subtrahend = multiply_uint_uint_mod(temp_quotient, denominator_coeff, modulus); + + // Subtract numerator with resulting product, appropriately shifted by + // denominator shift. + uint64_t &numerator_coeff = numerator[denominator_coeff_index + denominator_shift]; + numerator_coeff = sub_uint_uint_mod(numerator_coeff, subtrahend, modulus); + } + } + + // Top numerator coefficient must now be zero, so adjust coefficient count. + numerator_coeffs--; + } + + // Double check that numerator coefficients is correct because possible + // other coefficients are zero. + numerator_coeffs = get_significant_coeff_count_poly( + numerator, coeff_count, size_t(1)); + + // We are done if numerator is zero. + if (numerator_coeffs == 0) + { + break; + } + + // Integrate quotient with invert coefficients. + // Calculate: invert_next = invert_prior + -quotient * invert_curr + multiply_truncate_poly_poly_coeffmod(quotient.get(), invert_curr, + coeff_count, modulus, invert_next); + sub_poly_poly_coeffmod(invert_prior, invert_next, coeff_count, + modulus, invert_next); + + // Swap prior and curr, and then curr and next. + swap(invert_prior, invert_curr); + swap(invert_curr, invert_next); + + // Swap numerator and denominator. + swap(numerator, denominator); + swap(numerator_coeffs, denominator_coeffs); + } + + // Polynomial is invertible only if denominator is just a scalar. + if (denominator_coeffs != 1) + { + return false; + } + + // Determine scalar necessary to make denominator monic. + uint64_t leading_denominator_coeff = denominator[0]; + if (!try_invert_uint_mod(leading_denominator_coeff, modulus, + monic_denominator_scalar)) + { + throw invalid_argument("modulus is not coprime with leading denominator coefficient"); + } + + // Multiply inverse by scalar and done. + multiply_poly_scalar_coeffmod(invert_curr, coeff_count, + monic_denominator_scalar, modulus, result); + return true; + } + + void negacyclic_shift_poly_coeffmod(const uint64_t *operand, + size_t coeff_count, size_t shift, const SmallModulus &modulus, + uint64_t *result) + { +#ifdef SEAL_DEBUG + if (operand == nullptr && coeff_count > 0) + { + throw invalid_argument("operand"); + } + if (result == nullptr && coeff_count > 0) + { + throw invalid_argument("result"); + } + if (operand == result && coeff_count > 0) + { + throw invalid_argument("operand cannot point to the same location as result"); + } + if (modulus.is_zero()) + { + throw invalid_argument("modulus"); + } + if (util::get_power_of_two(static_cast(coeff_count)) < 0) + { + throw invalid_argument("coeff_count"); + } +#endif + uint64_t index_raw = shift; + uint64_t coeff_count_mod_mask = static_cast(coeff_count) - 1; + for (size_t i = 0; i < coeff_count; i++, operand++, index_raw++) + { + uint64_t index = index_raw & coeff_count_mod_mask; + if (!(index_raw & static_cast(coeff_count)) || !*operand) + { + result[index] = *operand; + } + else + { + result[index] = modulus.value() - *operand; + } + } + } + } +} diff --git a/src/seal/util/polyarithsmallmod.h b/src/seal/util/polyarithsmallmod.h new file mode 100644 index 000000000..62e8f0707 --- /dev/null +++ b/src/seal/util/polyarithsmallmod.h @@ -0,0 +1,217 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include "seal/smallmodulus.h" +#include "seal/util/common.h" +#include "seal/util/polycore.h" +#include "seal/util/uintarithsmallmod.h" +#include "seal/util/pointer.h" + +namespace seal +{ + namespace util + { + inline void modulo_poly_coeffs(const std::uint64_t *poly, + std::size_t coeff_count, const SmallModulus &modulus, + std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (poly == nullptr && coeff_count > 0) + { + throw std::invalid_argument("poly"); + } + if (result == nullptr && coeff_count > 0) + { + throw std::invalid_argument("result"); + } + if (modulus.is_zero()) + { + throw std::invalid_argument("modulus"); + } +#endif + const std::uint64_t modulus_value = modulus.value(); + std::transform(poly, poly + coeff_count, result, + [&](auto coeff) { return coeff % modulus_value; }); + } + + inline void negate_poly_coeffmod(const std::uint64_t *poly, + std::size_t coeff_count, const SmallModulus &modulus, + std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (poly == nullptr && coeff_count > 0) + { + throw std::invalid_argument("poly"); + } + if (modulus.is_zero()) + { + throw std::invalid_argument("modulus"); + } + if (result == nullptr && coeff_count > 0) + { + throw std::invalid_argument("result"); + } +#endif + const uint64_t modulus_value = modulus.value(); + for (; coeff_count--; poly++, result++) + { + // Explicit inline + //*result = negate_uint_mod(*poly, modulus); +#ifdef SEAL_DEBUG + if (*poly >= modulus_value) + { + throw std::out_of_range("poly"); + } +#endif + std::int64_t non_zero = (*poly != 0); + *result = (modulus_value - *poly) & + static_cast(-non_zero); + } + } + + inline void add_poly_poly_coeffmod(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t coeff_count, + const SmallModulus &modulus, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (operand1 == nullptr && coeff_count > 0) + { + throw std::invalid_argument("operand1"); + } + if (operand2 == nullptr && coeff_count > 0) + { + throw std::invalid_argument("operand2"); + } + if (modulus.is_zero()) + { + throw std::invalid_argument("modulus"); + } + if (result == nullptr && coeff_count > 0) + { + throw std::invalid_argument("result"); + } +#endif + const uint64_t modulus_value = modulus.value(); + for (; coeff_count--; result++, operand1++, operand2++) + { + // Explicit inline + //result[i] = add_uint_uint_mod(operand1[i], operand2[i], modulus); +#ifdef SEAL_DEBUG + if (*operand1 >= modulus_value) + { + throw std::invalid_argument("operand1"); + } + if (*operand2 >= modulus_value) + { + throw std::invalid_argument("operand2"); + } +#endif + std::uint64_t sum = *operand1 + *operand2; + *result = sum - (modulus_value & static_cast( + -static_cast(sum >= modulus_value))); + } + } + + inline void sub_poly_poly_coeffmod(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t coeff_count, + const SmallModulus &modulus, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (operand1 == nullptr && coeff_count > 0) + { + throw std::invalid_argument("operand1"); + } + if (operand2 == nullptr && coeff_count > 0) + { + throw std::invalid_argument("operand2"); + } + if (modulus.is_zero()) + { + throw std::invalid_argument("modulus"); + } + if (result == nullptr && coeff_count > 0) + { + throw std::invalid_argument("result"); + } +#endif + const uint64_t modulus_value = modulus.value(); + for (; coeff_count--; result++, operand1++, operand2++) + { +#ifdef SEAL_DEBUG + if (*operand1 >= modulus_value) + { + throw std::out_of_range("operand1"); + } + if (*operand2 >= modulus_value) + { + throw std::out_of_range("operand2"); + } +#endif + unsigned long long temp_result; + std::int64_t borrow = sub_uint64(*operand1, *operand2, &temp_result); + *result = temp_result + (modulus_value & static_cast(-borrow)); + } + } + + void multiply_poly_scalar_coeffmod(const std::uint64_t *poly, + std::size_t coeff_count, std::uint64_t scalar, const SmallModulus &modulus, + std::uint64_t *result); + + void multiply_poly_poly_coeffmod(const std::uint64_t *operand1, + std::size_t operand1_coeff_count, const std::uint64_t *operand2, + std::size_t operand2_coeff_count, const SmallModulus &modulus, + std::size_t result_coeff_count, std::uint64_t *result); + + void multiply_poly_poly_coeffmod(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t coeff_count, + const SmallModulus &modulus, std::uint64_t *result); + + inline void multiply_truncate_poly_poly_coeffmod( + const std::uint64_t *operand1, const std::uint64_t *operand2, + std::size_t coeff_count, const SmallModulus &modulus, std::uint64_t *result) + { + multiply_poly_poly_coeffmod(operand1, coeff_count, operand2, coeff_count, + modulus, coeff_count, result); + } + + void divide_poly_poly_coeffmod_inplace(std::uint64_t *numerator, + const std::uint64_t *denominator, std::size_t coeff_count, + const SmallModulus &modulus, std::uint64_t *quotient); + + inline void divide_poly_poly_coeffmod(const std::uint64_t *numerator, + const std::uint64_t *denominator, std::size_t coeff_count, + const SmallModulus &modulus, std::uint64_t *quotient, + std::uint64_t *remainder) + { + set_uint_uint(numerator, coeff_count, remainder); + divide_poly_poly_coeffmod_inplace(remainder, denominator, coeff_count, + modulus, quotient); + } + + void apply_galois(const std::uint64_t *input, int coeff_count_power, + std::uint64_t galois_elt, const SmallModulus &modulus, std::uint64_t *result); + + void apply_galois_ntt(const std::uint64_t *input, int coeff_count_power, + std::uint64_t galois_elt, std::uint64_t *result); + + void dyadic_product_coeffmod(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t coeff_count, + const SmallModulus &modulus, std::uint64_t *result); + + std::uint64_t poly_infty_norm_coeffmod(const std::uint64_t *operand, + std::size_t coeff_count, const SmallModulus &modulus); + + bool try_invert_poly_coeffmod(const std::uint64_t *operand, + const std::uint64_t *poly_modulus, std::size_t coeff_count, + const SmallModulus &modulus, std::uint64_t *result, MemoryPool &pool); + + void negacyclic_shift_poly_coeffmod(const std::uint64_t *operand, + std::size_t coeff_count, std::size_t shift, const SmallModulus &modulus, + std::uint64_t *result); + } +} diff --git a/src/seal/util/polycore.h b/src/seal/util/polycore.h new file mode 100644 index 000000000..5501bcbf9 --- /dev/null +++ b/src/seal/util/polycore.h @@ -0,0 +1,367 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "seal/util/common.h" +#include "seal/util/uintcore.h" +#include "seal/util/pointer.h" + +namespace seal +{ + namespace util + { + inline std::string poly_to_hex_string(const std::uint64_t *value, + std::size_t coeff_count, std::size_t coeff_uint64_count) + { +#ifdef SEAL_DEBUG + if (coeff_uint64_count && coeff_count && !value) + { + throw std::invalid_argument("value"); + } +#endif + std::ostringstream result; + bool empty = true; + value += util::mul_safe(coeff_count - 1, coeff_uint64_count); + while (coeff_count--) + { + if (is_zero_uint(value, coeff_uint64_count)) + { + value -= coeff_uint64_count; + continue; + } + if (!empty) + { + result << " + "; + } + result << uint_to_hex_string(value, coeff_uint64_count); + if (coeff_count) + { + result << "x^" << coeff_count; + } + empty = false; + value -= coeff_uint64_count; + } + if (empty) + { + result << "0"; + } + return result.str(); + } + + inline std::string poly_to_dec_string(const std::uint64_t *value, + std::size_t coeff_count, std::size_t coeff_uint64_count, + MemoryPool &pool) + { +#ifdef SEAL_DEBUG + if (coeff_uint64_count && coeff_count && !value) + { + throw std::invalid_argument("value"); + } +#endif + std::ostringstream result; + bool empty = true; + value += coeff_count - 1; + while (coeff_count--) + { + if (is_zero_uint(value, coeff_uint64_count)) + { + value -= coeff_uint64_count; + continue; + } + if (!empty) + { + result << " + "; + } + result << uint_to_dec_string(value, coeff_uint64_count, pool); + if (coeff_count) + { + result << "x^" << coeff_count; + } + empty = false; + value -= coeff_uint64_count; + } + if (empty) + { + result << "0"; + } + return result.str(); + } + + inline auto allocate_poly(std::size_t coeff_count, + std::size_t coeff_uint64_count, MemoryPool &pool) + { + return allocate_uint( + util::mul_safe(coeff_count, coeff_uint64_count), pool); + } + + inline void set_zero_poly(std::size_t coeff_count, + std::size_t coeff_uint64_count, std::uint64_t* result) + { +#ifdef SEAL_DEBUG + if (!result && coeff_count && coeff_uint64_count) + { + throw std::invalid_argument("result"); + } +#endif + set_zero_uint(util::mul_safe(coeff_count, coeff_uint64_count), result); + } + + inline auto allocate_zero_poly(std::size_t coeff_count, + std::size_t coeff_uint64_count, MemoryPool &pool) + { + return allocate_zero_uint( + util::mul_safe(coeff_count, coeff_uint64_count), pool); + } + + inline std::uint64_t *get_poly_coeff(std::uint64_t *poly, + std::size_t coeff_index, std::size_t coeff_uint64_count) + { +#ifdef SEAL_DEBUG + if (!poly) + { + throw std::invalid_argument("poly"); + } +#endif + return poly + util::mul_safe(coeff_index, coeff_uint64_count); + } + + inline const std::uint64_t *get_poly_coeff(const std::uint64_t *poly, + std::size_t coeff_index, std::size_t coeff_uint64_count) + { +#ifdef SEAL_DEBUG + if (!poly) + { + throw std::invalid_argument("poly"); + } +#endif + return poly + util::mul_safe(coeff_index, coeff_uint64_count); + } + + inline void set_poly_poly(const std::uint64_t *poly, + std::size_t coeff_count, std::size_t coeff_uint64_count, + std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!poly && coeff_count && coeff_uint64_count) + { + throw std::invalid_argument("poly"); + } + if (!result && coeff_count && coeff_uint64_count) + { + throw std::invalid_argument("result"); + } +#endif + set_uint_uint(poly, + util::mul_safe(coeff_count, coeff_uint64_count), result); + } + + inline bool is_zero_poly(const std::uint64_t *poly, + std::size_t coeff_count, std::size_t coeff_uint64_count) + { +#ifdef SEAL_DEBUG + if (!poly && coeff_count && coeff_uint64_count) + { + throw std::invalid_argument("poly"); + } +#endif + return is_zero_uint(poly, + util::mul_safe(coeff_count, coeff_uint64_count)); + } + + inline bool is_equal_poly_poly(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t coeff_count, + std::size_t coeff_uint64_count) + { +#ifdef SEAL_DEBUG + if (!operand1 && coeff_count && coeff_uint64_count) + { + throw std::invalid_argument("operand1"); + } + if (!operand2 && coeff_count && coeff_uint64_count) + { + throw std::invalid_argument("operand2"); + } +#endif + return is_equal_uint_uint(operand1, operand2, + util::mul_safe(coeff_count, coeff_uint64_count)); + } + + inline void set_poly_poly(const std::uint64_t *poly, std::size_t poly_coeff_count, + std::size_t poly_coeff_uint64_count, std::size_t result_coeff_count, + std::size_t result_coeff_uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!poly && poly_coeff_count && poly_coeff_uint64_count) + { + throw std::invalid_argument("poly"); + } + if (!result && result_coeff_count && result_coeff_uint64_count) + { + throw std::invalid_argument("result"); + } +#endif + if (!result_coeff_uint64_count || !result_coeff_count) + { + return; + } + + std::size_t min_coeff_count = std::min(poly_coeff_count, result_coeff_count); + for (std::size_t i = 0; i < min_coeff_count; i++, + poly += poly_coeff_uint64_count, result += result_coeff_uint64_count) + { + set_uint_uint(poly, poly_coeff_uint64_count, result_coeff_uint64_count, result); + } + set_zero_uint(util::mul_safe( + result_coeff_count - min_coeff_count, result_coeff_uint64_count), result); + } + + inline bool is_one_zero_one_poly(const std::uint64_t *poly, + std::size_t coeff_count, std::size_t coeff_uint64_count) + { +#ifdef SEAL_DEBUG + if (!poly && coeff_count && coeff_uint64_count) + { + throw std::invalid_argument("poly"); + } +#endif + if (coeff_count == 0 || coeff_uint64_count == 0) + { + return false; + } + if (!is_equal_uint(get_poly_coeff(poly, 0, coeff_uint64_count), + coeff_uint64_count, 1)) + { + return false; + } + if (!is_equal_uint(get_poly_coeff(poly, coeff_count - 1, coeff_uint64_count), + coeff_uint64_count, 1)) + { + return false; + } + if (coeff_count > 2 && + !is_zero_poly(poly + coeff_uint64_count, + coeff_count - 2, coeff_uint64_count)) + { + return false; + } + return true; + } + + inline std::size_t get_significant_coeff_count_poly( + const std::uint64_t *poly, std::size_t coeff_count, + std::size_t coeff_uint64_count) + { +#ifdef SEAL_DEBUG + if (!poly && coeff_count && coeff_uint64_count) + { + throw std::invalid_argument("poly"); + } +#endif + if(coeff_count == 0) + { + return 0; + } + + poly += util::mul_safe(coeff_count - 1, coeff_uint64_count); + for (std::size_t i = coeff_count; i; i--) + { + if (!is_zero_uint(poly, coeff_uint64_count)) + { + return i; + } + poly -= coeff_uint64_count; + } + return 0; + } + + inline auto duplicate_poly_if_needed(const std::uint64_t *poly, + std::size_t coeff_count, std::size_t coeff_uint64_count, + std::size_t new_coeff_count, std::size_t new_coeff_uint64_count, + bool force, MemoryPool &pool) + { +#ifdef SEAL_DEBUG + if (!poly && coeff_count && coeff_uint64_count) + { + throw std::invalid_argument("poly"); + } +#endif + if (!force && coeff_count >= new_coeff_count && + coeff_uint64_count == new_coeff_uint64_count) + { + return ConstPointer::Aliasing(poly); + } + auto allocation(allocate_poly( + new_coeff_count, new_coeff_uint64_count, pool)); + set_poly_poly(poly, coeff_count, coeff_uint64_count, new_coeff_count, + new_coeff_uint64_count, allocation.get()); + return ConstPointer(std::move(allocation)); + } + + inline bool are_poly_coefficients_less_than(const std::uint64_t *poly, + std::size_t coeff_count, std::size_t coeff_uint64_count, + const std::uint64_t *compare, std::size_t compare_uint64_count) + { +#ifdef SEAL_DEBUG + if (!poly && coeff_count && coeff_uint64_count) + { + throw std::invalid_argument("poly"); + } + if (!compare && compare_uint64_count > 0) + { + throw std::invalid_argument("compare"); + } +#endif + if (coeff_count == 0) + { + return true; + } + if (compare_uint64_count == 0) + { + return false; + } + if (coeff_uint64_count == 0) + { + return true; + } + for (; coeff_count--; poly += coeff_uint64_count) + { + if (compare_uint_uint(poly, coeff_uint64_count, compare, + compare_uint64_count) >= 0) + { + return false; + } + } + return true; + } + + inline bool are_poly_coefficients_less_than(const std::uint64_t *poly, + std::size_t coeff_count, std::uint64_t compare) + { +#ifdef SEAL_DEBUG + if (!poly && coeff_count) + { + throw std::invalid_argument("poly"); + } +#endif + if (coeff_count == 0) + { + return true; + } + for (; coeff_count--; poly++) + { + if (*poly >= compare) + { + return false; + } + } + return true; + } + } +} diff --git a/src/seal/util/randomtostd.h b/src/seal/util/randomtostd.h new file mode 100644 index 000000000..1657f01e1 --- /dev/null +++ b/src/seal/util/randomtostd.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include "seal/randomgen.h" + +namespace seal +{ + namespace util + { + class RandomToStandardAdapter + { + public: + typedef std::uint32_t result_type; + + RandomToStandardAdapter() : generator_(nullptr) + { + } + + RandomToStandardAdapter( + std::shared_ptr generator) : + generator_(generator) + { + } + + auto generator() const noexcept + { + return generator_; + } + + auto generator() noexcept + { + return generator_; + } + + result_type operator()() + { + return generator_->generate(); + } + + static constexpr result_type min() noexcept + { + return 0; + } + + static constexpr result_type max() noexcept + { + return std::numeric_limits::max(); + } + + private: + std::shared_ptr generator_; + }; + } +} diff --git a/src/seal/util/smallntt.cpp b/src/seal/util/smallntt.cpp new file mode 100644 index 000000000..8c83af779 --- /dev/null +++ b/src/seal/util/smallntt.cpp @@ -0,0 +1,340 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/util/smallntt.h" +#include "seal/util/polyarith.h" +#include "seal/util/uintarith.h" +#include "seal/smallmodulus.h" +#include "seal/util/uintarithsmallmod.h" +#include "seal/util/defines.h" +#include + +using namespace std; + +namespace seal +{ + namespace util + { + SmallNTTTables::SmallNTTTables(int coeff_count_power, + const SmallModulus &modulus, MemoryPoolHandle pool) : + pool_(move(pool)) + { +#ifdef SEAL_DEBUG + if (!pool_) + { + throw invalid_argument("pool is uninitialized"); + } +#endif + generate(coeff_count_power, modulus); + } + + void SmallNTTTables::reset() + { + generated_ = false; + modulus_ = SmallModulus(); + root_ = 0; + root_powers_.release(); + scaled_root_powers_.release(); + inv_root_powers_.release(); + scaled_inv_root_powers_.release(); + inv_root_powers_div_two_.release(); + scaled_inv_root_powers_div_two_.release(); + inv_degree_modulo_ = 0; + coeff_count_power_ = 0; + coeff_count_ = 0; + } + + bool SmallNTTTables::generate(int coeff_count_power, + const SmallModulus &modulus) + { + reset(); + + if ((coeff_count_power < get_power_of_two(SEAL_POLY_MOD_DEGREE_MIN)) || + coeff_count_power > get_power_of_two(SEAL_POLY_MOD_DEGREE_MAX)) + { + throw invalid_argument("coeff_count_power out of range"); + } + + coeff_count_power_ = coeff_count_power; + coeff_count_ = size_t(1) << coeff_count_power_; + + // Allocate memory for the tables + root_powers_ = allocate_uint(coeff_count_, pool_); + inv_root_powers_ = allocate_uint(coeff_count_, pool_); + scaled_root_powers_ = allocate_uint(coeff_count_, pool_); + scaled_inv_root_powers_ = allocate_uint(coeff_count_, pool_); + inv_root_powers_div_two_ = allocate_uint(coeff_count_, pool_); + scaled_inv_root_powers_div_two_ = allocate_uint(coeff_count_, pool_); + modulus_ = modulus; + + // We defer parameter checking to try_minimal_primitive_root(...) + if (!try_minimal_primitive_root(2 * coeff_count_, modulus_, root_)) + { + reset(); + return false; + } + + uint64_t inverse_root; + if (!try_invert_uint_mod(root_, modulus_, inverse_root)) + { + reset(); + return false; + } + + // Populate the tables storing (scaled version of) powers of root + // mod q in bit-scrambled order. + ntt_powers_of_primitive_root(root_, root_powers_.get()); + ntt_scale_powers_of_primitive_root(root_powers_.get(), + scaled_root_powers_.get()); + + // Populate the tables storing (scaled version of) powers of + // (root)^{-1} mod q in bit-scrambled order. + ntt_powers_of_primitive_root(inverse_root, inv_root_powers_.get()); + ntt_scale_powers_of_primitive_root(inv_root_powers_.get(), + scaled_inv_root_powers_.get()); + + // Populate the tables storing (scaled version of ) 2 times + // powers of roots^-1 mod q in bit-scrambled order. + for (size_t i = 0; i < coeff_count_; i++) + { + inv_root_powers_div_two_[i] = + div2_uint_mod(inv_root_powers_[i], modulus_); + } + ntt_scale_powers_of_primitive_root(inv_root_powers_div_two_.get(), + scaled_inv_root_powers_div_two_.get()); + + // Last compute n^(-1) modulo q. + uint64_t degree_uint = static_cast(coeff_count_); + generated_ = try_invert_uint_mod(degree_uint, modulus_, inv_degree_modulo_); + + if (!generated_) + { + reset(); + return false; + } + return true; + } + + void SmallNTTTables::ntt_powers_of_primitive_root(uint64_t root, + uint64_t *destination) const + { + uint64_t *destination_start = destination; + *destination_start = 1; + for (size_t i = 1; i < coeff_count_; i++) + { + uint64_t *next_destination = + destination_start + reverse_bits(i, coeff_count_power_); + *next_destination = + multiply_uint_uint_mod(*destination, root, modulus_); + destination = next_destination; + } + } + + // compute floor ( input * beta /q ), where beta is a 64k power of 2 + // and 0 < q < beta. + void SmallNTTTables::ntt_scale_powers_of_primitive_root( + const uint64_t *input, uint64_t *destination) const + { + for (size_t i = 0; i < coeff_count_; i++, input++, destination++) + { + uint64_t wide_quotient[2]{ 0, 0 }; + uint64_t wide_coeff[2]{ 0, *input }; + divide_uint128_uint64_inplace(wide_coeff, modulus_.value(), wide_quotient); + *destination = wide_quotient[0]; + } + } + + /** + This function computes in-place the negacyclic NTT. The input is + a polynomial a of degree n in R_q, where n is assumed to be a power of + 2 and q is a prime such that q = 1 (mod 2n). + + The output is a vector A such that the following hold: + A[j] = a(psi**(2*bit_reverse(j) + 1)), 0 <= j < n. + + For details, see Michael Naehrig and Patrick Longa. + */ + void ntt_negacyclic_harvey_lazy(uint64_t *operand, + const SmallNTTTables &tables) + { + uint64_t modulus = tables.modulus().value(); + uint64_t two_times_modulus = modulus * 2; + + // Return the NTT in scrambled order + size_t n = size_t(1) << tables.coeff_count_power(); + size_t t = n >> 1; + for (size_t m = 1; m < n; m <<= 1) + { + if (t >= 4) + { + for (size_t i = 0; i < m; i++) + { + size_t j1 = 2 * i * t; + size_t j2 = j1 + t; + const uint64_t W = tables.get_from_root_powers(m + i); + const uint64_t Wprime = tables.get_from_scaled_root_powers(m + i); + + uint64_t *X = operand + j1; + uint64_t *Y = X + t; + uint64_t currX; + unsigned long long Q; + for (size_t j = j1; j < j2; j += 4) + { + currX = *X - (two_times_modulus & static_cast(-static_cast(*X >= two_times_modulus))); + multiply_uint64_hw64(Wprime, *Y, &Q); + Q = *Y * W - Q * modulus; + *X++ = currX + Q; + *Y++ = currX + (two_times_modulus - Q); + + currX = *X - (two_times_modulus & static_cast(-static_cast(*X >= two_times_modulus))); + multiply_uint64_hw64(Wprime, *Y, &Q); + Q = *Y * W - Q * modulus; + *X++ = currX + Q; + *Y++ = currX + (two_times_modulus - Q); + + currX = *X - (two_times_modulus & static_cast(-static_cast(*X >= two_times_modulus))); + multiply_uint64_hw64(Wprime, *Y, &Q); + Q = *Y * W - Q * modulus; + *X++ = currX + Q; + *Y++ = currX + (two_times_modulus - Q); + + currX = *X - (two_times_modulus & static_cast(-static_cast(*X >= two_times_modulus))); + multiply_uint64_hw64(Wprime, *Y, &Q); + Q = *Y * W - Q * modulus; + *X++ = currX + Q; + *Y++ = currX + (two_times_modulus - Q); + } + } + } + else + { + for (size_t i = 0; i < m; i++) + { + size_t j1 = 2 * i * t; + size_t j2 = j1 + t; + const uint64_t W = tables.get_from_root_powers(m + i); + const uint64_t Wprime = tables.get_from_scaled_root_powers(m + i); + + uint64_t *X = operand + j1; + uint64_t *Y = X + t; + uint64_t currX; + unsigned long long Q; + for (size_t j = j1; j < j2; j++) + { + // The Harvey butterfly: assume X, Y in [0, 2p), and return X', Y' in [0, 2p). + // X', Y' = X + WY, X - WY (mod p). + currX = *X - (two_times_modulus & static_cast(-static_cast(*X >= two_times_modulus))); + multiply_uint64_hw64(Wprime, *Y, &Q); + Q = W * *Y - Q * modulus; + *X++ = currX + Q; + *Y++ = currX + (two_times_modulus - Q); + } + } + } + t >>= 1; + } + } + + // Inverse negacyclic NTT using Harvey's butterfly. (See Patrick Longa and Michael Naehrig). + void inverse_ntt_negacyclic_harvey_lazy(uint64_t *operand, const SmallNTTTables &tables) + { + uint64_t modulus = tables.modulus().value(); + uint64_t two_times_modulus = modulus * 2; + + // return the bit-reversed order of NTT. + size_t n = size_t(1) << tables.coeff_count_power(); + size_t t = 1; + + for (size_t m = n; m > 1; m >>= 1) + { + size_t j1 = 0; + size_t h = m >> 1; + if (t >= 4) + { + for (size_t i = 0; i < h; i++) + { + size_t j2 = j1 + t; + // Need the powers of phi^{-1} in bit-reversed order + const uint64_t W = tables.get_from_inv_root_powers_div_two(h + i); + const uint64_t Wprime = tables.get_from_scaled_inv_root_powers_div_two(h + i); + + uint64_t *U = operand + j1; + uint64_t *V = U + t; + uint64_t currU; + uint64_t T; + unsigned long long H; + for (size_t j = j1; j < j2; j += 4) + { + T = two_times_modulus - *V + *U; + currU = *U + *V - (two_times_modulus & static_cast(-static_cast((*U << 1) >= T))); + *U++ = (currU + (modulus & static_cast(-static_cast(T & 1)))) >> 1; + multiply_uint64_hw64(Wprime, T, &H); + *V++ = T * W - H * modulus; + + T = two_times_modulus - *V + *U; + currU = *U + *V - (two_times_modulus & static_cast(-static_cast((*U << 1) >= T))); + *U++ = (currU + (modulus & static_cast(-static_cast(T & 1)))) >> 1; + multiply_uint64_hw64(Wprime, T, &H); + *V++ = T * W - H * modulus; + + T = two_times_modulus - *V + *U; + currU = *U + *V - (two_times_modulus & static_cast(-static_cast((*U << 1) >= T))); + *U++ = (currU + (modulus & static_cast(-static_cast(T & 1)))) >> 1; + multiply_uint64_hw64(Wprime, T, &H); + *V++ = T * W - H * modulus; + + T = two_times_modulus - *V + *U; + currU = *U + *V - (two_times_modulus & static_cast(-static_cast((*U << 1) >= T))); + *U++ = (currU + (modulus & static_cast(-static_cast(T & 1)))) >> 1; + multiply_uint64_hw64(Wprime, T, &H); + *V++ = T * W - H * modulus; + } + j1 += (t << 1); + } + } + else + { + for (size_t i = 0; i < h; i++) + { + size_t j2 = j1 + t; + // Need the powers of phi^{-1} in bit-reversed order + const uint64_t W = tables.get_from_inv_root_powers_div_two(h + i); + const uint64_t Wprime = tables.get_from_scaled_inv_root_powers_div_two(h + i); + + uint64_t *U = operand + j1; + uint64_t *V = U + t; + uint64_t currU; + uint64_t T; + unsigned long long H; + for (size_t j = j1; j < j2; j++) + { + // U = x[i], V = x[i+m] + + // Compute U - V + 2q + T = two_times_modulus - *V + *U; + + // Cleverly check whether currU + currV >= two_times_modulus + currU = *U + *V - (two_times_modulus & static_cast(-static_cast((*U << 1) >= T))); + + // Need to make it so that div2_uint_mod takes values that are > q. + //div2_uint_mod(U, modulusptr, coeff_uint64_count, U); + // We use also the fact that parity of currU is same as parity of T. + // Since our modulus is always so small that currU + masked_modulus < 2^64, + // we never need to worry about wrapping around when adding masked_modulus. + //uint64_t masked_modulus = modulus & static_cast(-static_cast(T & 1)); + //uint64_t carry = add_uint64(currU, masked_modulus, 0, &currU); + //currU += modulus & static_cast(-static_cast(T & 1)); + *U++ = (currU + (modulus & static_cast(-static_cast(T & 1)))) >> 1; + + multiply_uint64_hw64(Wprime, T, &H); + // effectively, the next two multiply perform multiply modulo beta = 2**wordsize. + *V++ = W * T - H * modulus; + } + j1 += (t << 1); + } + } + t <<= 1; + } + } + } +} diff --git a/src/seal/util/smallntt.h b/src/seal/util/smallntt.h new file mode 100644 index 000000000..2a1ec510a --- /dev/null +++ b/src/seal/util/smallntt.h @@ -0,0 +1,274 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include "seal/util/pointer.h" +#include "seal/memorymanager.h" +#include "seal/smallmodulus.h" + +namespace seal +{ + namespace util + { + class SmallNTTTables + { + public: + SmallNTTTables(MemoryPoolHandle pool = MemoryManager::GetPool()) : + pool_(std::move(pool)) + { +#ifdef SEAL_DEBUG + if (!pool_) + { + throw std::invalid_argument("pool is uninitialized"); + } +#endif + } + + SmallNTTTables(int coeff_count_power, const SmallModulus &modulus, + MemoryPoolHandle pool = MemoryManager::GetPool()); + + inline bool is_generated() const + { + return generated_; + } + + bool generate(int coeff_count_power, const SmallModulus &modulus); + + void reset(); + + inline std::uint64_t get_root() const + { +#ifdef SEAL_DEBUG + if (!generated_) + { + throw std::logic_error("tables are not generated"); + } +#endif + return root_; + } + + inline std::uint64_t get_from_root_powers(std::size_t index) const + { +#ifdef SEAL_DEBUG + if (index >= coeff_count_) + { + throw std::out_of_range("index"); + } + if (!generated_) + { + throw std::logic_error("tables are not generated"); + } +#endif + return root_powers_[index]; + } + + inline std::uint64_t get_from_scaled_root_powers(std::size_t index) const + { +#ifdef SEAL_DEBUG + if (index >= coeff_count_) + { + throw std::out_of_range("index"); + } + if (!generated_) + { + throw std::logic_error("tables are not generated"); + } +#endif + return scaled_root_powers_[index]; + } + + inline std::uint64_t get_from_inv_root_powers(std::size_t index) const + { +#ifdef SEAL_DEBUG + if (index >= coeff_count_) + { + throw std::out_of_range("index"); + } + if (!generated_) + { + throw std::logic_error("tables are not generated"); + } +#endif + return inv_root_powers_[index]; + } + + inline std::uint64_t get_from_scaled_inv_root_powers(std::size_t index) const + { +#ifdef SEAL_DEBUG + if (index >= coeff_count_) + { + throw std::out_of_range("index"); + } + if (!generated_) + { + throw std::logic_error("tables are not generated"); + } +#endif + return scaled_inv_root_powers_[index]; + } + + inline std::uint64_t get_from_inv_root_powers_div_two( + std::size_t index) const + { +#ifdef SEAL_DEBUG + if (index >= coeff_count_) + { + throw std::out_of_range("index"); + } + if (!generated_) + { + throw std::logic_error("tables are not generated"); + } +#endif + return inv_root_powers_div_two_[index]; + } + + inline std::uint64_t get_from_scaled_inv_root_powers_div_two( + std::size_t index) const + { +#ifdef SEAL_DEBUG + if (index >= coeff_count_) + { + throw std::out_of_range("index"); + } + if (!generated_) + { + throw std::logic_error("tables are not generated"); + } +#endif + return scaled_inv_root_powers_div_two_[index]; + } + + inline const std::uint64_t *get_inv_degree_modulo() const + { +#ifdef SEAL_DEBUG + if (!generated_) + { + throw std::logic_error("tables are not generated"); + } +#endif + return &inv_degree_modulo_; + } + + inline const SmallModulus &modulus() const + { + return modulus_; + } + + inline int coeff_count_power() const + { + return coeff_count_power_; + } + + inline std::size_t coeff_count() const + { + return coeff_count_; + } + + private: + SmallNTTTables(const SmallNTTTables ©) = delete; + + SmallNTTTables(SmallNTTTables &&source) = delete; + + SmallNTTTables &operator =(const SmallNTTTables &assign) = delete; + + SmallNTTTables &operator =(SmallNTTTables &&assign) = delete; + + // Computed bit-scrambled vector of first 1 << coeff_count_power powers + // of a primitive root. + void ntt_powers_of_primitive_root(std::uint64_t root, + std::uint64_t *destination) const; + + // Scales the elements of a vector returned by powers_of_primitive_root(...) + // by word_size/modulus and rounds down. + void ntt_scale_powers_of_primitive_root(const std::uint64_t *input, + std::uint64_t *destination) const; + + MemoryPoolHandle pool_; + + bool generated_ = false; + + std::uint64_t root_ = 0; + + // Size coeff_count_ + Pointer root_powers_; + + // Size coeff_count_ + Pointer scaled_root_powers_; + + // Size coeff_count_ + Pointer inv_root_powers_div_two_; + + // Size coeff_count_ + Pointer scaled_inv_root_powers_div_two_; + + int coeff_count_power_ = 0; + + std::size_t coeff_count_ = 0; + + SmallModulus modulus_; + + // Size coeff_count_ + Pointer inv_root_powers_; + + // Size coeff_count_ + Pointer scaled_inv_root_powers_; + + std::uint64_t inv_degree_modulo_ = 0; + + }; + + void ntt_negacyclic_harvey_lazy(std::uint64_t *operand, + const SmallNTTTables &tables); + + inline void ntt_negacyclic_harvey(std::uint64_t *operand, + const SmallNTTTables &tables) + { + ntt_negacyclic_harvey_lazy(operand, tables); + + // Finally maybe we need to reduce every coefficient modulo q, but we + // know that they are in the range [0, 4q). + // Since word size is controlled this is fast. + std::uint64_t modulus = tables.modulus().value(); + std::uint64_t two_times_modulus = modulus * 2; + std::size_t n = std::size_t(1) << tables.coeff_count_power(); + + for (; n--; operand++) + { + if (*operand >= two_times_modulus) + { + *operand -= two_times_modulus; + } + if (*operand >= modulus) + { + *operand -= modulus; + } + } + } + + void inverse_ntt_negacyclic_harvey_lazy(std::uint64_t *operand, + const SmallNTTTables &tables); + + inline void inverse_ntt_negacyclic_harvey(std::uint64_t *operand, + const SmallNTTTables &tables) + { + inverse_ntt_negacyclic_harvey_lazy(operand, tables); + + std::uint64_t modulus = tables.modulus().value(); + std::size_t n = std::size_t(1) << tables.coeff_count_power(); + + // Final adjustments; compute a[j] = a[j] * n^{-1} mod q. + // We incorporated the final adjustment in the butterfly. Only need + // to reduce here. + for (; n--; operand++) + { + if (*operand >= modulus) + { + *operand -= modulus; + } + } + } + } +} diff --git a/src/seal/util/uintarith.cpp b/src/seal/util/uintarith.cpp new file mode 100644 index 000000000..f1e62d619 --- /dev/null +++ b/src/seal/util/uintarith.cpp @@ -0,0 +1,726 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/util/uintcore.h" +#include "seal/util/uintarith.h" +#include "seal/util/common.h" +#include +#include +#include + +using namespace std; + +namespace seal +{ + namespace util + { + void multiply_uint_uint(const uint64_t *operand1, + size_t operand1_uint64_count, const uint64_t *operand2, + size_t operand2_uint64_count, size_t result_uint64_count, + uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand1 && operand1_uint64_count > 0) + { + throw invalid_argument("operand1"); + } + if (!operand2 && operand2_uint64_count > 0) + { + throw invalid_argument("operand2"); + } + if (!result_uint64_count) + { + throw invalid_argument("result_uint64_count"); + } + if (!result) + { + throw invalid_argument("result"); + } + if (result != nullptr && (operand1 == result || operand2 == result)) + { + throw invalid_argument("result cannot point to the same value as operand1 or operand2"); + } +#endif + // Handle fast cases. + if (!operand1_uint64_count || !operand2_uint64_count) + { + // If either operand is 0, then result is 0. + set_zero_uint(result_uint64_count, result); + return; + } + if (result_uint64_count == 1) + { + *result = *operand1 * *operand2; + return; + } + + // In some cases these improve performance. + operand1_uint64_count = get_significant_uint64_count_uint( + operand1, operand1_uint64_count); + operand2_uint64_count = get_significant_uint64_count_uint( + operand2, operand2_uint64_count); + + // More fast cases + if (operand1_uint64_count == 1) + { + multiply_uint_uint64(operand2, operand2_uint64_count, + *operand1, result_uint64_count, result); + return; + } + if (operand2_uint64_count == 1) + { + multiply_uint_uint64(operand1, operand1_uint64_count, + *operand2, result_uint64_count, result); + return; + } + + // Clear out result. + set_zero_uint(result_uint64_count, result); + + // Multiply operand1 and operand2. + size_t operand1_index_max = min(operand1_uint64_count, + result_uint64_count); + for (size_t operand1_index = 0; + operand1_index < operand1_index_max; operand1_index++) + { + const uint64_t *inner_operand2 = operand2; + uint64_t *inner_result = result++; + uint64_t carry = 0; + size_t operand2_index = 0; + size_t operand2_index_max = min(operand2_uint64_count, + result_uint64_count - operand1_index); + for (; operand2_index < operand2_index_max; operand2_index++) + { + // Perform 64-bit multiplication of operand1 and operand2 + unsigned long long temp_result[2]; + multiply_uint64(*operand1, *inner_operand2++, temp_result); + carry = temp_result[1] + add_uint64(temp_result[0], carry, 0, temp_result); + unsigned long long temp; + carry += add_uint64(*inner_result, temp_result[0], 0, &temp); + *inner_result++ = temp; + } + + // Write carry if there is room in result + if (operand1_index + operand2_index_max < result_uint64_count) + { + *inner_result = carry; + } + + operand1++; + } + } + + void multiply_uint_uint64(const uint64_t *operand1, + size_t operand1_uint64_count, uint64_t operand2, + size_t result_uint64_count, uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand1 && operand1_uint64_count > 0) + { + throw invalid_argument("operand1"); + } + if (!result_uint64_count) + { + throw invalid_argument("result_uint64_count"); + } + if (!result) + { + throw invalid_argument("result"); + } + if (result != nullptr && operand1 == result) + { + throw invalid_argument("result cannot point to the same value as operand1"); + } +#endif + // Handle fast cases. + if (!operand1_uint64_count || !operand2) + { + // If either operand is 0, then result is 0. + set_zero_uint(result_uint64_count, result); + return; + } + if (result_uint64_count == 1) + { + *result = *operand1 * operand2; + return; + } + + // More fast cases + //if (result_uint64_count == 2 && operand1_uint64_count > 1) + //{ + // unsigned long long temp_result; + // multiply_uint64(*operand1, operand2, &temp_result); + // *result = temp_result; + // *(result + 1) += *(operand1 + 1) * operand2; + // return; + //} + + // Clear out result. + set_zero_uint(result_uint64_count, result); + + // Multiply operand1 and operand2. + unsigned long long carry = 0; + size_t operand1_index_max = min(operand1_uint64_count, + result_uint64_count); + for (size_t operand1_index = 0; + operand1_index < operand1_index_max; operand1_index++) + { + unsigned long long temp_result[2]; + multiply_uint64(*operand1++, operand2, temp_result); + unsigned long long temp; + carry = temp_result[1] + add_uint64(temp_result[0], carry, 0, &temp); + *result++ = temp; + } + + // Write carry if there is room in result + if (operand1_index_max < result_uint64_count) + { + *result = carry; + } + } + + void divide_uint_uint_inplace(uint64_t *numerator, + const uint64_t *denominator, size_t uint64_count, + uint64_t *quotient, MemoryPool &pool) + { +#ifdef SEAL_DEBUG + if (!numerator && uint64_count > 0) + { + throw invalid_argument("numerator"); + } + if (!denominator && uint64_count > 0) + { + throw invalid_argument("denominator"); + } + if (!quotient && uint64_count > 0) + { + throw invalid_argument("quotient"); + } + if (is_zero_uint(denominator, uint64_count) && uint64_count > 0) + { + throw invalid_argument("denominator"); + } + if (quotient && (numerator == quotient || denominator == quotient)) + { + throw invalid_argument("quotient cannot point to same value as numerator or denominator"); + } +#endif + if (!uint64_count) + { + return; + } + + // Clear quotient. Set it to zero. + set_zero_uint(uint64_count, quotient); + + // Determine significant bits in numerator and denominator. + int numerator_bits = + get_significant_bit_count_uint(numerator, uint64_count); + int denominator_bits = + get_significant_bit_count_uint(denominator, uint64_count); + + // If numerator has fewer bits than denominator, then done. + if (numerator_bits < denominator_bits) + { + return; + } + + // Only perform computation up to last non-zero uint64s. + uint64_count = safe_cast( + divide_round_up(numerator_bits, bits_per_uint64)); + + // Handle fast case. + if (uint64_count == 1) + { + *quotient = *numerator / *denominator; + *numerator -= *quotient * *denominator; + return; + } + + auto alloc_anchor(allocate_uint(uint64_count << 1, pool)); + + // Create temporary space to store mutable copy of denominator. + uint64_t *shifted_denominator = alloc_anchor.get(); + + // Create temporary space to store difference calculation. + uint64_t *difference = shifted_denominator + uint64_count; + + // Shift denominator to bring MSB in alignment with MSB of numerator. + int denominator_shift = numerator_bits - denominator_bits; + left_shift_uint(denominator, denominator_shift, uint64_count, + shifted_denominator); + denominator_bits += denominator_shift; + + // Perform bit-wise division algorithm. + int remaining_shifts = denominator_shift; + while (numerator_bits == denominator_bits) + { + // NOTE: MSBs of numerator and denominator are aligned. + + // Even though MSB of numerator and denominator are aligned, + // still possible numerator < shifted_denominator. + if (sub_uint_uint(numerator, shifted_denominator, + uint64_count, difference)) + { + // numerator < shifted_denominator and MSBs are aligned, + // so current quotient bit is zero and next one is definitely one. + if (remaining_shifts == 0) + { + // No shifts remain and numerator < denominator so done. + break; + } + + // Effectively shift numerator left by 1 by instead adding + // numerator to difference (to prevent overflow in numerator). + add_uint_uint(difference, numerator, uint64_count, difference); + + // Adjust quotient and remaining shifts as a result of + // shifting numerator. + left_shift_uint(quotient, 1, uint64_count, quotient); + remaining_shifts--; + } + // Difference is the new numerator with denominator subtracted. + + // Update quotient to reflect subtraction. + quotient[0] |= 1; + + // Determine amount to shift numerator to bring MSB in alignment + // with denominator. + numerator_bits = get_significant_bit_count_uint(difference, uint64_count); + int numerator_shift = denominator_bits - numerator_bits; + if (numerator_shift > remaining_shifts) + { + // Clip the maximum shift to determine only the integer + // (as opposed to fractional) bits. + numerator_shift = remaining_shifts; + } + + // Shift and update numerator. + if (numerator_bits > 0) + { + left_shift_uint(difference, numerator_shift, uint64_count, numerator); + numerator_bits += numerator_shift; + } + else + { + // Difference is zero so no need to shift, just set to zero. + set_zero_uint(uint64_count, numerator); + } + + // Adjust quotient and remaining shifts as a result of shifting numerator. + left_shift_uint(quotient, numerator_shift, uint64_count, quotient); + remaining_shifts -= numerator_shift; + } + + // Correct numerator (which is also the remainder) for shifting of + // denominator, unless it is just zero. + if (numerator_bits > 0) + { + right_shift_uint(numerator, denominator_shift, uint64_count, numerator); + } + } + + void divide_uint128_uint64_inplace(uint64_t *numerator, + uint64_t denominator, uint64_t *quotient) + { +#ifdef SEAL_DEBUG + if (!numerator) + { + throw invalid_argument("numerator"); + } + if (denominator == 0) + { + throw invalid_argument("denominator"); + } + if (!quotient) + { + throw invalid_argument("quotient"); + } + if (numerator == quotient) + { + throw invalid_argument("quotient cannot point to same value as numerator"); + } +#endif + // We expect 129-bit input + constexpr size_t uint64_count = 2; + + // Clear quotient. Set it to zero. + quotient[0] = 0; + quotient[1] = 0; + + // Determine significant bits in numerator and denominator. + int numerator_bits = get_significant_bit_count_uint(numerator, uint64_count); + int denominator_bits = get_significant_bit_count(denominator); + + // If numerator has fewer bits than denominator, then done. + if (numerator_bits < denominator_bits) + { + return; + } + + // Create temporary space to store mutable copy of denominator. + uint64_t shifted_denominator[uint64_count]{ denominator, 0 }; + + // Create temporary space to store difference calculation. + uint64_t difference[uint64_count]{ 0, 0 }; + + // Shift denominator to bring MSB in alignment with MSB of numerator. + int denominator_shift = numerator_bits - denominator_bits; + + left_shift_uint(shifted_denominator, denominator_shift, + uint64_count, shifted_denominator); + denominator_bits += denominator_shift; + + // Perform bit-wise division algorithm. + int remaining_shifts = denominator_shift; + while (numerator_bits == denominator_bits) + { + // NOTE: MSBs of numerator and denominator are aligned. + + // Even though MSB of numerator and denominator are aligned, + // still possible numerator < shifted_denominator. + if (sub_uint_uint(numerator, shifted_denominator, uint64_count, difference)) + { + // numerator < shifted_denominator and MSBs are aligned, + // so current quotient bit is zero and next one is definitely one. + if (remaining_shifts == 0) + { + // No shifts remain and numerator < denominator so done. + break; + } + + // Effectively shift numerator left by 1 by instead adding + // numerator to difference (to prevent overflow in numerator). + add_uint_uint(difference, numerator, uint64_count, difference); + + // Adjust quotient and remaining shifts as a result of shifting numerator. + quotient[1] = (quotient[1] << 1) | (quotient[0] >> (bits_per_uint64 - 1)); + quotient[0] <<= 1; + remaining_shifts--; + } + // Difference is the new numerator with denominator subtracted. + + // Determine amount to shift numerator to bring MSB in alignment + // with denominator. + numerator_bits = get_significant_bit_count_uint(difference, uint64_count); + + // Clip the maximum shift to determine only the integer + // (as opposed to fractional) bits. + int numerator_shift = min(denominator_bits - numerator_bits, remaining_shifts); + + // Shift and update numerator. + // This may be faster; first set to zero and then update if needed + + // Difference is zero so no need to shift, just set to zero. + numerator[0] = 0; + numerator[1] = 0; + + if (numerator_bits > 0) + { + left_shift_uint(difference, numerator_shift, uint64_count, numerator); + numerator_bits += numerator_shift; + } + + // Update quotient to reflect subtraction. + quotient[0] |= 1; + + // Adjust quotient and remaining shifts as a result of shifting numerator. + left_shift_uint(quotient, numerator_shift, uint64_count, quotient); + remaining_shifts -= numerator_shift; + } + + // Correct numerator (which is also the remainder) for shifting of + // denominator, unless it is just zero. + if (numerator_bits > 0) + { + right_shift_uint(numerator, denominator_shift, uint64_count, numerator); + } + } + + void divide_uint192_uint64_inplace(uint64_t *numerator, + uint64_t denominator, uint64_t *quotient) + { +#ifdef SEAL_DEBUG + if (!numerator) + { + throw invalid_argument("numerator"); + } + if (denominator == 0) + { + throw invalid_argument("denominator"); + } + if (!quotient) + { + throw invalid_argument("quotient"); + } + if (numerator == quotient) + { + throw invalid_argument("quotient cannot point to same value as numerator"); + } +#endif + // We expect 192-bit input + size_t uint64_count = 3; + + // Clear quotient. Set it to zero. + quotient[0] = 0; + quotient[1] = 0; + quotient[2] = 0; + + // Determine significant bits in numerator and denominator. + int numerator_bits = get_significant_bit_count_uint(numerator, uint64_count); + int denominator_bits = get_significant_bit_count(denominator); + + // If numerator has fewer bits than denominator, then done. + if (numerator_bits < denominator_bits) + { + return; + } + + // Only perform computation up to last non-zero uint64s. + uint64_count = safe_cast( + divide_round_up(numerator_bits, bits_per_uint64)); + + // Handle fast case. + if (uint64_count == 1) + { + *quotient = *numerator / denominator; + *numerator -= *quotient * denominator; + return; + } + + // Create temporary space to store mutable copy of denominator. + vector shifted_denominator(uint64_count, 0); + shifted_denominator[0] = denominator; + + // Create temporary space to store difference calculation. + vector difference(uint64_count); + + // Shift denominator to bring MSB in alignment with MSB of numerator. + int denominator_shift = numerator_bits - denominator_bits; + + left_shift_uint(shifted_denominator.data(), denominator_shift, + uint64_count, shifted_denominator.data()); + denominator_bits += denominator_shift; + + // Perform bit-wise division algorithm. + int remaining_shifts = denominator_shift; + while (numerator_bits == denominator_bits) + { + // NOTE: MSBs of numerator and denominator are aligned. + + // Even though MSB of numerator and denominator are aligned, + // still possible numerator < shifted_denominator. + if (sub_uint_uint(numerator, shifted_denominator.data(), + uint64_count, difference.data())) + { + // numerator < shifted_denominator and MSBs are aligned, + // so current quotient bit is zero and next one is definitely one. + if (remaining_shifts == 0) + { + // No shifts remain and numerator < denominator so done. + break; + } + + // Effectively shift numerator left by 1 by instead adding + // numerator to difference (to prevent overflow in numerator). + add_uint_uint(difference.data(), numerator, uint64_count, difference.data()); + + // Adjust quotient and remaining shifts as a result of shifting numerator. + left_shift_uint(quotient, 1, uint64_count, quotient); + remaining_shifts--; + } + // Difference is the new numerator with denominator subtracted. + + // Update quotient to reflect subtraction. + quotient[0] |= 1; + + // Determine amount to shift numerator to bring MSB in alignment with denominator. + numerator_bits = get_significant_bit_count_uint(difference.data(), uint64_count); + int numerator_shift = denominator_bits - numerator_bits; + if (numerator_shift > remaining_shifts) + { + // Clip the maximum shift to determine only the integer + // (as opposed to fractional) bits. + numerator_shift = remaining_shifts; + } + + // Shift and update numerator. + if (numerator_bits > 0) + { + left_shift_uint(difference.data(), numerator_shift, uint64_count, numerator); + numerator_bits += numerator_shift; + } + else + { + // Difference is zero so no need to shift, just set to zero. + set_zero_uint(uint64_count, numerator); + } + + // Adjust quotient and remaining shifts as a result of shifting numerator. + left_shift_uint(quotient, numerator_shift, uint64_count, quotient); + remaining_shifts -= numerator_shift; + } + + // Correct numerator (which is also the remainder) for shifting of + // denominator, unless it is just zero. + if (numerator_bits > 0) + { + right_shift_uint(numerator, denominator_shift, uint64_count, numerator); + } + } + + void exponentiate_uint(const uint64_t *operand, + size_t operand_uint64_count, const uint64_t *exponent, + size_t exponent_uint64_count, size_t result_uint64_count, + uint64_t *result, MemoryPool &pool) + { +#ifdef SEAL_DEBUG + if (!operand) + { + throw invalid_argument("operand"); + } + if (!operand_uint64_count) + { + throw invalid_argument("operand_uint64_count"); + } + if (!exponent) + { + throw invalid_argument("exponent"); + } + if (!exponent_uint64_count) + { + throw invalid_argument("exponent_uint64_count"); + } + if (!result) + { + throw invalid_argument("result"); + } + if (!result_uint64_count) + { + throw invalid_argument("result_uint64_count"); + } +#endif + // Fast cases + if (is_zero_uint(exponent, exponent_uint64_count)) + { + set_uint(1, result_uint64_count, result); + return; + } + if (is_equal_uint(exponent, exponent_uint64_count, 1)) + { + set_uint_uint(operand, operand_uint64_count, result_uint64_count, result); + return; + } + + // Need to make a copy of exponent + auto exponent_copy(allocate_uint(exponent_uint64_count, pool)); + set_uint_uint(exponent, exponent_uint64_count, exponent_copy.get()); + + // Perform binary exponentiation. + auto big_alloc(allocate_uint( + result_uint64_count + result_uint64_count + result_uint64_count, pool)); + + uint64_t *powerptr = big_alloc.get(); + uint64_t *productptr = powerptr + result_uint64_count; + uint64_t *intermediateptr = productptr + result_uint64_count; + + set_uint_uint(operand, operand_uint64_count, result_uint64_count, powerptr); + set_uint(1, result_uint64_count, intermediateptr); + + // Initially: power = operand and intermediate = 1, product is not initialized. + while (true) + { + if ((*exponent_copy.get() % 2) == 1) + { + multiply_truncate_uint_uint(powerptr, intermediateptr, + result_uint64_count, productptr); + swap(productptr, intermediateptr); + } + right_shift_uint(exponent_copy.get(), 1, exponent_uint64_count, + exponent_copy.get()); + if (is_zero_uint(exponent_copy.get(), exponent_uint64_count)) + { + break; + } + multiply_truncate_uint_uint(powerptr, powerptr, result_uint64_count, + productptr); + swap(productptr, powerptr); + } + set_uint_uint(intermediateptr, result_uint64_count, result); + } + + uint64_t exponentiate_uint64_safe(uint64_t operand, uint64_t exponent) + { + // Fast cases + if (exponent == 0) + { + return 1; + } + if (exponent == 1) + { + return operand; + } + + // Perform binary exponentiation. + uint64_t power = operand; + uint64_t product = 0; + uint64_t intermediate = 1; + + // Initially: power = operand and intermediate = 1, product irrelevant. + while (true) + { + if (exponent & 1) + { + product = mul_safe(power, intermediate); + swap(product, intermediate); + } + exponent >>= 1; + if (exponent == 0) + { + break; + } + product = mul_safe(power, power); + swap(product, power); + } + + return intermediate; + } + + uint64_t exponentiate_uint64(uint64_t operand, uint64_t exponent) + { + // Fast cases + if (exponent == 0) + { + return 1; + } + if (exponent == 1) + { + return operand; + } + + // Perform binary exponentiation. + uint64_t power = operand; + uint64_t product = 0; + uint64_t intermediate = 1; + + // Initially: power = operand and intermediate = 1, product irrelevant. + while (true) + { + if (exponent & 1) + { + product = power * intermediate; + swap(product, intermediate); + } + exponent >>= 1; + if (exponent == 0) + { + break; + } + product = power * power; + swap(product, power); + } + + return intermediate; + } + } +} diff --git a/src/seal/util/uintarith.h b/src/seal/util/uintarith.h new file mode 100644 index 000000000..dff9261fe --- /dev/null +++ b/src/seal/util/uintarith.h @@ -0,0 +1,682 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include +#include "seal/util/common.h" +#include "seal/util/uintcore.h" +#include "seal/util/pointer.h" +#include "seal/util/defines.h" + +namespace seal +{ + namespace util + { + template>> + inline unsigned char add_uint64_generic(T operand1, S operand2, + unsigned char carry, unsigned long long *result) + { +#ifdef SEAL_DEBUG + if (!result) + { + throw std::invalid_argument("result cannot be null"); + } +#endif + operand1 += operand2; + *result = operand1 + carry; + return (operand1 < operand2) || (~operand1 < carry); + } + + template>> + inline unsigned char add_uint64(T operand1, S operand2, + unsigned char carry, unsigned long long *result) + { + return SEAL_ADD_CARRY_UINT64(operand1, operand2, carry, result); + } + + template>> + inline unsigned char add_uint64(T operand1, S operand2, R *result) + { + *result = operand1 + operand2; + return static_cast(*result < operand1); + } + + inline unsigned char add_uint_uint(const std::uint64_t *operand1, + std::size_t operand1_uint64_count, const std::uint64_t *operand2, + std::size_t operand2_uint64_count, unsigned char carry, + std::size_t result_uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand1_uint64_count) + { + throw std::invalid_argument("operand1_uint64_count"); + } + if (!operand2_uint64_count) + { + throw std::invalid_argument("operand2_uint64_count"); + } + if (!result_uint64_count) + { + throw std::invalid_argument("result_uint64_count"); + } + if (!operand1) + { + throw std::invalid_argument("operand1"); + } + if (!operand2) + { + throw std::invalid_argument("operand2"); + } + if (!result) + { + throw std::invalid_argument("result"); + } +#endif + for (std::size_t i = 0; i < result_uint64_count; i++) + { + unsigned long long temp_result; + carry = add_uint64( + (i < operand1_uint64_count) ? *operand1++ : 0, + (i < operand2_uint64_count) ? *operand2++ : 0, + carry, &temp_result); + *result++ = temp_result; + } + return carry; + } + + inline unsigned char add_uint_uint(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t uint64_count, + std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (!operand1) + { + throw std::invalid_argument("operand1"); + } + if (!operand2) + { + throw std::invalid_argument("operand2"); + } + if (!result) + { + throw std::invalid_argument("result"); + } +#endif + // Unroll first iteration of loop. We assume uint64_count > 0. + unsigned char carry = add_uint64(*operand1++, *operand2++, result++); + + // Do the rest + for(; --uint64_count; operand1++, operand2++, result++) + { + unsigned long long temp_result; + carry = add_uint64(*operand1, *operand2, carry, &temp_result); + *result = temp_result; + } + return carry; + } + + inline unsigned char add_uint_uint64(const std::uint64_t *operand1, + std::uint64_t operand2, std::size_t uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (!operand1) + { + throw std::invalid_argument("operand1"); + } + if (!result) + { + throw std::invalid_argument("result"); + } +#endif + // Unroll first iteration of loop. We assume uint64_count > 0. + unsigned char carry = add_uint64(*operand1++, operand2, result++); + + // Do the rest + for(; --uint64_count; operand1++, result++) + { + unsigned long long temp_result; + carry = add_uint64(*operand1, std::uint64_t(0), carry, &temp_result); + *result = temp_result; + } + return carry; + } + + template>> + inline unsigned char sub_uint64_generic(T operand1, S operand2, + unsigned char borrow, unsigned long long *result) + { +#ifdef SEAL_DEBUG + if (!result) + { + throw std::invalid_argument("result cannot be null"); + } +#endif + auto diff = operand1 - operand2; + *result = diff - (borrow != 0); + return (diff > operand1) || (diff < borrow); + } + + template>> + inline unsigned char sub_uint64(T operand1, S operand2, + unsigned char borrow, unsigned long long *result) + { + return SEAL_SUB_BORROW_UINT64(operand1, operand2, borrow, result); + } + + template>> + inline unsigned char sub_uint64(T operand1, S operand2, R *result) + { + *result = operand1 - operand2; + return static_cast(operand2 > operand1); + } + + inline unsigned char sub_uint_uint(const std::uint64_t *operand1, + std::size_t operand1_uint64_count, const std::uint64_t *operand2, + std::size_t operand2_uint64_count, unsigned char borrow, + std::size_t result_uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!result_uint64_count) + { + throw std::invalid_argument("result_uint64_count"); + } + if (!result) + { + throw std::invalid_argument("result"); + } +#endif + for (std::size_t i = 0; i < result_uint64_count; + i++, operand1++, operand2++, result++) + { + unsigned long long temp_result; + borrow = sub_uint64((i < operand1_uint64_count) ? *operand1 : 0, + (i < operand2_uint64_count) ? *operand2 : 0, borrow, &temp_result); + *result = temp_result; + } + return borrow; + } + + inline unsigned char sub_uint_uint(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (!operand1) + { + throw std::invalid_argument("operand1"); + } + if (!operand2) + { + throw std::invalid_argument("operand2"); + } + if (!result) + { + throw std::invalid_argument("result"); + } +#endif + // Unroll first iteration of loop. We assume uint64_count > 0. + unsigned char borrow = sub_uint64(*operand1++, *operand2++, result++); + + // Do the rest + for(; --uint64_count; operand1++, operand2++, result++) + { + unsigned long long temp_result; + borrow = sub_uint64(*operand1, *operand2, borrow, &temp_result); + *result = temp_result; + } + return borrow; + } + + inline unsigned char sub_uint_uint64(const std::uint64_t *operand1, + std::uint64_t operand2, std::size_t uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (!operand1) + { + throw std::invalid_argument("operand1"); + } + if (!result) + { + throw std::invalid_argument("result"); + } +#endif + // Unroll first iteration of loop. We assume uint64_count > 0. + unsigned char borrow = sub_uint64(*operand1++, operand2, result++); + + // Do the rest + for(; --uint64_count; operand1++, operand2++, result++) + { + unsigned long long temp_result; + borrow = sub_uint64(*operand1, std::uint64_t(0), borrow, &temp_result); + *result = temp_result; + } + return borrow; + } + + inline unsigned char increment_uint(const std::uint64_t *operand, + std::size_t uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand) + { + throw std::invalid_argument("operand"); + } + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (!result) + { + throw std::invalid_argument("result"); + } +#endif + return add_uint_uint64(operand, 1, uint64_count, result); + } + + inline unsigned char decrement_uint(const std::uint64_t *operand, + std::size_t uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand && uint64_count > 0) + { + throw std::invalid_argument("operand"); + } + if (!result && uint64_count > 0) + { + throw std::invalid_argument("result"); + } +#endif + return sub_uint_uint64(operand, 1, uint64_count, result); + } + + inline void negate_uint(const std::uint64_t *operand, std::size_t uint64_count, + std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand) + { + throw std::invalid_argument("operand"); + } + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (!result) + { + throw std::invalid_argument("result"); + } +#endif + // Negation is equivalent to inverting bits and adding 1. + unsigned char carry = add_uint64(~*operand++, std::uint64_t(1), result++); + for(; --uint64_count; operand++, result++) + { + unsigned long long temp_result; + carry = add_uint64( + ~*operand, std::uint64_t(0), carry, &temp_result); + *result = temp_result; + } + } + + inline void left_shift_uint(const std::uint64_t *operand, + int shift_amount, std::size_t uint64_count, std::uint64_t *result) + { + std::size_t bits_per_uint64_sz = static_cast(bits_per_uint64); +#ifdef SEAL_DEBUG + if (!operand) + { + throw std::invalid_argument("operand"); + } + if (shift_amount < 0 || + unsigned_gt(shift_amount, + mul_safe(uint64_count, bits_per_uint64_sz))) + { + throw std::invalid_argument("shift_amount"); + } + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (!result) + { + throw std::invalid_argument("result"); + } +#endif + std::size_t uint64_shift_amount = + safe_cast(shift_amount) / bits_per_uint64_sz; + int bit_shift_amount = shift_amount - + safe_cast(mul_safe(uint64_shift_amount, bits_per_uint64_sz)); + int neg_bit_shift_amount = (bits_per_uint64 - bit_shift_amount) & + (static_cast(bit_shift_amount == 0) - 1); + + for (std::size_t i = uint64_count - uint64_shift_amount; i--; ) + { + result[i + uint64_shift_amount] = operand[i]; + } + for (std::size_t i = uint64_shift_amount; i--; ) + { + result[i] = 0; + } + if (neg_bit_shift_amount) + { + for (std::size_t i = uint64_count - 1; + i >= uint64_shift_amount + 1; i--) + { + result[i] = (result[i] << bit_shift_amount) | + (result[i - 1] >> neg_bit_shift_amount); + } + result[uint64_shift_amount] <<= bit_shift_amount; + } + } + + inline void right_shift_uint(const std::uint64_t *operand, + int shift_amount, std::size_t uint64_count, std::uint64_t *result) + { + std::size_t bits_per_uint64_sz = static_cast(bits_per_uint64); +#ifdef SEAL_DEBUG + if (!operand) + { + throw std::invalid_argument("operand"); + } + if (shift_amount < 0 || + unsigned_gt(shift_amount, + mul_safe(uint64_count, bits_per_uint64_sz))) + { + throw std::invalid_argument("shift_amount"); + } + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (!result) + { + throw std::invalid_argument("result"); + } +#endif + std::size_t uint64_shift_amount = + safe_cast(shift_amount) / bits_per_uint64_sz; + int bit_shift_amount = shift_amount - + safe_cast(mul_safe(uint64_shift_amount, bits_per_uint64_sz)); + int neg_bit_shift_amount = (bits_per_uint64 - bit_shift_amount) & + (static_cast(bit_shift_amount == 0) - 1); + + for (std::size_t i = 0; i < uint64_count - uint64_shift_amount; i++) + { + result[i] = operand[i + uint64_shift_amount]; + } + for (std::size_t i = uint64_count - uint64_shift_amount; i < uint64_count; i++) + { + result[i] = 0; + } + if (neg_bit_shift_amount) + { + for (std::size_t i = 0; i < (uint64_count - uint64_shift_amount - 1); i++) + { + result[i] = (result[i] >> bit_shift_amount) | (result[i + 1] << neg_bit_shift_amount); + } + result[uint64_count - uint64_shift_amount - 1] >>= bit_shift_amount; + } + } + + inline void half_round_up_uint(const std::uint64_t *operand, + std::size_t uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand && uint64_count > 0) + { + throw std::invalid_argument("operand"); + } + if (!result && uint64_count > 0) + { + throw std::invalid_argument("result"); + } +#endif + if (!uint64_count) + { + return; + } + // Set result to (operand + 1) / 2. To prevent overflowing operand, right shift + // and then increment result if low-bit of operand was set. + bool low_bit_set = operand[0] & 1; + + for (std::size_t i = 0; i < uint64_count - 1; i++) + { + result[i] = (operand[i] >> 1) | (operand[i + 1] << (bits_per_uint64 - 1)); + } + result[uint64_count - 1] = operand[uint64_count - 1] >> 1; + + if (low_bit_set) + { + increment_uint(result, uint64_count, result); + } + } + + inline void not_uint(const std::uint64_t *operand, std::size_t uint64_count, + std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand && uint64_count > 0) + { + throw std::invalid_argument("operand"); + } + if (!result && uint64_count > 0) + { + throw std::invalid_argument("result"); + } +#endif + for (; uint64_count--; result++, operand++) + { + *result = ~*operand; + } + } + + inline void and_uint_uint(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand1 && uint64_count > 0) + { + throw std::invalid_argument("operand1"); + } + if (!operand2 && uint64_count > 0) + { + throw std::invalid_argument("operand2"); + } + if (!result && uint64_count > 0) + { + throw std::invalid_argument("result"); + } +#endif + for (; uint64_count--; result++, operand1++, operand2++) + { + *result = *operand1 & *operand2; + } + } + + inline void or_uint_uint(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand1 && uint64_count > 0) + { + throw std::invalid_argument("operand1"); + } + if (!operand2 && uint64_count > 0) + { + throw std::invalid_argument("operand2"); + } + if (!result && uint64_count > 0) + { + throw std::invalid_argument("result"); + } +#endif + for (; uint64_count--; result++, operand1++, operand2++) + { + *result = *operand1 | *operand2; + } + } + + inline void xor_uint_uint(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand1 && uint64_count > 0) + { + throw std::invalid_argument("operand1"); + } + if (!operand2 && uint64_count > 0) + { + throw std::invalid_argument("operand2"); + } + if (!result && uint64_count > 0) + { + throw std::invalid_argument("result"); + } +#endif + for (; uint64_count--; result++, operand1++, operand2++) + { + *result = *operand1 ^ *operand2; + } + } + + template>> + inline void multiply_uint64_generic(T operand1, S operand2, + unsigned long long *result128) + { +#ifdef SEAL_DEBUG + if (!result128) + { + throw std::invalid_argument("result128 cannot be null"); + } +#endif + auto operand1_coeff_right = operand1 & 0x00000000FFFFFFFF; + auto operand2_coeff_right = operand2 & 0x00000000FFFFFFFF; + operand1 >>= 32; + operand2 >>= 32; + + auto middle1 = operand1 * operand2_coeff_right; + T middle; + auto left = operand1 * operand2 + (static_cast(add_uint64( + middle1, operand2 * operand1_coeff_right, &middle)) << 32); + auto right = operand1_coeff_right * operand2_coeff_right; + auto temp_sum = (right >> 32) + (middle & 0x00000000FFFFFFFF); + + result128[1] = static_cast( + left + (middle >> 32) + (temp_sum >> 32)); + result128[0] = static_cast( + (temp_sum << 32) | (right & 0x00000000FFFFFFFF)); + } + + template>> + inline void multiply_uint64(T operand1, S operand2, + unsigned long long *result128) + { + SEAL_MULTIPLY_UINT64(operand1, operand2, result128); + } + + template>> + inline void multiply_uint64_hw64_generic(T operand1, S operand2, + unsigned long long *hw64) + { +#ifdef SEAL_DEBUG + if (!hw64) + { + throw std::invalid_argument("hw64 cannot be null"); + } +#endif + auto operand1_coeff_right = operand1 & 0x00000000FFFFFFFF; + auto operand2_coeff_right = operand2 & 0x00000000FFFFFFFF; + operand1 >>= 32; + operand2 >>= 32; + + auto middle1 = operand1 * operand2_coeff_right; + T middle; + auto left = operand1 * operand2 + (static_cast(add_uint64( + middle1, operand2 * operand1_coeff_right, &middle)) << 32); + auto right = operand1_coeff_right * operand2_coeff_right; + auto temp_sum = (right >> 32) + (middle & 0x00000000FFFFFFFF); + + *hw64 = static_cast( + left + (middle >> 32) + (temp_sum >> 32)); + } + + template>> + inline void multiply_uint64_hw64(T operand1, S operand2, + unsigned long long *hw64) + { + SEAL_MULTIPLY_UINT64_HW64(operand1, operand2, hw64); + } + + void multiply_uint_uint(const std::uint64_t *operand1, + std::size_t operand1_uint64_count, const std::uint64_t *operand2, + std::size_t operand2_uint64_count, std::size_t result_uint64_count, + std::uint64_t *result); + + inline void multiply_uint_uint(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t uint64_count, std::uint64_t *result) + { + multiply_uint_uint(operand1, uint64_count, operand2, uint64_count, + uint64_count * 2, result); + } + + void multiply_uint_uint64(const std::uint64_t *operand1, + std::size_t operand1_uint64_count, std::uint64_t operand2, + std::size_t result_uint64_count, std::uint64_t *result); + + inline void multiply_truncate_uint_uint(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t uint64_count, std::uint64_t *result) + { + multiply_uint_uint(operand1, uint64_count, operand2, uint64_count, + uint64_count, result); + } + + void divide_uint_uint_inplace(std::uint64_t *numerator, + const std::uint64_t *denominator, std::size_t uint64_count, + std::uint64_t *quotient, MemoryPool &pool); + + inline void divide_uint_uint(const std::uint64_t *numerator, + const std::uint64_t *denominator, std::size_t uint64_count, + std::uint64_t *quotient, std::uint64_t *remainder, MemoryPool &pool) + { + set_uint_uint(numerator, uint64_count, remainder); + divide_uint_uint_inplace(remainder, denominator, uint64_count, quotient, pool); + } + + void divide_uint128_uint64_inplace(std::uint64_t *numerator, + std::uint64_t denominator, std::uint64_t *quotient); + + void divide_uint192_uint64_inplace(std::uint64_t *numerator, + std::uint64_t denominator, std::uint64_t *quotient); + + void exponentiate_uint(const std::uint64_t *operand, + std::size_t operand_uint64_count, const std::uint64_t *exponent, + std::size_t exponent_uint64_count, std::size_t result_uint64_count, + std::uint64_t *result, MemoryPool &pool); + + std::uint64_t exponentiate_uint64_safe(std::uint64_t operand, + std::uint64_t exponent); + + std::uint64_t exponentiate_uint64(std::uint64_t operand, + std::uint64_t exponent); + } +} diff --git a/src/seal/util/uintarithmod.cpp b/src/seal/util/uintarithmod.cpp new file mode 100644 index 000000000..731da0911 --- /dev/null +++ b/src/seal/util/uintarithmod.cpp @@ -0,0 +1,248 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/util/uintcore.h" +#include "seal/util/uintarith.h" +#include "seal/util/uintarithmod.h" +#include "seal/util/common.h" + +using namespace std; + +namespace seal +{ + namespace util + { + bool try_invert_uint_mod(const uint64_t *operand, + const uint64_t *modulus, size_t uint64_count, + uint64_t *result, MemoryPool &pool) + { +#ifdef SEAL_DEBUG + if (!operand) + { + throw invalid_argument("operand"); + } + if (!modulus) + { + throw invalid_argument("modulus"); + } + if (!uint64_count) + { + throw invalid_argument("uint64_count"); + } + if (!result) + { + throw invalid_argument("result"); + } + if (is_greater_than_or_equal_uint_uint(operand, modulus, uint64_count)) + { + throw out_of_range("operand"); + } +#endif + // Cannot invert 0. + int bit_count = get_significant_bit_count_uint(operand, uint64_count); + if (bit_count == 0) + { + return false; + } + + // If it is 1, then its invert is itself. + if (bit_count == 1) + { + set_uint(1, uint64_count, result); + return true; + } + + auto alloc_anchor(allocate_uint(7 * uint64_count, pool)); + + // Construct a mutable copy of operand and modulus, with numerator being modulus + // and operand being denominator. Notice that numerator > denominator. + uint64_t *numerator = alloc_anchor.get(); + set_uint_uint(modulus, uint64_count, numerator); + + uint64_t *denominator = numerator + uint64_count; + set_uint_uint(operand, uint64_count, denominator); + + // Create space to store difference. + uint64_t *difference = denominator + uint64_count; + + // Determine highest bit index of each. + int numerator_bits = get_significant_bit_count_uint(numerator, uint64_count); + int denominator_bits = get_significant_bit_count_uint(denominator, uint64_count); + + // Create space to store quotient. + uint64_t *quotient = difference + uint64_count; + + // Create three sign/magnitude values to store coefficients. + // Initialize invert_prior to +0 and invert_curr to +1. + uint64_t *invert_prior = quotient + uint64_count; + set_zero_uint(uint64_count, invert_prior); + bool invert_prior_positive = true; + + uint64_t *invert_curr = invert_prior + uint64_count; + set_uint(1, uint64_count, invert_curr); + bool invert_curr_positive = true; + + uint64_t *invert_next = invert_curr + uint64_count; + bool invert_next_positive = true; + + // Perform extended Euclidean algorithm. + while (true) + { + // NOTE: Numerator is > denominator. + + // Only perform computation up to last non-zero uint64s. + size_t division_uint64_count = static_cast( + divide_round_up(numerator_bits, bits_per_uint64)); + + // Shift denominator to bring MSB in alignment with MSB of numerator. + int denominator_shift = numerator_bits - denominator_bits; + left_shift_uint(denominator, denominator_shift, + division_uint64_count, denominator); + denominator_bits += denominator_shift; + + // Clear quotient. + set_zero_uint(uint64_count, quotient); + + // Perform bit-wise division algorithm. + int remaining_shifts = denominator_shift; + while (numerator_bits == denominator_bits) + { + // NOTE: MSBs of numerator and denominator are aligned. + + // Even though MSB of numerator and denominator are aligned, + // still possible numerator < denominator. + if (sub_uint_uint(numerator, denominator, + division_uint64_count, difference)) + { + // numerator < denominator and MSBs are aligned, so current + // quotient bit is zero and next one is definitely one. + if (remaining_shifts == 0) + { + // No shifts remain and numerator < denominator so done. + break; + } + + // Effectively shift numerator left by 1 by instead adding + // numerator to difference (to prevent overflow in numerator). + add_uint_uint(difference, numerator, division_uint64_count, difference); + + // Adjust quotient and remaining shifts as a result of shifting numerator. + left_shift_uint(quotient, 1, division_uint64_count, quotient); + remaining_shifts--; + } + // Difference is the new numerator with denominator subtracted. + + // Update quotient to reflect subtraction. + *quotient |= 1; + + // Determine amount to shift numerator to bring MSB in alignment + // with denominator. + numerator_bits = + get_significant_bit_count_uint(difference, division_uint64_count); + int numerator_shift = denominator_bits - numerator_bits; + if (numerator_shift > remaining_shifts) + { + // Clip the maximum shift to determine only the integer + // (as opposed to fractional) bits. + numerator_shift = remaining_shifts; + } + + // Shift and update numerator. + if (numerator_bits > 0) + { + left_shift_uint(difference, numerator_shift, + division_uint64_count, numerator); + numerator_bits += numerator_shift; + } + else + { + // Difference is zero so no need to shift, just set to zero. + set_zero_uint(division_uint64_count, numerator); + } + + // Adjust quotient and remaining shifts as a result of + // shifting numerator. + left_shift_uint(quotient, numerator_shift, + division_uint64_count, quotient); + remaining_shifts -= numerator_shift; + } + + // Correct for shifting of denominator. + right_shift_uint(denominator, denominator_shift, + division_uint64_count, denominator); + denominator_bits -= denominator_shift; + + // We are done if remainder (which is stored in numerator) is zero. + if (numerator_bits == 0) + { + break; + } + + // Correct for shifting of denominator. + right_shift_uint(numerator, denominator_shift, + division_uint64_count, numerator); + numerator_bits -= denominator_shift; + + // Integrate quotient with invert coefficients. + // Calculate: invert_prior + -quotient * invert_curr + multiply_truncate_uint_uint(quotient, invert_curr, + uint64_count, invert_next); + invert_next_positive = !invert_curr_positive; + if (invert_prior_positive == invert_next_positive) + { + // If both sides of add have same sign, then simple add and + // do not need to worry about overflow due to known limits + // on the coefficients proved in the euclidean algorithm. + add_uint_uint(invert_prior, invert_next, uint64_count, invert_next); + } + else + { + // If both sides of add have opposite sign, then subtract + // and check for overflow. + uint64_t borrow = sub_uint_uint(invert_prior, invert_next, + uint64_count, invert_next); + if (borrow == 0) + { + // No borrow means |invert_prior| >= |invert_next|, + // so sign is same as invert_prior. + invert_next_positive = invert_prior_positive; + } + else + { + // Borrow means |invert prior| < |invert_next|, + // so sign is opposite of invert_prior. + invert_next_positive = !invert_prior_positive; + negate_uint(invert_next, uint64_count, invert_next); + } + } + + // Swap prior and curr, and then curr and next. + swap(invert_prior, invert_curr); + swap(invert_prior_positive, invert_curr_positive); + swap(invert_curr, invert_next); + swap(invert_curr_positive, invert_next_positive); + + // Swap numerator and denominator using pointer swings. + swap(numerator, denominator); + swap(numerator_bits, denominator_bits); + } + + if (!is_equal_uint(denominator, uint64_count, 1)) + { + // GCD is not one, so unable to find inverse. + return false; + } + + // Correct coefficient if negative by modulo. + if (!invert_curr_positive && !is_zero_uint(invert_curr, uint64_count)) + { + sub_uint_uint(modulus, invert_curr, uint64_count, invert_curr); + invert_curr_positive = true; + } + + // Set result. + set_uint_uint(invert_curr, uint64_count, result); + return true; + } + } +} diff --git a/src/seal/util/uintarithmod.h b/src/seal/util/uintarithmod.h new file mode 100644 index 000000000..17bce9d89 --- /dev/null +++ b/src/seal/util/uintarithmod.h @@ -0,0 +1,267 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include "seal/util/uintcore.h" +#include "seal/util/uintarith.h" +#include "seal/util/pointer.h" + +namespace seal +{ + namespace util + { + inline void increment_uint_mod(const std::uint64_t *operand, + const std::uint64_t *modulus, std::size_t uint64_count, + std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand) + { + throw std::invalid_argument("operand"); + } + if (!modulus) + { + throw std::invalid_argument("modulus"); + } + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (!result) + { + throw std::invalid_argument("result"); + } + if (is_greater_than_or_equal_uint_uint(operand, modulus, uint64_count)) + { + throw std::invalid_argument("operand"); + } + if (modulus == result) + { + throw std::invalid_argument("result cannot point to the same value as modulus"); + } +#endif + unsigned char carry = increment_uint(operand, uint64_count, result); + if (carry || + is_greater_than_or_equal_uint_uint(result, modulus, uint64_count)) + { + sub_uint_uint(result, modulus, uint64_count, result); + } + } + + inline void decrement_uint_mod(const std::uint64_t *operand, + const std::uint64_t *modulus, std::size_t uint64_count, + std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand) + { + throw std::invalid_argument("operand"); + } + if (!modulus) + { + throw std::invalid_argument("modulus"); + } + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (!result) + { + throw std::invalid_argument("result"); + } + if (is_greater_than_or_equal_uint_uint(operand, modulus, uint64_count)) + { + throw std::invalid_argument("operand"); + } + if (modulus == result) + { + throw std::invalid_argument("result cannot point to the same value as modulus"); + } +#endif + if (decrement_uint(operand, uint64_count, result)) + { + add_uint_uint(result, modulus, uint64_count, result); + } + } + + inline void negate_uint_mod(const std::uint64_t *operand, + const std::uint64_t *modulus, std::size_t uint64_count, + std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand) + { + throw std::invalid_argument("operand"); + } + if (!modulus) + { + throw std::invalid_argument("modulus"); + } + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (!result) + { + throw std::invalid_argument("result"); + } + if (is_greater_than_or_equal_uint_uint(operand, modulus, uint64_count)) + { + throw std::invalid_argument("operand"); + } +#endif + if (is_zero_uint(operand, uint64_count)) + { + // Negation of zero is zero. + set_zero_uint(uint64_count, result); + } + else + { + // Otherwise, we know operand > 0 and < modulus so subtract modulus - operand. + sub_uint_uint(modulus, operand, uint64_count, result); + } + } + + inline void div2_uint_mod(const std::uint64_t *operand, + const std::uint64_t *modulus, std::size_t uint64_count, + std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand) + { + throw std::invalid_argument("operand"); + } + if (!modulus) + { + throw std::invalid_argument("modulus"); + } + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (!result) + { + throw std::invalid_argument("result"); + } + if (!is_bit_set_uint(modulus, uint64_count, 0)) + { + throw std::invalid_argument("modulus"); + } + if (is_greater_than_or_equal_uint_uint(operand, modulus, uint64_count)) + { + throw std::invalid_argument("operand"); + } +#endif + if (*operand & 1) + { + unsigned char carry = add_uint_uint(operand, modulus, uint64_count, result); + right_shift_uint(result, 1, uint64_count, result); + if (carry) + { + set_bit_uint(result, uint64_count, + static_cast(uint64_count) * bits_per_uint64 - 1); + } + } + else + { + right_shift_uint(operand, 1, uint64_count, result); + } + } + + inline void add_uint_uint_mod(const std::uint64_t *operand1, + const std::uint64_t *operand2, const std::uint64_t *modulus, + std::size_t uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand1) + { + throw std::invalid_argument("operand1"); + } + if (!operand2) + { + throw std::invalid_argument("operand2"); + } + if (!modulus) + { + throw std::invalid_argument("modulus"); + } + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (!result) + { + throw std::invalid_argument("result"); + } + if (is_greater_than_or_equal_uint_uint(operand1, modulus, uint64_count)) + { + throw std::invalid_argument("operand1"); + } + if (is_greater_than_or_equal_uint_uint(operand2, modulus, uint64_count)) + { + throw std::invalid_argument("operand2"); + } + if (modulus == result) + { + throw std::invalid_argument("result cannot point to the same value as modulus"); + } +#endif + unsigned char carry = add_uint_uint(operand1, operand2, uint64_count, result); + if (carry || + is_greater_than_or_equal_uint_uint(result, modulus, uint64_count)) + { + sub_uint_uint(result, modulus, uint64_count, result); + } + } + + inline void sub_uint_uint_mod(const std::uint64_t *operand1, + const std::uint64_t *operand2, const std::uint64_t *modulus, + std::size_t uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!operand1) + { + throw std::invalid_argument("operand1"); + } + if (!operand2) + { + throw std::invalid_argument("operand2"); + } + if (!modulus) + { + throw std::invalid_argument("modulus"); + } + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (!result) + { + throw std::invalid_argument("result"); + } + if (is_greater_than_or_equal_uint_uint(operand1, modulus, uint64_count)) + { + throw std::invalid_argument("operand1"); + } + if (is_greater_than_or_equal_uint_uint(operand2, modulus, uint64_count)) + { + throw std::invalid_argument("operand2"); + } + if (modulus == result) + { + throw std::invalid_argument("result cannot point to the same value as modulus"); + } +#endif + if (sub_uint_uint(operand1, operand2, uint64_count, result)) + { + add_uint_uint(result, modulus, uint64_count, result); + } + } + + bool try_invert_uint_mod(const std::uint64_t *operand, + const std::uint64_t *modulus, std::size_t uint64_count, std::uint64_t *result, + MemoryPool &pool); + } +} diff --git a/src/seal/util/uintarithsmallmod.cpp b/src/seal/util/uintarithsmallmod.cpp new file mode 100644 index 000000000..12ecef00d --- /dev/null +++ b/src/seal/util/uintarithsmallmod.cpp @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/util/uintcore.h" +#include "seal/util/uintarith.h" +#include "seal/util/uintarithmod.h" +#include "seal/util/uintarithsmallmod.h" +#include + +using namespace std; + +namespace seal +{ + namespace util + { + bool is_primitive_root(uint64_t root, uint64_t degree, + const SmallModulus &modulus) + { +#ifdef SEAL_DEBUG + if (modulus.bit_count() < 2) + { + throw invalid_argument("modulus"); + } + if (root >= modulus.value()) + { + throw out_of_range("operand"); + } + if (get_power_of_two(degree) < 1) + { + throw invalid_argument("degree must be a power of two and at least two"); + } +#endif + if (root == 0) + { + return false; + } + + // We check if root is a degree-th root of unity in integers modulo + // modulus, where degree is a power of two. + // It suffices to check that root^(degree/2) is -1 modulo modulus. + return exponentiate_uint_mod( + root, degree >> 1, modulus) == (modulus.value() - 1); + } + + bool try_primitive_root(uint64_t degree, const SmallModulus &modulus, + uint64_t &destination) + { +#ifdef SEAL_DEBUG + if (modulus.bit_count() < 2) + { + throw invalid_argument("modulus"); + } + if (get_power_of_two(degree) < 1) + { + throw invalid_argument("degree must be a power of two and at least two"); + } +#endif + // We need to divide modulus-1 by degree to get the size of the + // quotient group + uint64_t size_entire_group = modulus.value() - 1; + + // Compute size of quotient group + uint64_t size_quotient_group = size_entire_group / degree; + + // size_entire_group must be divisible by degree, or otherwise the + // primitive root does not exist in integers modulo modulus + if (size_entire_group - size_quotient_group * degree != 0) + { + return false; + } + + // For randomness + random_device rd; + + int attempt_counter = 0; + int attempt_counter_max = 100; + do + { + attempt_counter++; + + // Set destination to be a random number modulo modulus + destination = (static_cast(rd()) << 32) | + static_cast(rd()); + destination %= modulus.value(); + + // Raise the random number to power the size of the quotient + // to get rid of irrelevant part + destination = exponentiate_uint_mod( + destination, size_quotient_group, modulus); + } while (!is_primitive_root(destination, degree, modulus) && + (attempt_counter < attempt_counter_max)); + + return is_primitive_root(destination, degree, modulus); + } + + bool try_minimal_primitive_root(uint64_t degree, + const SmallModulus &modulus, uint64_t &destination) + { + uint64_t root; + if (!try_primitive_root(degree, modulus, root)) + { + return false; + } + uint64_t generator_sq = multiply_uint_uint_mod(root, root, modulus); + uint64_t current_generator = root; + + // destination is going to always contain the smallest generator found + for (size_t i = 0; i < degree; i++) + { + // If our current generator is strictly smaller than destination, + // update + if (current_generator < root) + { + root = current_generator; + } + + // Then move on to the next generator + current_generator = multiply_uint_uint_mod( + current_generator, generator_sq, modulus); + } + + destination = root; + return true; + } + + uint64_t exponentiate_uint_mod(uint64_t operand, uint64_t exponent, + const SmallModulus &modulus) + { +#ifdef SEAL_DEBUG + if (modulus.is_zero()) + { + throw invalid_argument("modulus"); + } + if (operand >= modulus.value()) + { + throw invalid_argument("operand"); + } +#endif + // Fast cases + if (exponent == 0) + { + // Result is supposed to be only one digit + return 1; + } + + if (exponent == 1) + { + return operand; + } + + // Perform binary exponentiation. + uint64_t power = operand; + uint64_t product = 0; + uint64_t intermediate = 1; + + // Initially: power = operand and intermediate = 1, product is irrelevant. + while (true) + { + if (exponent & 1) + { + product = multiply_uint_uint_mod(power, intermediate, modulus); + swap(product, intermediate); + } + exponent >>= 1; + if (exponent == 0) + { + break; + } + product = multiply_uint_uint_mod(power, power, modulus); + swap(product, power); + } + return intermediate; + } + + void divide_uint_uint_mod_inplace(uint64_t *numerator, + const SmallModulus &modulus, size_t uint64_count, + uint64_t *quotient, MemoryPool &pool) + { + // Handle base cases + if (uint64_count == 2) + { + divide_uint128_uint64_inplace(numerator, modulus.value(), quotient); + return; + } + else if(uint64_count == 1) + { + *numerator = *numerator % modulus.value(); + *quotient = *numerator / modulus.value(); + return; + } + else + { + // If uint64_count > 2. + // x = numerator = x1 * 2^128 + x2. + // 2^128 = A*value + B. + + auto x1_alloc(allocate_uint(uint64_count - 2 , pool)); + uint64_t *x1 = x1_alloc.get(); + uint64_t x2[2]; + auto quot_alloc(allocate_uint(uint64_count, pool)); + uint64_t *quot = quot_alloc.get(); + auto rem_alloc(allocate_uint(uint64_count, pool)); + uint64_t *rem = rem_alloc.get(); + set_uint_uint(numerator + 2, uint64_count - 2, x1); + set_uint_uint(numerator, 2, x2); // x2 = (num) % 2^128. + + multiply_uint_uint(x1, uint64_count - 2, &modulus.const_ratio()[0], 2, + uint64_count, quot); // x1*A. + multiply_uint_uint64(x1, uint64_count - 2, modulus.const_ratio()[2], + uint64_count - 1, rem); // x1*B + add_uint_uint(rem, uint64_count - 1, x2, 2, 0, uint64_count, rem); // x1*B + x2; + + size_t remainder_uint64_count = get_significant_uint64_count_uint(rem, uint64_count); + divide_uint_uint_mod_inplace(rem, modulus, remainder_uint64_count, quotient, pool); + add_uint_uint(quotient, quot, uint64_count, quotient); + *numerator = rem[0]; + + return; + } + } + + uint64_t steps_to_galois_elt(int steps, size_t coeff_count) + { + uint32_t n = safe_cast(coeff_count); + uint32_t m32 = mul_safe(n, uint32_t(2)); + uint64_t m = static_cast(m32); + + if (steps == 0) + { + return m - 1; + } + else + { + // Extract sign of steps. When steps is positive, the rotation + // is to the left; when steps is negative, it is to the right. + bool sign = steps < 0; + uint32_t pos_steps = safe_cast(abs(steps)); + + if (pos_steps >= (n >> 1)) + { + throw invalid_argument("step count too large"); + } + + pos_steps &= m32 - 1; + if (sign) + { + steps = safe_cast(n >> 1) - safe_cast(pos_steps); + } + else + { + steps = safe_cast(pos_steps); + } + + // Construct Galois element for row rotation + uint64_t gen = 3; + uint64_t galois_elt = 1; + while(steps--) + { + galois_elt *= gen; + galois_elt &= m - 1; + } + return galois_elt; + } + } + } +} diff --git a/src/seal/util/uintarithsmallmod.h b/src/seal/util/uintarithsmallmod.h new file mode 100644 index 000000000..882b2111a --- /dev/null +++ b/src/seal/util/uintarithsmallmod.h @@ -0,0 +1,286 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include "seal/smallmodulus.h" +#include "seal/util/defines.h" +#include "seal/util/pointer.h" +#include "seal/util/numth.h" +#include "seal/util/uintarith.h" + +namespace seal +{ + namespace util + { + inline std::uint64_t increment_uint_mod(std::uint64_t operand, + const SmallModulus &modulus) + { +#ifdef SEAL_DEBUG + if (modulus.is_zero()) + { + throw std::invalid_argument("modulus"); + } + if (operand >= modulus.value()) + { + throw std::out_of_range("operand"); + } +#endif + operand++; + return operand - (modulus.value() & static_cast( + -static_cast(operand >= modulus.value()))); + } + + inline std::uint64_t decrement_uint_mod(std::uint64_t operand, + const SmallModulus &modulus) + { +#ifdef SEAL_DEBUG + if (modulus.is_zero()) + { + throw std::invalid_argument("modulus"); + } + if (operand >= modulus.value()) + { + throw std::out_of_range("operand"); + } +#endif + std::int64_t carry = (operand == 0); + return operand - 1 + (modulus.value() & + static_cast(-carry)); + } + + inline std::uint64_t negate_uint_mod(std::uint64_t operand, + const SmallModulus &modulus) + { +#ifdef SEAL_DEBUG + if (modulus.is_zero()) + { + throw std::invalid_argument("modulus"); + } + if (operand >= modulus.value()) + { + throw std::out_of_range("operand"); + } +#endif + std::int64_t non_zero = (operand != 0); + return (modulus.value() - operand) + & static_cast(-non_zero); + } + + inline std::uint64_t div2_uint_mod(std::uint64_t operand, + const SmallModulus &modulus) + { +#ifdef SEAL_DEBUG + if (modulus.is_zero()) + { + throw std::invalid_argument("modulus"); + } + if (operand >= modulus.value()) + { + throw std::out_of_range("operand"); + } +#endif + if (operand & 1) + { + unsigned long long temp; + int64_t carry = add_uint64(operand, modulus.value(), 0, &temp); + operand = temp >> 1; + if (carry) + { + return operand | (std::uint64_t(1) << (bits_per_uint64 - 1)); + } + return operand; + } + return operand >> 1; + } + + inline std::uint64_t add_uint_uint_mod(std::uint64_t operand1, + std::uint64_t operand2, const SmallModulus &modulus) + { +#ifdef SEAL_DEBUG + if (modulus.is_zero()) + { + throw std::invalid_argument("modulus"); + } + if (operand1 >= modulus.value()) + { + throw std::out_of_range("operand1"); + } + if (operand2 >= modulus.value()) + { + throw std::out_of_range("operand2"); + } +#endif + // Sum of operands modulo SmallModulus can never wrap around 2^64 + operand1 += operand2; + return operand1 - (modulus.value() & static_cast( + -static_cast(operand1 >= modulus.value()))); + } + + inline std::uint64_t sub_uint_uint_mod(std::uint64_t operand1, + std::uint64_t operand2, const SmallModulus &modulus) + { +#ifdef SEAL_DEBUG + if (modulus.is_zero()) + { + throw std::invalid_argument("modulus"); + } + + if (operand1 >= modulus.value()) + { + throw std::out_of_range("operand1"); + } + if (operand2 >= modulus.value()) + { + throw std::out_of_range("operand2"); + } +#endif + unsigned long long temp; + std::int64_t borrow = SEAL_SUB_BORROW_UINT64(operand1, operand2, 0, &temp); + return static_cast(temp) + + (modulus.value() & static_cast(-borrow)); + } + + template>> + inline std::uint64_t barrett_reduce_128(const T *input, + const SmallModulus &modulus) + { +#ifdef SEAL_DEBUG + if (!input) + { + throw std::invalid_argument("input"); + } + if (modulus.is_zero()) + { + throw std::invalid_argument("modulus"); + } +#endif + // Reduces input using base 2^64 Barrett reduction + // input allocation size must be 128 bits + + unsigned long long tmp1, tmp2[2], tmp3, carry; + const std::uint64_t *const_ratio = modulus.const_ratio().data(); + + // Multiply input and const_ratio + // Round 1 + multiply_uint64_hw64(input[0], const_ratio[0], &carry); + + multiply_uint64(input[0], const_ratio[1], tmp2); + tmp3 = tmp2[1] + add_uint64(tmp2[0], carry, 0, &tmp1); + + // Round 2 + multiply_uint64(input[1], const_ratio[0], tmp2); + carry = tmp2[1] + add_uint64(tmp1, tmp2[0], 0, &tmp1); + + // This is all we care about + tmp1 = input[1] * const_ratio[1] + tmp3 + carry; + + // Barrett subtraction + tmp3 = input[0] - tmp1 * modulus.value(); + + // Claim: One more subtraction is enough + return static_cast(tmp3) - + (modulus.value() & static_cast( + -static_cast(tmp3 >= modulus.value()))); + } + + inline std::uint64_t multiply_uint_uint_mod(std::uint64_t operand1, + std::uint64_t operand2, const SmallModulus &modulus) + { +#ifdef SEAL_DEBUG + if (modulus.is_zero()) + { + throw std::invalid_argument("modulus"); + } +#endif + unsigned long long z[2]; + multiply_uint64(operand1, operand2, z); + return barrett_reduce_128(z, modulus); + } + + inline void modulo_uint_inplace(std::uint64_t *value, + std::size_t value_uint64_count, const SmallModulus &modulus) + { +#ifdef SEAL_DEBUG + if (!value && value_uint64_count > 0) + { + throw std::invalid_argument("value"); + } +#endif + if (value_uint64_count == 1) + { + value[0] %= modulus.value(); + return; + } + + // Starting from the top, reduce always 128-bit blocks + for (std::size_t i = value_uint64_count - 1; i--; ) + { + value[i] = barrett_reduce_128(value + i, modulus); + value[i + 1] = 0; + } + } + + inline std::uint64_t modulo_uint(const std::uint64_t *value, + std::size_t value_uint64_count, const SmallModulus &modulus, + MemoryPool &pool) + { +#ifdef SEAL_DEBUG + if (!value && value_uint64_count) + { + throw std::invalid_argument("value"); + } + if (!value_uint64_count) + { + throw std::invalid_argument("value_uint64_count"); + } +#endif + if (value_uint64_count == 1) + { + // If value < modulus no operation is needed + return *value % modulus.value(); + } + + auto value_copy(allocate_uint(value_uint64_count, pool)); + set_uint_uint(value, value_uint64_count, value_copy.get()); + + // Starting from the top, reduce always 128-bit blocks + for (std::size_t i = value_uint64_count - 1; i--; ) + { + value_copy[i] = barrett_reduce_128(value_copy.get() + i, modulus); + } + + return value_copy[0]; + } + + inline bool try_invert_uint_mod(uint64_t operand, + const SmallModulus &modulus, std::uint64_t &result) + { + return try_mod_inverse(operand, modulus.value(), result); + } + + bool is_primitive_root(std::uint64_t root, std::uint64_t degree, + const SmallModulus &prime_modulus); + + // Try to find a primitive degree-th root of unity modulo small prime + // modulus, where degree must be a power of two. + bool try_primitive_root(std::uint64_t degree, + const SmallModulus &prime_modulus, std::uint64_t &destination); + + // Try to find the smallest (as integer) primitive degree-th root of + // unity modulo small prime modulus, where degree must be a power of two. + bool try_minimal_primitive_root(std::uint64_t degree, + const SmallModulus &prime_modulus, std::uint64_t &destination); + + std::uint64_t exponentiate_uint_mod(std::uint64_t operand, + std::uint64_t exponent, const SmallModulus &modulus); + + void divide_uint_uint_mod_inplace(uint64_t *numerator, + const SmallModulus &modulus, std::size_t uint64_count, + uint64_t *quotient, MemoryPool &pool); + + std::uint64_t steps_to_galois_elt(int steps, std::size_t coeff_count); + } +} diff --git a/src/seal/util/uintcore.cpp b/src/seal/util/uintcore.cpp new file mode 100644 index 000000000..967350e1d --- /dev/null +++ b/src/seal/util/uintcore.cpp @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "seal/util/common.h" +#include "seal/util/uintcore.h" +#include "seal/util/uintarith.h" +#include +#include + +using namespace std; + +namespace seal +{ + namespace util + { + string uint_to_hex_string(const uint64_t *value, size_t uint64_count) + { +#ifdef SEAL_DEBUG + if (uint64_count && !value) + { + throw invalid_argument("value"); + } +#endif + // Start with a string with a zero for each nibble in the array. + size_t num_nibbles = + mul_safe(uint64_count, static_cast(nibbles_per_uint64)); + string output(num_nibbles, '0'); + + // Iterate through each uint64 in array and set string with correct nibbles in hex. + size_t nibble_index = num_nibbles; + size_t leftmost_non_zero_pos = num_nibbles; + for (size_t i = 0; i < uint64_count; i++) + { + uint64_t part = *value++; + + // Iterate through each nibble in the current uint64. + for (size_t j = 0; j < nibbles_per_uint64; j++) + { + size_t nibble = safe_cast(part & uint64_t(0x0F)); + size_t pos = --nibble_index; + if (nibble != 0) + { + // If nibble is not zero, then update string and save this pos to determine + // number of leading zeros. + output[pos] = nibble_to_upper_hex(static_cast(nibble)); + leftmost_non_zero_pos = pos; + } + part >>= 4; + } + } + + // Trim string to remove leading zeros. + output.erase(0, leftmost_non_zero_pos); + + // Return 0 if nothing remains. + if (output.empty()) + { + return string("0"); + } + + return output; + } + + string uint_to_dec_string(const uint64_t *value, + size_t uint64_count, MemoryPool &pool) + { +#ifdef SEAL_DEBUG + if (uint64_count && !value) + { + throw invalid_argument("value"); + } +#endif + if (!uint64_count) + { + return string("0"); + } + auto remainder(allocate_uint(uint64_count, pool)); + auto quotient(allocate_uint(uint64_count, pool)); + auto base(allocate_uint(uint64_count, pool)); + uint64_t *remainderptr = remainder.get(); + uint64_t *quotientptr = quotient.get(); + uint64_t *baseptr = base.get(); + set_uint(10, uint64_count, baseptr); + set_uint_uint(value, uint64_count, remainderptr); + string output; + while (!is_zero_uint(remainderptr, uint64_count)) + { + divide_uint_uint_inplace(remainderptr, baseptr, + uint64_count, quotientptr, pool); + char digit = static_cast( + remainderptr[0] + static_cast('0')); + output += digit; + swap(remainderptr, quotientptr); + } + reverse(output.begin(), output.end()); + + // Return 0 if nothing remains. + if (output.empty()) + { + return string("0"); + } + + return output; + } + } +} diff --git a/src/seal/util/uintcore.h b/src/seal/util/uintcore.h new file mode 100644 index 000000000..46fb4ae9d --- /dev/null +++ b/src/seal/util/uintcore.h @@ -0,0 +1,636 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include +#include +#include "seal/util/common.h" +#include "seal/util/pointer.h" +#include "seal/util/defines.h" + +namespace seal +{ + namespace util + { + std::string uint_to_hex_string(const std::uint64_t *value, + std::size_t uint64_count); + + std::string uint_to_dec_string(const std::uint64_t *value, + std::size_t uint64_count, MemoryPool &pool); + + inline void hex_string_to_uint(const char *hex_string, + int char_count, std::size_t uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!hex_string && char_count > 0) + { + throw std::invalid_argument("hex_string"); + } + if (uint64_count && !result) + { + throw std::invalid_argument("result"); + } + if (unsigned_gt(get_hex_string_bit_count(hex_string, char_count), + mul_safe(uint64_count, static_cast(bits_per_uint64)))) + { + throw std::invalid_argument("hex_string"); + } +#endif + const char *hex_string_ptr = hex_string + char_count; + for (std::size_t uint64_index = 0; + uint64_index < uint64_count; uint64_index++) + { + std::uint64_t value = 0; + for (int bit_index = 0; bit_index < bits_per_uint64; + bit_index += bits_per_nibble) + { + if (hex_string_ptr == hex_string) + { + break; + } + char hex = *--hex_string_ptr; + int nibble = hex_to_nibble(hex); + if (nibble == -1) + { + throw std::invalid_argument("hex_value"); + } + value |= static_cast(nibble) << bit_index; + } + result[uint64_index] = value; + } + } + + inline auto allocate_uint(std::size_t uint64_count, MemoryPool &pool) + { + return allocate(uint64_count, pool); + } + + inline void set_zero_uint(std::size_t uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!result && uint64_count) + { + throw std::invalid_argument("result"); + } +#endif + std::fill_n(result, uint64_count, std::uint64_t(0)); + } + + inline auto allocate_zero_uint(std::size_t uint64_count, MemoryPool &pool) + { + return allocate(uint64_count, pool, std::uint64_t(0)); + } + + inline void set_uint(std::uint64_t value, std::size_t uint64_count, + std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (!result) + { + throw std::invalid_argument("result"); + } +#endif + *result++ = value; + for (; --uint64_count; result++) + { + *result = 0; + } + } + + inline void set_uint_uint(const std::uint64_t *value, + std::size_t uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!value && uint64_count) + { + throw std::invalid_argument("value"); + } + if (!result && uint64_count) + { + throw std::invalid_argument("result"); + } +#endif + if ((value == result) || !uint64_count) + { + return; + } + std::copy_n(value, uint64_count, result); + } + + inline bool is_zero_uint(const std::uint64_t *value, + std::size_t uint64_count) + { +#ifdef SEAL_DEBUG + if (!value && uint64_count) + { + throw std::invalid_argument("value"); + } +#endif + return std::all_of(value, value + uint64_count, + [](auto coeff) -> bool { return !coeff; }); + } + + inline bool is_equal_uint(const std::uint64_t *value, + std::size_t uint64_count, std::uint64_t scalar) + { +#ifdef SEAL_DEBUG + if (!value) + { + throw std::invalid_argument("value"); + } + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } +#endif + if (*value++ != scalar) + { + return false; + } + return std::all_of(value, value + uint64_count - 1, + [](auto coeff) -> bool { return !coeff; }); + } + + inline bool is_high_bit_set_uint(const std::uint64_t *value, + std::size_t uint64_count) + { +#ifdef SEAL_DEBUG + if (!value) + { + throw std::invalid_argument("value"); + } + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } +#endif + return (value[uint64_count - 1] >> (bits_per_uint64 - 1)) != 0; + } + + inline bool is_bit_set_uint(const std::uint64_t *value, + std::size_t uint64_count SEAL_MAYBE_UNUSED, int bit_index) + { +#ifdef SEAL_DEBUG + if (!value) + { + throw std::invalid_argument("value"); + } + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (bit_index < 0 || + static_cast(bit_index) >= + static_cast(uint64_count) * bits_per_uint64) + { + throw std::invalid_argument("bit_index"); + } +#endif + int uint64_index = bit_index / bits_per_uint64; + int sub_bit_index = bit_index - uint64_index * bits_per_uint64; + return ((value[static_cast(uint64_index)] + >> sub_bit_index) & 1) != 0; + } + + inline void set_bit_uint(std::uint64_t *value, + std::size_t uint64_count SEAL_MAYBE_UNUSED, int bit_index) + { +#ifdef SEAL_DEBUG + if (!value) + { + throw std::invalid_argument("value"); + } + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } + if (bit_index < 0 || + static_cast(bit_index) >= + static_cast(uint64_count) * bits_per_uint64) + { + throw std::invalid_argument("bit_index"); + } +#endif + int uint64_index = bit_index / bits_per_uint64; + int sub_bit_index = bit_index % bits_per_uint64; + value[static_cast(uint64_index)] |= + std::uint64_t(1) << sub_bit_index; + } + + inline int get_significant_bit_count_uint( + const std::uint64_t *value, std::size_t uint64_count) + { +#ifdef SEAL_DEBUG + if (!value && uint64_count) + { + throw std::invalid_argument("value"); + } + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } +#endif + if (!uint64_count) + { + return 0; + } + + value += uint64_count - 1; + for (; *value == 0 && uint64_count > 1; uint64_count--) + { + value--; + } + + return static_cast(uint64_count - 1) * bits_per_uint64 + + get_significant_bit_count(*value); + } + + inline std::size_t get_significant_uint64_count_uint( + const std::uint64_t *value, std::size_t uint64_count) + { +#ifdef SEAL_DEBUG + if (!value && uint64_count) + { + throw std::invalid_argument("value"); + } + if (!uint64_count) + { + throw std::invalid_argument("uint64_count"); + } +#endif + value += uint64_count - 1; + for (; *value == 0 && uint64_count; uint64_count--) + { + value--; + } + + return uint64_count; + } + + inline void set_uint_uint(const std::uint64_t *value, + std::size_t value_uint64_count, + std::size_t result_uint64_count, std::uint64_t *result) + { +#ifdef SEAL_DEBUG + if (!value && value_uint64_count) + { + throw std::invalid_argument("value"); + } + if (!result && result_uint64_count) + { + throw std::invalid_argument("result"); + } +#endif + if (value == result || !value_uint64_count) + { + // Fast path to handle self assignment. + std::fill(result + value_uint64_count, + result + result_uint64_count, std::uint64_t(0)); + } + else + { + std::size_t min_uint64_count = + std::min(value_uint64_count, result_uint64_count); + std::copy_n(value, min_uint64_count, result); + std::fill(result + min_uint64_count, + result + result_uint64_count, std::uint64_t(0)); + } + } + + inline int get_power_of_two(std::uint64_t value) + { + if (value == 0 || (value & (value - 1)) != 0) + { + return -1; + } + + unsigned long result = 0; + SEAL_MSB_INDEX_UINT64(&result, value); + return static_cast(result); + } + + inline int get_power_of_two_minus_one(std::uint64_t value) + { + if (value == 0xFFFFFFFFFFFFFFFF) + { + return bits_per_uint64; + } + return get_power_of_two(value + 1); + } + + inline int get_power_of_two_uint(const std::uint64_t *operand, + std::size_t uint64_count) + { +#ifdef SEAL_DEBUG + if (!operand && uint64_count) + { + throw std::invalid_argument("operand"); + } +#endif + operand += uint64_count; + int long_index = safe_cast(uint64_count), local_result = -1; + for (; (long_index >= 1) && (local_result == -1); long_index--) + { + operand--; + local_result = get_power_of_two(*operand); + } + + // If local_result != -1, we've found a power-of-two highest order block, + // in which case need to check that rest are zero. + // If local_result == -1, operand is not power of two. + if (local_result == -1) + { + return -1; + } + + int zeros = 1; + for (int j = long_index; j >= 1; j--) + { + zeros &= (*--operand == 0); + } + + return add_safe(mul_safe(zeros, + add_safe(local_result, + mul_safe(long_index, bits_per_uint64))), zeros, -1); + } + + inline int get_power_of_two_minus_one_uint( + const std::uint64_t *operand, std::size_t uint64_count) + { +#ifdef SEAL_DEBUG + if (!operand && uint64_count) + { + throw std::invalid_argument("operand"); + } + if (unsigned_geq(uint64_count, std::numeric_limits::max())) + { + throw std::invalid_argument("uint64_count"); + } +#endif + operand += uint64_count; + int long_index = safe_cast(uint64_count), local_result = 0; + for (; (long_index >= 1) && (local_result == 0); long_index--) + { + operand--; + local_result = get_power_of_two_minus_one(*operand); + } + + // If local_result != -1, we've found a power-of-two-minus-one highest + // order block, in which case need to check that rest are ~0. + // If local_result == -1, operand is not power of two minus one. + if (local_result == -1) + { + return -1; + } + + int ones = 1; + for (int j = long_index; j >= 1; j--) + { + ones &= (~*--operand == 0); + } + + return add_safe(mul_safe(ones, + add_safe(local_result, + mul_safe(long_index, bits_per_uint64))), ones, -1); + } + + inline void filter_highbits_uint(std::uint64_t *operand, + std::size_t uint64_count, int bit_count) + { + std::size_t bits_per_uint64_sz = static_cast(bits_per_uint64); +#ifdef SEAL_DEBUG + if (!operand && uint64_count) + { + throw std::invalid_argument("operand"); + } + if (bit_count < 0 || unsigned_gt(bit_count, + mul_safe(uint64_count, bits_per_uint64_sz))) + { + throw std::invalid_argument("bit_count"); + } +#endif + if (unsigned_eq(bit_count, mul_safe(uint64_count, bits_per_uint64_sz))) + { + return; + } + int uint64_index = bit_count / bits_per_uint64; + int subbit_index = bit_count - uint64_index * bits_per_uint64; + operand += uint64_index; + *operand++ &= (std::uint64_t(1) << subbit_index) - 1; + for (int long_index = uint64_index + 1; + unsigned_lt(long_index, uint64_count); long_index++) + { + *operand++ = 0; + } + } + + inline auto duplicate_uint_if_needed(const std::uint64_t *input, + std::size_t uint64_count, std::size_t new_uint64_count, + bool force, MemoryPool &pool) + { +#ifdef SEAL_DEBUG + if (!input && uint64_count) + { + throw std::invalid_argument("uint"); + } +#endif + if (!force && uint64_count >= new_uint64_count) + { + return ConstPointer::Aliasing(input); + } + + auto allocation(allocate(new_uint64_count, pool)); + set_uint_uint(input, uint64_count, new_uint64_count, allocation.get()); + return ConstPointer(std::move(allocation)); + } + + inline int compare_uint_uint(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t uint64_count) + { +#ifdef SEAL_DEBUG + if (!operand1 && uint64_count) + { + throw std::invalid_argument("operand1"); + } + if (!operand2 && uint64_count) + { + throw std::invalid_argument("operand2"); + } +#endif + int result = 0; + operand1 += uint64_count - 1; + operand2 += uint64_count - 1; + + for (; (result == 0) && uint64_count--; operand1--, operand2--) + { + result = (*operand1 > *operand2) - (*operand1 < *operand2); + } + return result; + } + + inline int compare_uint_uint(const std::uint64_t *operand1, + std::size_t operand1_uint64_count, const std::uint64_t *operand2, + std::size_t operand2_uint64_count) + { +#ifdef SEAL_DEBUG + if (!operand1 && operand1_uint64_count) + { + throw std::invalid_argument("operand1"); + } + if (!operand2 && operand2_uint64_count) + { + throw std::invalid_argument("operand2"); + } +#endif + int result = 0; + operand1 += operand1_uint64_count - 1; + operand2 += operand2_uint64_count - 1; + + std::size_t min_uint64_count = + std::min(operand1_uint64_count, operand2_uint64_count); + + operand1_uint64_count -= min_uint64_count; + for (; (result == 0) && operand1_uint64_count--; operand1--) + { + result = (*operand1 > 0); + } + + operand2_uint64_count -= min_uint64_count; + for (; (result == 0) && operand2_uint64_count--; operand2--) + { + result = -(*operand2 > 0); + } + + for (; (result == 0) && min_uint64_count--; operand1--, operand2--) + { + result = (*operand1 > *operand2) - (*operand1 < *operand2); + } + return result; + } + + inline bool is_greater_than_uint_uint(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t uint64_count) + { + return compare_uint_uint(operand1, operand2, uint64_count) > 0; + } + + inline bool is_greater_than_or_equal_uint_uint(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t uint64_count) + { + return compare_uint_uint(operand1, operand2, uint64_count) >= 0; + } + + inline bool is_less_than_uint_uint(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t uint64_count) + { + return compare_uint_uint(operand1, operand2, uint64_count) < 0; + } + + inline bool is_less_than_or_equal_uint_uint(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t uint64_count) + { + return compare_uint_uint(operand1, operand2, uint64_count) <= 0; + } + + inline bool is_equal_uint_uint(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t uint64_count) + { + return compare_uint_uint(operand1, operand2, uint64_count) == 0; + } + + inline bool is_not_equal_uint_uint(const std::uint64_t *operand1, + const std::uint64_t *operand2, std::size_t uint64_count) + { + return compare_uint_uint(operand1, operand2, uint64_count) != 0; + } + + inline bool is_greater_than_uint_uint(const std::uint64_t *operand1, + std::size_t operand1_uint64_count, const std::uint64_t *operand2, + std::size_t operand2_uint64_count) + { + return compare_uint_uint(operand1, operand1_uint64_count, operand2, + operand2_uint64_count) > 0; + } + + inline bool is_greater_than_or_equal_uint_uint( + const std::uint64_t *operand1, std::size_t operand1_uint64_count, + const std::uint64_t *operand2, std::size_t operand2_uint64_count) + { + return compare_uint_uint(operand1, operand1_uint64_count, operand2, + operand2_uint64_count) >= 0; + } + + inline bool is_less_than_uint_uint(const std::uint64_t *operand1, + std::size_t operand1_uint64_count, const std::uint64_t *operand2, + std::size_t operand2_uint64_count) + { + return compare_uint_uint(operand1, operand1_uint64_count, operand2, + operand2_uint64_count) < 0; + } + + inline bool is_less_than_or_equal_uint_uint(const std::uint64_t *operand1, + std::size_t operand1_uint64_count, const std::uint64_t *operand2, + std::size_t operand2_uint64_count) + { + return compare_uint_uint(operand1, operand1_uint64_count, operand2, + operand2_uint64_count) <= 0; + } + + inline bool is_equal_uint_uint(const std::uint64_t *operand1, + std::size_t operand1_uint64_count, const std::uint64_t *operand2, + std::size_t operand2_uint64_count) + { + return compare_uint_uint(operand1, operand1_uint64_count, operand2, + operand2_uint64_count) == 0; + } + + inline bool is_not_equal_uint_uint(const std::uint64_t *operand1, + std::size_t operand1_uint64_count, const std::uint64_t *operand2, + std::size_t operand2_uint64_count) + { + return compare_uint_uint(operand1, operand1_uint64_count, operand2, + operand2_uint64_count) != 0; + } + + inline std::uint64_t hamming_weight(std::uint64_t value) + { + std::uint64_t res = 0; + while (value) + { + res++; + value &= value - 1; + } + return res; + } + + inline std::uint64_t hamming_weight_split(std::uint64_t value) + { + std::uint64_t hwx = hamming_weight(value); + std::uint64_t target = (hwx + 1) >> 1; + std::uint64_t now = 0; + std::uint64_t result = 0; + + for (int i = 0; i < bits_per_uint64; i++) + { + std::uint64_t xbit = value & 1; + value = value >> 1; + now += xbit; + result += (xbit << i); + + if (now >= target) + { + break; + } + } + return result; + } + } +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 000000000..ae15119e5 --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +cmake_minimum_required(VERSION 3.10) + +project(SEALTest VERSION 3.1.0 LANGUAGES CXX) + +# Executable will be in ../bin +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/../bin) + +add_executable(sealtest seal/testrunner.cpp) + +# Import SEAL +find_package(SEAL 3.1.0 EXACT REQUIRED) + +if(SEAL_ENFORCE_HE_STD_SECURITY) + message(FATAL_ERROR "SEAL is configured with SEAL_ENFORCE_HE_STD_SECURITY=ON which is incompatible with unit tests") +endif() + +# Import Google target_link_libraries +find_library(GTEST gtest) +if (NOT GTEST) + message(FATAL_ERROR "Failed to find Google Test library required for unit tests") +endif() + +# Link SEAL +target_link_libraries(sealtest SEAL::seal gtest) + +# Add source files +add_subdirectory(seal) diff --git a/tests/SEALTest.vcxproj b/tests/SEALTest.vcxproj new file mode 100644 index 000000000..dbdbd3ff0 --- /dev/null +++ b/tests/SEALTest.vcxproj @@ -0,0 +1,126 @@ + + + + + Debug + x64 + + + Release + x64 + + + + {0345DC4D-EFE3-460E-AB7E-AA6E05BB8DFF} + Win32Proj + 10.0.16299.0 + Application + v141 + Unicode + + + + + + + + + $(SolutionDir)bin\$(Platform)\$(Configuration)\ + $(ProjectDir)obj\$(Platform)\$(Configuration)\ + sealtest + + + $(SolutionDir)bin\$(Platform)\$(Configuration)\ + $(ProjectDir)obj\$(Platform)\$(Configuration)\ + sealtest + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + NotUsing + Disabled + X64;_SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + EnableFastChecks + MultiThreadedDebugDLL + Level3 + $(SolutionDir)/src;%(AdditionalIncludeDirectories) + + + true + Console + seal.lib;%(AdditionalDependencies) + $(SolutionDir)\lib\$(Platform)\$(Configuration);%(AdditionalLibraryDirectories) + + + + + NotUsing + X64;_SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + MultiThreadedDLL + Level3 + ProgramDatabase + $(SolutionDir)/src;%(AdditionalIncludeDirectories) + + + true + Console + true + true + seal.lib;%(AdditionalDependencies) + $(SolutionDir)\lib\$(Platform)\$(Configuration);%(AdditionalLibraryDirectories) + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + \ No newline at end of file diff --git a/tests/SEALTest.vcxproj.filters b/tests/SEALTest.vcxproj.filters new file mode 100644 index 000000000..a7d1bc536 --- /dev/null +++ b/tests/SEALTest.vcxproj.filters @@ -0,0 +1,137 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms + + + {6c39d93e-a64a-44b3-95ca-ba22fd03ea17} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files\util + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + + + + diff --git a/tests/packages.config b/tests/packages.config new file mode 100644 index 000000000..85c6f204b --- /dev/null +++ b/tests/packages.config @@ -0,0 +1,4 @@ + + + + diff --git a/tests/seal/CMakeLists.txt b/tests/seal/CMakeLists.txt new file mode 100644 index 000000000..155d34b28 --- /dev/null +++ b/tests/seal/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +target_sources(sealtest + PRIVATE + ${CMAKE_CURRENT_LIST_DIR}/batchencoder.cpp + ${CMAKE_CURRENT_LIST_DIR}/biguint.cpp + ${CMAKE_CURRENT_LIST_DIR}/ciphertext.cpp + ${CMAKE_CURRENT_LIST_DIR}/ckks.cpp + ${CMAKE_CURRENT_LIST_DIR}/context.cpp + ${CMAKE_CURRENT_LIST_DIR}/encoder.cpp + ${CMAKE_CURRENT_LIST_DIR}/encryptionparams.cpp + ${CMAKE_CURRENT_LIST_DIR}/encryptor.cpp + ${CMAKE_CURRENT_LIST_DIR}/evaluator.cpp + ${CMAKE_CURRENT_LIST_DIR}/galoiskeys.cpp + ${CMAKE_CURRENT_LIST_DIR}/intarray.cpp + ${CMAKE_CURRENT_LIST_DIR}/keygenerator.cpp + ${CMAKE_CURRENT_LIST_DIR}/memorymanager.cpp + ${CMAKE_CURRENT_LIST_DIR}/plaintext.cpp + ${CMAKE_CURRENT_LIST_DIR}/publickey.cpp + ${CMAKE_CURRENT_LIST_DIR}/randomgen.cpp + ${CMAKE_CURRENT_LIST_DIR}/relinkeys.cpp + ${CMAKE_CURRENT_LIST_DIR}/secretkey.cpp + ${CMAKE_CURRENT_LIST_DIR}/smallmodulus.cpp +) + +add_subdirectory(util) diff --git a/tests/seal/baseconverter.cpp b/tests/seal/baseconverter.cpp new file mode 100644 index 000000000..d0649c55e --- /dev/null +++ b/tests/seal/baseconverter.cpp @@ -0,0 +1,552 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "CppUnitTest.h" +#include +#include "util/mempool.h" +#include "util/uintcore.h" +#include "memorypoolhandle.h" +#include "smallmodulus.h" +#include "util/BaseConverter.h" +#include "util/uintarith.h" +#include "util/uintarithsmallmod.h" +#include "util/uintarithmod.h" +#include "primes.h" + +using namespace Microsoft::VisualStudio::CppUnitTestFramework; +using namespace seal::util; +using namespace seal; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST_CLASS(BaseConverterClass) + { + public: + TEST_METHOD(BaseConverterConstructor) + { + MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); + vector coeff_base; + vector aux_base; + SmallModulus mtilda = small_mods[10]; + SmallModulus msk = small_mods[11]; + SmallModulus plain_t = small_mods[9]; + int coeff_base_count = 4; + int aux_base_count = 4; + + for (int i = 0; i < coeff_base_count; ++i) + { + coeff_base.push_back(small_mods[i]); + aux_base.push_back(small_mods[i + coeff_base_count]); + } + + BaseConverter BaseConverter(coeff_base, 4, plain_t); + Assert::IsTrue(BaseConverter.is_generated()); + } + + TEST_METHOD(FastBConverter) + { + { + MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); + vector coeff_base; + vector aux_base; + SmallModulus plain_t = small_mods[9]; + int coeff_base_count = 2; + int aux_base_count = 2; + + for (int i = 0; i < coeff_base_count; ++i) + { + coeff_base.push_back(small_mods[i]); + aux_base.push_back(small_mods[i + coeff_base_count + 2]); + } + + BaseConverter BaseConverter(coeff_base, 1, plain_t); + Pointer input(allocate_uint(2, pool)); + Pointer output(allocate_uint(3, pool)); + + // the composed input is 0xffffffffffffff00ffffffffffffff + + input[0] = 4395513236581707780; + input[1] = 4395513390924464132; + + + output[0] = 0xFFFFFFFFFFFFFFFF; + output[1] = 0xFFFFFFFFFFFFFFFF; + output[2] = 0; + + Assert::IsTrue(BaseConverter.fastbconv(input.get(), output.get())); + Assert::AreEqual(static_cast(3116074317392112723), output[0]); + Assert::AreEqual(static_cast(1254200639185090240), output[1]); + Assert::AreEqual(static_cast(3528328721557038672), output[2]); + } + + { + MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); + vector coeff_base; + vector aux_base; + SmallModulus mtilda = small_mods[10]; + SmallModulus msk = small_mods[11]; + SmallModulus plain_t = small_mods[9]; + int coeff_base_count = 2; + int aux_base_count = 2; + + for (int i = 0; i < coeff_base_count; ++i) + { + coeff_base.push_back(small_mods[i]); + aux_base.push_back(small_mods[i + coeff_base_count + 2]); + } + BaseConverter BaseConverter(coeff_base, 4, plain_t); + Pointer input(allocate_uint(8, pool)); + Pointer output(allocate_uint(12, pool)); + + // the composed input is 0xffffffffffffff00ffffffffffffff for all coeffs + // mod q1 + input[0] = 4395513236581707780; // cons + input[1] = 4395513236581707780; // x + input[2] = 4395513236581707780; // x^2 + input[3] = 4395513236581707780; // x^3 + + //mod q2 + input[4] = 4395513390924464132; + input[5] = 4395513390924464132; + input[6] = 4395513390924464132; + input[7] = 4395513390924464132; + + output[0] = 0xFFFFFFFFFFFFFFFF; + output[1] = 0xFFFFFFFFFFFFFFFF; + output[2] = 0; + + Assert::IsTrue(BaseConverter.fastbconv(input.get(), output.get())); + Assert::AreEqual(static_cast(3116074317392112723), output[0]); + Assert::AreEqual(static_cast(3116074317392112723), output[1]); + Assert::AreEqual(static_cast(3116074317392112723), output[2]); + Assert::AreEqual(static_cast(3116074317392112723), output[3]); + + Assert::AreEqual(static_cast(1254200639185090240), output[4]); + Assert::AreEqual(static_cast(1254200639185090240), output[5]); + Assert::AreEqual(static_cast(1254200639185090240), output[6]); + Assert::AreEqual(static_cast(1254200639185090240), output[7]); + + Assert::AreEqual(static_cast(3528328721557038672), output[8]); + Assert::AreEqual(static_cast(3528328721557038672), output[9]); + Assert::AreEqual(static_cast(3528328721557038672), output[10]); + Assert::AreEqual(static_cast(3528328721557038672), output[11]); + } + } + + TEST_METHOD(FastBConvSK) + { + { + MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); + vector coeff_base; + vector aux_base; + SmallModulus mtilda = small_mods[10]; + SmallModulus msk = small_mods[4]; + SmallModulus plain_t = small_mods[9]; + + int coeff_base_count = 2; + int aux_base_count = 2; + for (int i = 0; i < coeff_base_count; ++i) + { + coeff_base.push_back(small_mods[i]); + aux_base.push_back(small_mods[i + coeff_base_count]); + } + + BaseConverter BaseConverter(coeff_base, 1, plain_t); + Pointer input(allocate_uint(3, pool)); + Pointer output(allocate_uint(2, pool)); + + // The composed input is 0xffffffffffffff00ffffffffffffff + + input[0] = 4395583330278772740; + input[1] = 4396634741790752772; + input[2] = 4396375252835237892; // mod msk + + output[0] = 0xFFFFFFFFFFFFFFF; + output[1] = 0xFFFFFFFFFFFFFFF; + + Assert::IsTrue(BaseConverter.fastbconv_sk(input.get(), output.get())); + Assert::AreEqual(static_cast(2494482839790051254), output[0]); + Assert::AreEqual(static_cast(218180408843610743), output[1]); + } + + { + MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); + vector coeff_base; + vector aux_base; + SmallModulus mtilda = small_mods[10]; + SmallModulus msk = small_mods[4]; + SmallModulus plain_t = small_mods[9]; + + int coeff_base_count = 2; + int aux_base_count = 2; + for (int i = 0; i < coeff_base_count; ++i) + { + coeff_base.push_back(small_mods[i]); + aux_base.push_back(small_mods[i + coeff_base_count]); + } + + BaseConverter BaseConverter(coeff_base, 4, plain_t); + Pointer input(allocate_uint(12, pool)); + Pointer output(allocate_uint(8, pool)); + + // The composed input is 0xffffffffffffff00ffffffffffffff + + input[0] = 4395583330278772740; // cons + input[1] = 4395583330278772740; // x + input[2] = 4395583330278772740; // x^2 + input[3] = 4395583330278772740; // x^3 + + input[4] = 4396634741790752772; + input[5] = 4396634741790752772; + input[6] = 4396634741790752772; + input[7] = 4396634741790752772; + + input[8] = 4396375252835237892; // mod msk + input[9] = 4396375252835237892; + input[10] = 4396375252835237892; + input[11] = 4396375252835237892; + + output[0] = 0xFFFFFFFFFFFFFFF; + output[1] = 0xFFFFFFFFFFFFFFF; + + Assert::IsTrue(BaseConverter.fastbconv_sk(input.get(), output.get())); + Assert::AreEqual(static_cast(2494482839790051254), output[0]); //mod q1 + Assert::AreEqual(static_cast(2494482839790051254), output[1]); + Assert::AreEqual(static_cast(2494482839790051254), output[2]); + Assert::AreEqual(static_cast(2494482839790051254), output[3]); + + Assert::AreEqual(static_cast(218180408843610743), output[4]); //mod q2 + Assert::AreEqual(static_cast(218180408843610743), output[5]); + Assert::AreEqual(static_cast(218180408843610743), output[6]); + Assert::AreEqual(static_cast(218180408843610743), output[7]); + } + + } + + TEST_METHOD(MontRq) + { + { + MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); + vector coeff_base; + vector aux_base; + SmallModulus mtilda = small_mods[5]; + SmallModulus msk = small_mods[4]; + SmallModulus plain_t = small_mods[9]; + + int coeff_base_count = 2; + int aux_base_count = 2; + for (int i = 0; i < coeff_base_count; ++i) + { + coeff_base.push_back(small_mods[i]); + aux_base.push_back(small_mods[i + coeff_base_count]); + } + + BaseConverter BaseConverter(coeff_base, 1, plain_t); + Pointer input(allocate_uint(4, pool)); + Pointer output(allocate_uint(3, pool)); + + // The composed input is 0xffffffffffffff00ffffffffffffff + + input[0] = 4395583330278772740; // mod m1 + input[1] = 4396634741790752772; // mod m2 + input[2] = 4396375252835237892; // mod msk + input[3] = 4396146554501595140; // mod m_tilde + + output[0] = 0xfffffffff; + output[1] = 0x00fffffff; + output[2] = 0; + + Assert::IsTrue(BaseConverter.mont_rq(input.get(), output.get())); + Assert::AreEqual(static_cast(1412154008057360306), output[0]); + Assert::AreEqual(static_cast(3215947095329058299), output[1]); + Assert::AreEqual(static_cast(1636465626706639696), output[2]); + } + + { + MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); + vector coeff_base; + vector aux_base; + SmallModulus mtilda = small_mods[5]; + SmallModulus msk = small_mods[4]; + SmallModulus plain_t = small_mods[9]; + + int coeff_base_count = 2; + int aux_base_count = 2; + for (int i = 0; i < coeff_base_count; ++i) + { + coeff_base.push_back(small_mods[i]); + aux_base.push_back(small_mods[i + coeff_base_count]); + } + + BaseConverter BaseConverter(coeff_base, 3, plain_t); + Pointer input(allocate_uint(12, pool)); + Pointer output(allocate_uint(9, pool)); + + // The composed input is 0xffffffffffffff00ffffffffffffff for all coeffs + + input[0] = 4395583330278772740; // cons mod m1 + input[1] = 4395583330278772740; // x mod m1 + input[2] = 4395583330278772740; // x^2 mod m1 + + input[3] = 4396634741790752772; // cons mod m2 + input[4] = 4396634741790752772; // x mod m2 + input[5] = 4396634741790752772; // x^2 mod m2 + + input[6] = 4396375252835237892; // cons mod msk + input[7] = 4396375252835237892; // x mod msk + input[8] = 4396375252835237892; // x^2 mod msk + + input[9] = 4396146554501595140; // cons mod m_tilde + input[10] = 4396146554501595140; // x mod m_tilde + input[11] = 4396146554501595140; // x^2 mod m_tilde + + output[0] = 0xfffffffff; + output[1] = 0x00fffffff; + output[2] = 0; + + Assert::IsTrue(BaseConverter.mont_rq(input.get(), output.get())); + Assert::AreEqual(static_cast(1412154008057360306), output[0]); + Assert::AreEqual(static_cast(1412154008057360306), output[1]); + Assert::AreEqual(static_cast(1412154008057360306), output[2]); + + Assert::AreEqual(static_cast(3215947095329058299), output[3]); + Assert::AreEqual(static_cast(3215947095329058299), output[4]); + Assert::AreEqual(static_cast(3215947095329058299), output[5]); + + Assert::AreEqual(static_cast(1636465626706639696), output[6]); + Assert::AreEqual(static_cast(1636465626706639696), output[7]); + Assert::AreEqual(static_cast(1636465626706639696), output[8]); + } + } + + TEST_METHOD(FastFloor) + { + { + MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); + vector coeff_base; + vector aux_base; + SmallModulus mtilda = small_mods[5]; + SmallModulus msk = small_mods[4]; + SmallModulus plain_t = small_mods[9]; + + int coeff_base_count = 2; + int aux_base_count = 2; + for (int i = 0; i < coeff_base_count; ++i) + { + coeff_base.push_back(small_mods[i]); + aux_base.push_back(small_mods[i + coeff_base_count]); + } + + BaseConverter BaseConverter(coeff_base, 1, plain_t); + Pointer input(allocate_uint(5, pool)); + Pointer output(allocate_uint(3, pool)); + + // The composed input is 0xffffffffffffff00ffffffffffffff + + input[0] = 4395513236581707780; // mod q1 + input[1] = 4395513390924464132; // mod q2 + input[2] = 4395583330278772740; // mod m1 + input[3] = 4396634741790752772; // mod m2 + input[4] = 4396375252835237892; // mod msk + + output[0] = 0xfffffffff; + output[1] = 0x00fffffff; + output[2] = 0; + + Assert::IsTrue(BaseConverter.fast_floor(input.get(), output.get())); + + // The result for all moduli is equal to -1 since the composed input is small + // Assert::AreEqual(static_cast(4611686018393899008), output[0]); + // Assert::AreEqual(static_cast(4611686018293432320), output[1]); + // Assert::AreEqual(static_cast(4611686018309947392), output[2]); + + // The composed input is 0xffffffffffffff00ffffffffffffff00ff + + input[0] = 17574536613119; // mod q1 + input[1] = 10132675570633983; // mod q2 + input[2] = 3113399115422302529; // mod m1 + input[3] = 1298513899176416785; // mod m2 + input[4] = 3518991311999157564; // mod msk + + output[0] = 0xfffffffff; + output[1] = 0x00fffffff; + output[2] = 0; + + // Since input > q1*q2, the result should be floor(x/(q1*q2)) - alpha (alpha = {0 or 1}) + Assert::IsTrue(BaseConverter.fast_floor(input.get(), output.get())); + Assert::AreEqual(static_cast(0xfff), output[0]); + Assert::AreEqual(static_cast(0xfff), output[1]); + Assert::AreEqual(static_cast(0xfff), output[2]); + + // The composed input is 0xffffffffffffff00ffffffffffffff00ffff + + input[0] = 4499081372958719; // mod q1 + input[1] = 2593964946082299903; // mod q2 + input[2] = 4013821342825660755; // mod m1 + input[3] = 457963018288239031; // mod m2 + input[4] = 1691919900291185724; // mod msk + + output[0] = 0xfffffffff; + output[1] = 0x00fffffff; + output[2] = 0; + + // Since input > q1*q2, the result should be floor(x/(q1*q2)) - alpha (alpha = {0 or 1}) + Assert::IsTrue(BaseConverter.fast_floor(input.get(), output.get())); + Assert::AreEqual(static_cast(0xfffff), output[0]); + Assert::AreEqual(static_cast(0xfffff), output[1]); + Assert::AreEqual(static_cast(0xfffff), output[2]); + } + + { + MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); + vector coeff_base; + vector aux_base; + SmallModulus plain_t = small_mods[9]; + + int coeff_base_count = 2; + int aux_base_count = 2; + for (int i = 0; i < coeff_base_count; ++i) + { + coeff_base.push_back(small_mods[i]); + } + + BaseConverter BaseConverter(coeff_base, 2, plain_t); + Pointer input(allocate_uint(10, pool)); + Pointer output(allocate_uint(6, pool)); + + input[0] = 4499081372958719; // mod q1 + input[1] = 4499081372958719; // mod q1 + + input[2] = 2593964946082299903; // mod q2 + input[3] = 2593964946082299903; // mod q2 + + input[4] = 4013821342825660755; // mod m1 + input[5] = 4013821342825660755; // mod m1 + + input[6] = 457963018288239031; // mod m2 + input[7] = 457963018288239031; // mod m2 + + input[8] = 1691919900291185724; // mod msk + input[9] = 1691919900291185724; // mod msk + + output[0] = 0xfffffffff; + output[1] = 0x00fffffff; + output[2] = 0; + + // Since input > q1*q2, the result should be floor(x/(q1*q2)) - alpha (alpha = {0 or 1}) + Assert::IsTrue(BaseConverter.fast_floor(input.get(), output.get())); + Assert::AreEqual(static_cast(0xfffff), output[0]); + Assert::AreEqual(static_cast(0xfffff), output[1]); + + Assert::AreEqual(static_cast(0xfffff), output[2]); + Assert::AreEqual(static_cast(0xfffff), output[3]); + + Assert::AreEqual(static_cast(0xfffff), output[4]); + Assert::AreEqual(static_cast(0xfffff), output[5]); + } + + } + + TEST_METHOD(FastBConver_mtilde) + { + MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); + vector coeff_base; + vector aux_base; + SmallModulus mtilda = small_mods[5]; + SmallModulus msk = small_mods[4]; + SmallModulus plain_t = small_mods[9]; + + int coeff_base_count = 2; + int aux_base_count = 2; + for (int i = 0; i < coeff_base_count; ++i) + { + coeff_base.push_back(small_mods[i]); + aux_base.push_back(small_mods[i + coeff_base_count]); + } + + BaseConverter BaseConverter(coeff_base, 3, plain_t); + Pointer input(allocate_uint(6, pool)); + Pointer output(allocate_uint(12, pool)); + + // The composed input is 0xffffffffffffff00ffffffffffffff for all coeffs + + input[0] = 4395513236581707780; // cons mod q1 + input[1] = 4395513236581707780; // x mod q1 + input[2] = 4395513236581707780; // x^2 mod q1 + + input[3] = 4395513390924464132; // cons mod q2 + input[4] = 4395513390924464132; // x mod q2 + input[5] = 4395513390924464132; // x^2 mod q2 + + output[0] = 0xffffffff; + output[1] = 0; + output[2] = 0xffffff; + output[3] = 0xffffff; + + Assert::IsTrue(BaseConverter.fastbconv_mtilde(input.get(), output.get())); + Assert::AreEqual(static_cast(3116074317392112723), output[0]);//mod m1 + Assert::AreEqual(static_cast(3116074317392112723), output[1]); + Assert::AreEqual(static_cast(3116074317392112723), output[2]); + + Assert::AreEqual(static_cast(1254200639185090240), output[3]);//mod m2 + Assert::AreEqual(static_cast(1254200639185090240), output[4]); + Assert::AreEqual(static_cast(1254200639185090240), output[5]); + + Assert::AreEqual(static_cast(3528328721557038672), output[6]);//mod msk + Assert::AreEqual(static_cast(3528328721557038672), output[7]); + Assert::AreEqual(static_cast(3528328721557038672), output[8]); + + Assert::AreEqual(static_cast(849325434816160659), output[9]);//mod m_tilde + Assert::AreEqual(static_cast(849325434816160659), output[10]); + Assert::AreEqual(static_cast(849325434816160659), output[11]); + } + + TEST_METHOD(FastBConvert_plain_gamma) + { + MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); + vector coeff_base; + vector aux_base; + SmallModulus plain_t = small_mods[9]; + + int coeff_base_count = 2; + int aux_base_count = 2; + for (int i = 0; i < coeff_base_count; ++i) + { + coeff_base.push_back(small_mods[i]); + aux_base.push_back(small_mods[i + coeff_base_count]); + } + + BaseConverter BaseConverter(coeff_base, 3, plain_t); + Pointer input(allocate_uint(6, pool)); + Pointer output(allocate_uint(6, pool)); + + // The composed input is 0xffffffffffffff00ffffffffffffff for all coeffs + + input[0] = 4395513236581707780; // cons mod q1 + input[1] = 4395513236581707780; // x mod q1 + input[2] = 4395513236581707780; // x^2 mod q1 + + input[3] = 4395513390924464132; // cons mod q2 + input[4] = 4395513390924464132; // x mod q2 + input[5] = 4395513390924464132; // x^2 mod q2 + + output[0] = 0xffffffff; + output[1] = 0; + output[2] = 0xffffff; + output[3] = 0xffffff; + + Assert::IsTrue(BaseConverter.fastbconv_plain_gamma(input.get(), output.get())); + Assert::AreEqual(static_cast(1950841694949736435), output[0]);//mod plain modulus + Assert::AreEqual(static_cast(1950841694949736435), output[1]); + Assert::AreEqual(static_cast(1950841694949736435), output[2]); + + Assert::AreEqual(static_cast(3744510248429639755), output[3]);//mod gamma + Assert::AreEqual(static_cast(3744510248429639755), output[4]); + Assert::AreEqual(static_cast(3744510248429639755), output[5]); + } + }; + } +} diff --git a/tests/seal/batchencoder.cpp b/tests/seal/batchencoder.cpp new file mode 100644 index 000000000..a213f0aa1 --- /dev/null +++ b/tests/seal/batchencoder.cpp @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/batchencoder.h" +#include "seal/context.h" +#include "seal/defaultparams.h" +#include "seal/keygenerator.h" +#include +#include + +using namespace seal; +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + TEST(BatchEncoderTest, BatchUnbatchUIntVector) + { + EncryptionParameters parms(scheme_type::BFV); + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus({ small_mods_60bit(0) }); + parms.set_plain_modulus(257); + + auto context = SEALContext::Create(parms); + ASSERT_TRUE(context->context_data()->qualifiers().using_batching); + + BatchEncoder batch_encoder(context); + ASSERT_EQ(64ULL, batch_encoder.slot_count()); + vector plain_vec; + for (size_t i = 0; i < batch_encoder.slot_count(); i++) + { + plain_vec.push_back(i); + } + + Plaintext plain; + batch_encoder.encode(plain_vec, plain); + vector plain_vec2; + batch_encoder.decode(plain, plain_vec2); + ASSERT_TRUE(plain_vec == plain_vec2); + + for (size_t i = 0; i < batch_encoder.slot_count(); i++) + { + plain_vec[i] = 5; + } + batch_encoder.encode(plain_vec, plain); + ASSERT_TRUE(plain.to_string() == "5"); + batch_encoder.decode(plain, plain_vec2); + ASSERT_TRUE(plain_vec == plain_vec2); + + vector short_plain_vec; + for (size_t i = 0; i < 20; i++) + { + short_plain_vec.push_back(i); + } + batch_encoder.encode(short_plain_vec, plain); + vector short_plain_vec2; + batch_encoder.decode(plain, short_plain_vec2); + ASSERT_EQ(20ULL, short_plain_vec.size()); + ASSERT_EQ(64ULL, short_plain_vec2.size()); + for (size_t i = 0; i < 20; i++) + { + ASSERT_EQ(short_plain_vec[i], short_plain_vec2[i]); + } + for (size_t i = 20; i < batch_encoder.slot_count(); i++) + { + ASSERT_EQ(0ULL, short_plain_vec2[i]); + } + } + + TEST(BatchEncoderTest, BatchUnbatchIntVector) + { + EncryptionParameters parms(scheme_type::BFV); + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus({ small_mods_60bit(0) }); + parms.set_plain_modulus(257); + + auto context = SEALContext::Create(parms); + ASSERT_TRUE(context->context_data()->qualifiers().using_batching); + + BatchEncoder batch_encoder(context); + ASSERT_EQ(64ULL, batch_encoder.slot_count()); + vector plain_vec; + for (size_t i = 0; i < batch_encoder.slot_count(); i++) + { + plain_vec.push_back(static_cast(i * (1 - 2 * (i % 2)))); + } + + Plaintext plain; + batch_encoder.encode(plain_vec, plain); + vector plain_vec2; + batch_encoder.decode(plain, plain_vec2); + ASSERT_TRUE(plain_vec == plain_vec2); + + for (size_t i = 0; i < batch_encoder.slot_count(); i++) + { + plain_vec[i] = -5; + } + batch_encoder.encode(plain_vec, plain); + ASSERT_TRUE(plain.to_string() == "FC"); + batch_encoder.decode(plain, plain_vec2); + ASSERT_TRUE(plain_vec == plain_vec2); + + vector short_plain_vec; + for (size_t i = 0; i < 20; i++) + { + short_plain_vec.push_back(static_cast(i * (1 - 2 * (i % 2)))); + } + batch_encoder.encode(short_plain_vec, plain); + vector short_plain_vec2; + batch_encoder.decode(plain, short_plain_vec2); + ASSERT_EQ(20ULL, short_plain_vec.size()); + ASSERT_EQ(64ULL, short_plain_vec2.size()); + for (size_t i = 0; i < 20; i++) + { + ASSERT_TRUE(short_plain_vec[i] == short_plain_vec2[i]); + } + for (size_t i = 20; i < batch_encoder.slot_count(); i++) + { + ASSERT_TRUE(0LL == short_plain_vec2[i]); + } + } + + TEST(BatchEncoderTest, BatchUnbatchPlaintext) + { + EncryptionParameters parms(scheme_type::BFV); + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus({ small_mods_60bit(0) }); + parms.set_plain_modulus(257); + + auto context = SEALContext::Create(parms); + ASSERT_TRUE(context->context_data()->qualifiers().using_batching); + + BatchEncoder batch_encoder(context); + ASSERT_EQ(64ULL, batch_encoder.slot_count()); + Plaintext plain(batch_encoder.slot_count()); + for (size_t i = 0; i < batch_encoder.slot_count(); i++) + { + plain[i] = i; + } + + batch_encoder.encode(plain); + batch_encoder.decode(plain); + for (size_t i = 0; i < batch_encoder.slot_count(); i++) + { + ASSERT_TRUE(plain[i] == i); + } + + for (size_t i = 0; i < batch_encoder.slot_count(); i++) + { + plain[i] = 5; + } + batch_encoder.encode(plain); + ASSERT_TRUE(plain.to_string() == "5"); + batch_encoder.decode(plain); + for (size_t i = 0; i < batch_encoder.slot_count(); i++) + { + ASSERT_EQ(5ULL, plain[i]); + } + + Plaintext short_plain(20); + for (size_t i = 0; i < 20; i++) + { + short_plain[i] = i; + } + batch_encoder.encode(short_plain); + batch_encoder.decode(short_plain); + for (size_t i = 0; i < 20; i++) + { + ASSERT_TRUE(short_plain[i] == i); + } + for (size_t i = 20; i < batch_encoder.slot_count(); i++) + { + ASSERT_TRUE(short_plain[i] == 0); + } + } +} diff --git a/tests/seal/biguint.cpp b/tests/seal/biguint.cpp new file mode 100644 index 000000000..c8903bd96 --- /dev/null +++ b/tests/seal/biguint.cpp @@ -0,0 +1,378 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/biguint.h" +#include "seal/util/defines.h" + +using namespace seal; +using namespace std; + +namespace SEALTest +{ + TEST(BigUnsignedInt, EmptyBigUInt) + { + BigUInt uint; + ASSERT_EQ(0, uint.bit_count()); + ASSERT_TRUE(nullptr == uint.data()); + ASSERT_EQ(0ULL, uint.byte_count()); + ASSERT_EQ(0ULL, uint.uint64_count()); + ASSERT_EQ(0, uint.significant_bit_count()); + ASSERT_TRUE("0" == uint.to_string()); + ASSERT_TRUE(uint.is_zero()); + ASSERT_FALSE(uint.is_alias()); + uint.set_zero(); + + BigUInt uint2; + ASSERT_TRUE(uint == uint2); + ASSERT_FALSE(uint != uint2); + + uint.resize(1); + ASSERT_EQ(1, uint.bit_count()); + ASSERT_TRUE(nullptr != uint.data()); + ASSERT_FALSE(uint.is_alias()); + + uint.resize(0); + ASSERT_EQ(0, uint.bit_count()); + ASSERT_TRUE(nullptr == uint.data()); + ASSERT_FALSE(uint.is_alias()); + } + + TEST(BigUnsignedInt, BigUInt64Bits) + { + BigUInt uint(64); + ASSERT_EQ(64, uint.bit_count()); + ASSERT_TRUE(nullptr != uint.data()); + ASSERT_EQ(8ULL, uint.byte_count()); + ASSERT_EQ(1ULL, uint.uint64_count()); + ASSERT_EQ(0, uint.significant_bit_count()); + ASSERT_TRUE("0" == uint.to_string()); + ASSERT_TRUE(uint.is_zero()); + ASSERT_EQ(static_cast(0), *uint.data()); + ASSERT_TRUE(SEAL_BYTE(0) == uint[0]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[1]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[2]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[3]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[4]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[5]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[6]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[7]); + + uint = "1"; + ASSERT_EQ(1, uint.significant_bit_count()); + ASSERT_TRUE("1" == uint.to_string()); + ASSERT_FALSE(uint.is_zero()); + ASSERT_EQ(1ULL, *uint.data()); + ASSERT_TRUE(SEAL_BYTE(1) == uint[0]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[1]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[2]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[3]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[4]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[5]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[6]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[7]); + uint.set_zero(); + ASSERT_TRUE(uint.is_zero()); + ASSERT_EQ(static_cast(0), *uint.data()); + + uint = "7FFFFFFFFFFFFFFF"; + ASSERT_EQ(63, uint.significant_bit_count()); + ASSERT_TRUE("7FFFFFFFFFFFFFFF" == uint.to_string()); + ASSERT_EQ(static_cast(0x7FFFFFFFFFFFFFFF), *uint.data()); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[0]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[1]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[2]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[3]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[4]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[5]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[6]); + ASSERT_TRUE(SEAL_BYTE(0x7F) == uint[7]); + ASSERT_FALSE(uint.is_zero()); + + uint = "FFFFFFFFFFFFFFFF"; + ASSERT_EQ(64, uint.significant_bit_count()); + ASSERT_TRUE("FFFFFFFFFFFFFFFF" == uint.to_string()); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), *uint.data()); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[0]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[1]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[2]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[3]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[4]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[5]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[6]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[7]); + ASSERT_FALSE(uint.is_zero()); + + uint = 0x8001; + ASSERT_EQ(16, uint.significant_bit_count()); + ASSERT_TRUE("8001" == uint.to_string()); + ASSERT_EQ(static_cast(0x8001), *uint.data()); + ASSERT_TRUE(SEAL_BYTE(0x01) == uint[0]); + ASSERT_TRUE(SEAL_BYTE(0x80) == uint[1]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[2]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[3]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[4]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[5]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[6]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[7]); + } + + TEST(BigUnsignedInt, BigUInt99Bits) + { + BigUInt uint(99); + ASSERT_EQ(99, uint.bit_count()); + ASSERT_TRUE(nullptr != uint.data()); + ASSERT_EQ(13ULL, uint.byte_count()); + ASSERT_EQ(2ULL, uint.uint64_count()); + ASSERT_EQ(0, uint.significant_bit_count()); + ASSERT_TRUE("0" == uint.to_string()); + ASSERT_TRUE(uint.is_zero()); + ASSERT_EQ(static_cast(0), uint.data()[0]); + ASSERT_EQ(static_cast(0), uint.data()[1]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[0]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[1]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[2]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[3]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[4]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[5]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[6]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[7]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[8]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[9]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[10]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[11]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[12]); + + uint = "1"; + ASSERT_EQ(1, uint.significant_bit_count()); + ASSERT_TRUE("1" == uint.to_string()); + ASSERT_FALSE(uint.is_zero()); + ASSERT_EQ(1ULL, uint.data()[0]); + ASSERT_EQ(static_cast(0), uint.data()[1]); + ASSERT_TRUE(SEAL_BYTE(1) == uint[0]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[1]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[2]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[3]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[4]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[5]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[6]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[7]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[8]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[9]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[10]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[11]); + ASSERT_TRUE(SEAL_BYTE(0) == uint[12]); + uint.set_zero(); + ASSERT_TRUE(uint.is_zero()); + ASSERT_EQ(static_cast(0), uint.data()[0]); + ASSERT_EQ(static_cast(0), uint.data()[1]); + + uint = "7FFFFFFFFFFFFFFFFFFFFFFFF"; + ASSERT_EQ(99, uint.significant_bit_count()); + ASSERT_TRUE("7FFFFFFFFFFFFFFFFFFFFFFFF" == uint.to_string()); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), uint.data()[0]); + ASSERT_EQ(static_cast(0x7FFFFFFFF), uint.data()[1]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[0]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[1]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[2]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[3]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[4]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[5]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[6]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[7]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[8]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[9]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[10]); + ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[11]); + ASSERT_TRUE(SEAL_BYTE(0x07) == uint[12]); + ASSERT_FALSE(uint.is_zero()); + uint.set_zero(); + ASSERT_TRUE(uint.is_zero()); + ASSERT_EQ(static_cast(0), uint.data()[0]); + ASSERT_EQ(static_cast(0), uint.data()[1]); + + uint = "4000000000000000000000000"; + ASSERT_EQ(99, uint.significant_bit_count()); + ASSERT_TRUE("4000000000000000000000000" == uint.to_string()); + ASSERT_EQ(static_cast(0x0000000000000000), uint.data()[0]); + ASSERT_EQ(static_cast(0x400000000), uint.data()[1]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[0]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[1]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[2]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[3]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[4]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[5]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[6]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[7]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[8]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[9]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[10]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[11]); + ASSERT_TRUE(SEAL_BYTE(0x04) == uint[12]); + ASSERT_FALSE(uint.is_zero()); + + uint = 0x8001; + ASSERT_EQ(16, uint.significant_bit_count()); + ASSERT_TRUE("8001" == uint.to_string()); + ASSERT_EQ(static_cast(0x8001), uint.data()[0]); + ASSERT_EQ(static_cast(0), uint.data()[1]); + ASSERT_TRUE(SEAL_BYTE(0x01) == uint[0]); + ASSERT_TRUE(SEAL_BYTE(0x80) == uint[1]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[2]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[3]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[4]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[5]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[6]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[7]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[8]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[9]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[10]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[11]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[12]); + + BigUInt uint2("123"); + ASSERT_FALSE(uint == uint2); + ASSERT_FALSE(uint2 == uint); + ASSERT_TRUE(uint != uint2); + ASSERT_TRUE(uint2 != uint); + + uint = uint2; + ASSERT_TRUE(uint == uint2); + ASSERT_FALSE(uint != uint2); + ASSERT_EQ(9, uint.significant_bit_count()); + ASSERT_TRUE("123" == uint.to_string()); + ASSERT_EQ(static_cast(0x123), uint.data()[0]); + ASSERT_EQ(static_cast(0), uint.data()[1]); + ASSERT_TRUE(SEAL_BYTE(0x23) == uint[0]); + ASSERT_TRUE(SEAL_BYTE(0x01) == uint[1]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[2]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[3]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[4]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[5]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[6]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[7]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[8]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[9]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[10]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[11]); + ASSERT_TRUE(SEAL_BYTE(0x00) == uint[12]); + + uint.resize(8); + ASSERT_EQ(8, uint.bit_count()); + ASSERT_EQ(1ULL, uint.uint64_count()); + ASSERT_TRUE("23" == uint.to_string()); + + uint.resize(100); + ASSERT_EQ(100, uint.bit_count()); + ASSERT_EQ(2ULL, uint.uint64_count()); + ASSERT_TRUE("23" == uint.to_string()); + + uint.resize(0); + ASSERT_EQ(0, uint.bit_count()); + ASSERT_EQ(0ULL, uint.uint64_count()); + ASSERT_TRUE(nullptr == uint.data()); + } + + TEST(BigUnsignedInt, SaveLoadUInt) + { + stringstream stream; + + BigUInt value; + BigUInt value2("100"); + value.save(stream); + value2.load(stream); + ASSERT_TRUE(value == value2); + + value = "123"; + value.save(stream); + value2.load(stream); + ASSERT_TRUE(value == value2); + + value = "FFFFFFFFFFFFFFFFFFFFFFFFFF"; + value.save(stream); + value2.load(stream); + ASSERT_TRUE(value == value2); + + value = "0"; + value.save(stream); + value2.load(stream); + ASSERT_TRUE(value == value2); + } + + TEST(BigUnsignedInt, DuplicateTo) + { + BigUInt original(123); + original = 56789; + + BigUInt target; + + original.duplicate_to(target); + ASSERT_EQ(target.bit_count(), original.bit_count()); + ASSERT_TRUE(target == original); + } + + TEST(BigUnsignedInt, DuplicateFrom) + { + BigUInt original(123); + original = 56789; + + BigUInt target; + + target.duplicate_from(original); + ASSERT_EQ(target.bit_count(), original.bit_count()); + ASSERT_TRUE(target == original); + } + + TEST(BigUnsignedInt, BigUIntCopyMoveAssign) + { + { + BigUInt p1("123"); + BigUInt p2("456"); + BigUInt p3; + + p1.operator =(p2); + p3.operator =(p1); + ASSERT_TRUE(p1 == p2); + ASSERT_TRUE(p3 == p1); + } + { + BigUInt p1("123"); + BigUInt p2("456"); + BigUInt p3; + BigUInt p4(p2); + + p1.operator =(move(p2)); + p3.operator =(move(p1)); + ASSERT_TRUE(p3 == p4); + ASSERT_TRUE(p1 == p2); + ASSERT_TRUE(p3 == p1); + } + { + uint64_t p1_anchor = 123; + uint64_t p2_anchor = 456; + BigUInt p1(64, &p1_anchor); + BigUInt p2(64, &p2_anchor); + BigUInt p3; + + p1.operator =(p2); + p3.operator =(p1); + ASSERT_TRUE(p1 == p2); + ASSERT_TRUE(p3 == p1); + } + { + uint64_t p1_anchor = 123; + uint64_t p2_anchor = 456; + BigUInt p1(64, &p1_anchor); + BigUInt p2(64, &p2_anchor); + BigUInt p3; + BigUInt p4(p2); + + p1.operator =(move(p2)); + p3.operator =(move(p1)); + ASSERT_TRUE(p3 == p4); + ASSERT_TRUE(p2 == 456); + ASSERT_TRUE(p1 == 456); + ASSERT_TRUE(p3 == 456); + } + } +} diff --git a/tests/seal/ciphertext.cpp b/tests/seal/ciphertext.cpp new file mode 100644 index 000000000..627664798 --- /dev/null +++ b/tests/seal/ciphertext.cpp @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/ciphertext.h" +#include "seal/context.h" +#include "seal/keygenerator.h" +#include "seal/encryptor.h" +#include "seal/memorymanager.h" +#include "seal/defaultparams.h" + +using namespace seal; +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + TEST(CiphertextTest, CiphertextBasics) + { + EncryptionParameters parms(scheme_type::BFV); + + parms.set_poly_modulus_degree(2); + parms.set_coeff_modulus({ small_mods_30bit(0) }); + parms.set_plain_modulus(2); + parms.set_noise_standard_deviation(1.0); + auto context = SEALContext::Create(parms); + + Ciphertext ctxt(context); + ctxt.reserve(10); + ASSERT_EQ(0ULL, ctxt.size()); + ASSERT_EQ(0ULL, ctxt.uint64_count()); + ASSERT_EQ(10ULL * 2 * 1, ctxt.uint64_count_capacity()); + ASSERT_EQ(2ULL, ctxt.poly_modulus_degree()); + ASSERT_TRUE(ctxt.parms_id() == parms.parms_id()); + ASSERT_FALSE(ctxt.is_ntt_form()); + const uint64_t *ptr = ctxt.data(); + + ctxt.reserve(5); + ASSERT_EQ(0ULL, ctxt.size()); + ASSERT_EQ(0ULL, ctxt.uint64_count()); + ASSERT_EQ(5ULL * 2 * 1, ctxt.uint64_count_capacity()); + ASSERT_EQ(2ULL, ctxt.poly_modulus_degree()); + ASSERT_TRUE(ptr != ctxt.data()); + ASSERT_TRUE(ctxt.parms_id() == parms.parms_id()); + ptr = ctxt.data(); + + ctxt.reserve(10); + ASSERT_EQ(0ULL, ctxt.size()); + ASSERT_EQ(0ULL, ctxt.uint64_count()); + ASSERT_EQ(10ULL * 2 * 1, ctxt.uint64_count_capacity()); + ASSERT_EQ(2ULL, ctxt.poly_modulus_degree()); + ASSERT_TRUE(ptr != ctxt.data()); + ASSERT_TRUE(ctxt.parms_id() == parms.parms_id()); + ASSERT_FALSE(ctxt.is_ntt_form()); + ptr = ctxt.data(); + + ctxt.reserve(2); + ASSERT_EQ(0ULL, ctxt.size()); + ASSERT_EQ(2ULL * 2 * 1, ctxt.uint64_count_capacity()); + ASSERT_EQ(0ULL, ctxt.uint64_count()); + ASSERT_EQ(2ULL, ctxt.poly_modulus_degree()); + ASSERT_TRUE(ptr != ctxt.data()); + ASSERT_TRUE(ctxt.parms_id() == parms.parms_id()); + ASSERT_FALSE(ctxt.is_ntt_form()); + ptr = ctxt.data(); + + ctxt.reserve(5); + ASSERT_EQ(0ULL, ctxt.size()); + ASSERT_EQ(5ULL * 2 * 1, ctxt.uint64_count_capacity()); + ASSERT_EQ(0ULL, ctxt.uint64_count()); + ASSERT_EQ(2ULL, ctxt.poly_modulus_degree()); + ASSERT_TRUE(ptr != ctxt.data()); + ASSERT_TRUE(ctxt.parms_id() == parms.parms_id()); + ASSERT_FALSE(ctxt.is_ntt_form()); + + Ciphertext ctxt2{ ctxt }; + ASSERT_EQ(ctxt.coeff_mod_count(), ctxt2.coeff_mod_count()); + ASSERT_EQ(ctxt.is_ntt_form(), ctxt2.is_ntt_form()); + ASSERT_EQ(ctxt.poly_modulus_degree(), ctxt2.poly_modulus_degree()); + ASSERT_TRUE(ctxt.parms_id() == ctxt2.parms_id()); + ASSERT_EQ(ctxt.poly_modulus_degree(), ctxt2.poly_modulus_degree()); + ASSERT_EQ(ctxt.size(), ctxt2.size()); + + Ciphertext ctxt3; + ctxt3 = ctxt; + ASSERT_EQ(ctxt.coeff_mod_count(), ctxt3.coeff_mod_count()); + ASSERT_EQ(ctxt.poly_modulus_degree(), ctxt3.poly_modulus_degree()); + ASSERT_EQ(ctxt.is_ntt_form(), ctxt3.is_ntt_form()); + ASSERT_TRUE(ctxt.parms_id() == ctxt3.parms_id()); + ASSERT_EQ(ctxt.poly_modulus_degree(), ctxt3.poly_modulus_degree()); + ASSERT_EQ(ctxt.size(), ctxt3.size()); + } + + TEST(CiphertextTest, SaveLoadCiphertext) + { + stringstream stream; + EncryptionParameters parms(scheme_type::BFV); + parms.set_poly_modulus_degree(2); + parms.set_coeff_modulus({ small_mods_30bit(0) }); + parms.set_plain_modulus(2); + parms.set_noise_standard_deviation(1.0); + + auto context = SEALContext::Create(parms); + + Ciphertext ctxt(context); + Ciphertext ctxt2; + ctxt.save(stream); + ctxt2.load(context, stream); + ASSERT_TRUE(ctxt.parms_id() == ctxt2.parms_id()); + ASSERT_FALSE(ctxt.is_ntt_form()); + ASSERT_FALSE(ctxt2.is_ntt_form()); + + parms.set_poly_modulus_degree(1024); + parms.set_coeff_modulus(coeff_modulus_128(1024)); + parms.set_plain_modulus(0xF0F0); + parms.set_noise_standard_deviation(3.14159); + context = SEALContext::Create(parms); + KeyGenerator keygen(context); + Encryptor encryptor(context, keygen.public_key()); + encryptor.encrypt(Plaintext("Ax^10 + 9x^9 + 8x^8 + 7x^7 + 6x^6 + 5x^5 + 4x^4 + 3x^3 + 2x^2 + 1"), ctxt); + ctxt.save(stream); + ctxt2.load(context, stream); + ASSERT_TRUE(ctxt.parms_id() == ctxt2.parms_id()); + ASSERT_FALSE(ctxt.is_ntt_form()); + ASSERT_FALSE(ctxt2.is_ntt_form()); + ASSERT_TRUE(is_equal_uint_uint(ctxt.data(), ctxt2.data(), + parms.poly_modulus_degree() * parms.coeff_modulus().size() * 2)); + ASSERT_TRUE(ctxt.data() != ctxt2.data()); + } +} diff --git a/tests/seal/ckks.cpp b/tests/seal/ckks.cpp new file mode 100644 index 000000000..02c340b91 --- /dev/null +++ b/tests/seal/ckks.cpp @@ -0,0 +1,320 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/ckks.h" +#include "seal/context.h" +#include "seal/defaultparams.h" +#include "seal/keygenerator.h" +#include +#include + +using namespace seal; +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + TEST(CKKSEncoderTest, CKKSEncoderEncodeVectorDecodeTest) + { + EncryptionParameters parms(scheme_type::CKKS); + { + uint32_t slots = 32; + parms.set_poly_modulus_degree(2 * slots); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1), + small_mods_40bit(2), small_mods_40bit(3) }); + auto context = SEALContext::Create(parms); + + std::vector> values(slots); + + for (size_t i = 0; i < slots; i++) + { + std::complex value(0.0, 0.0); + values[i] = value; + } + + CKKSEncoder encoder(context); + double delta = (1ULL << 16); + Plaintext plain; + encoder.encode(values, parms.parms_id(), delta, plain); + std::vector> result; + encoder.decode(plain, result); + + for (size_t i = 0; i < slots; ++i) + { + auto tmp = abs(values[i].real() - result[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + { + uint32_t slots = 32; + parms.set_poly_modulus_degree(2 * slots); + parms.set_coeff_modulus({ small_mods_60bit(0), small_mods_60bit(1), + small_mods_60bit(2), small_mods_60bit(3) }); + auto context = SEALContext::Create(parms); + + std::vector> values(slots); + + srand(static_cast(time(NULL))); + int data_bound = (1 << 30); + + for (size_t i = 0; i < slots; i++) + { + std::complex value(static_cast(rand() % data_bound), 0); + values[i] = value; + } + + CKKSEncoder encoder(context); + double delta = (1ULL << 40); + Plaintext plain; + encoder.encode(values, parms.parms_id(), delta, plain); + std::vector> result; + encoder.decode(plain, result); + + for (size_t i = 0; i < slots; ++i) + { + auto tmp = abs(values[i].real() - result[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + { + uint32_t slots = 64; + parms.set_poly_modulus_degree(2 * slots); + parms.set_coeff_modulus({ small_mods_60bit(0), + small_mods_60bit(1), small_mods_60bit(2) }); + auto context = SEALContext::Create(parms); + + std::vector> values(slots); + + srand(static_cast(time(NULL))); + int data_bound = (1 << 30); + + for (size_t i = 0; i < slots; i++) + { + std::complex value(static_cast(rand() % data_bound), 0); + values[i] = value; + } + + CKKSEncoder encoder(context); + double delta = (1ULL << 40); + Plaintext plain; + encoder.encode(values, parms.parms_id(), delta, plain); + std::vector> result; + encoder.decode(plain, result); + + for (size_t i = 0; i < slots; ++i) + { + auto tmp = abs(values[i].real() - result[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + { + uint32_t slots = 64; + parms.set_poly_modulus_degree(2 * slots); + parms.set_coeff_modulus({ small_mods_30bit(0), small_mods_30bit(1), + small_mods_30bit(2), small_mods_30bit(3), small_mods_30bit(4) }); + auto context = SEALContext::Create(parms); + + std::vector> values(slots); + + srand(static_cast(time(NULL))); + int data_bound = (1 << 30); + + for (size_t i = 0; i < slots; i++) + { + std::complex value(static_cast(rand() % data_bound), 0); + values[i] = value; + } + + CKKSEncoder encoder(context); + double delta = (1ULL << 40); + Plaintext plain; + encoder.encode(values, parms.parms_id(), delta, plain); + std::vector> result; + encoder.decode(plain, result); + + for (size_t i = 0; i < slots; ++i) + { + auto tmp = abs(values[i].real() - result[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + { + uint32_t slots = 32; + parms.set_poly_modulus_degree(128); + parms.set_coeff_modulus({ small_mods_30bit(0), small_mods_30bit(1), + small_mods_30bit(2), small_mods_30bit(3), small_mods_30bit(4) }); + auto context = SEALContext::Create(parms); + + std::vector> values(slots); + + srand(static_cast(time(NULL))); + int data_bound = (1 << 30); + + for (size_t i = 0; i < slots; i++) + { + std::complex value(static_cast(rand() % data_bound), 0); + values[i] = value; + } + + CKKSEncoder encoder(context); + double delta = (1ULL << 40); + Plaintext plain; + encoder.encode(values, parms.parms_id(), delta, plain); + std::vector> result; + encoder.decode(plain, result); + + for (size_t i = 0; i < slots; ++i) + { + auto tmp = abs(values[i].real() - result[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + { + uint32_t slots = 64; + parms.set_poly_modulus_degree(2 * slots); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1), + small_mods_40bit(2), small_mods_40bit(3), small_mods_40bit(4) }); + auto context = SEALContext::Create(parms); + + std::vector> values(slots); + + srand(static_cast(time(NULL))); + int data_bound = (1 << 20); + + for (size_t i = 0; i < slots; i++) + { + std::complex value(static_cast(rand() % data_bound), 0); + values[i] = value; + } + + CKKSEncoder encoder(context); + { + // Use a very large scale + double delta = pow(2.0, 110); + Plaintext plain; + encoder.encode(values, parms.parms_id(), delta, plain); + std::vector> result; + encoder.decode(plain, result); + + for (size_t i = 0; i < slots; ++i) + { + auto tmp = abs(values[i].real() - result[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + { + // Use a scale over 128 bits + double delta = pow(2.0, 130); + Plaintext plain; + encoder.encode(values, parms.parms_id(), delta, plain); + std::vector> result; + encoder.decode(plain, result); + + for (size_t i = 0; i < slots; ++i) + { + auto tmp = abs(values[i].real() - result[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + } + + TEST(CKKSEncoderTest, CKKSEncoderEncodeSingleDecodeTest) + { + EncryptionParameters parms(scheme_type::CKKS); + { + uint32_t slots = 16; + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1), + small_mods_40bit(2), small_mods_40bit(3) }); + auto context = SEALContext::Create(parms); + CKKSEncoder encoder(context); + + srand(static_cast(time(NULL))); + int data_bound = (1 << 30); + double delta = (1ULL << 16); + Plaintext plain; + std::vector> result; + + for (int iRun = 0; iRun < 50; iRun++) + { + double value = static_cast(rand() % data_bound); + encoder.encode(value, parms.parms_id(), delta, plain); + encoder.decode(plain, result); + + for (size_t i = 0; i < slots; ++i) + { + auto tmp = abs(value - result[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + uint32_t slots = 32; + parms.set_poly_modulus_degree(slots * 2); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1), + small_mods_40bit(2), small_mods_40bit(3) }); + auto context = SEALContext::Create(parms); + CKKSEncoder encoder(context); + + srand(static_cast(time(NULL))); + { + int data_bound = (1 << 30); + Plaintext plain; + std::vector> result; + + for (int iRun = 0; iRun < 50; iRun++) + { + int value = static_cast(rand() % data_bound); + encoder.encode(value, parms.parms_id(), plain); + encoder.decode(plain, result); + + for (size_t i = 0; i < slots; ++i) + { + auto tmp = abs(value - result[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + // Use a very large scale + int data_bound = (1 << 20); + Plaintext plain; + std::vector> result; + + for (int iRun = 0; iRun < 50; iRun++) + { + int value = static_cast(rand() % data_bound); + encoder.encode(value, parms.parms_id(), plain); + encoder.decode(plain, result); + + for (size_t i = 0; i < slots; ++i) + { + auto tmp = abs(value - result[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + // Use a scale over 128 bits + int data_bound = (1 << 20); + Plaintext plain; + std::vector> result; + + for (int iRun = 0; iRun < 50; iRun++) + { + int value = static_cast(rand() % data_bound); + encoder.encode(value, parms.parms_id(), plain); + encoder.decode(plain, result); + + for (size_t i = 0; i < slots; ++i) + { + auto tmp = abs(value - result[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + } + } +} diff --git a/tests/seal/context.cpp b/tests/seal/context.cpp new file mode 100644 index 000000000..ca18ddbab --- /dev/null +++ b/tests/seal/context.cpp @@ -0,0 +1,228 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/context.h" + +using namespace seal; +using namespace std; + +namespace SEALTest +{ + TEST(ContextTest, ContextConstructor) + { + // Nothing set + auto scheme = scheme_type::BFV; + EncryptionParameters parms(scheme); + { + auto context = SEALContext::Create(parms); + auto qualifiers = context->context_data()->qualifiers(); + ASSERT_FALSE(qualifiers.parameters_set); + ASSERT_FALSE(qualifiers.using_fft); + ASSERT_FALSE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_FALSE(qualifiers.using_fast_plain_lift); + } + + // Not relatively prime coeff moduli + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 2, 30 }); + parms.set_plain_modulus(2); + parms.set_noise_standard_deviation(3.20); + parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); + { + auto context = SEALContext::Create(parms); + auto qualifiers = context->context_data()->qualifiers(); + ASSERT_FALSE(qualifiers.parameters_set); + ASSERT_FALSE(qualifiers.using_fft); + ASSERT_FALSE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_FALSE(qualifiers.using_fast_plain_lift); + } + + // Plain modulus not relatively prime to coeff moduli + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 17, 41 }); + parms.set_plain_modulus(34); + parms.set_noise_standard_deviation(3.20); + parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); + { + auto context = SEALContext::Create(parms); + auto qualifiers = context->context_data()->qualifiers(); + ASSERT_FALSE(qualifiers.parameters_set); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_FALSE(qualifiers.using_fast_plain_lift); + } + + // Plain modulus not smaller than product of coeff moduli + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 2 }); + parms.set_plain_modulus(3); + parms.set_noise_standard_deviation(3.20); + parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); + { + auto context = SEALContext::Create(parms); + ASSERT_EQ(2ULL, *context->context_data()->total_coeff_modulus()); + auto qualifiers = context->context_data()->qualifiers(); + ASSERT_FALSE(qualifiers.parameters_set); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_FALSE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_FALSE(qualifiers.using_fast_plain_lift); + } + + // FFT poly but not NTT modulus + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 3 }); + parms.set_plain_modulus(2); + parms.set_noise_standard_deviation(3.20); + parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); + { + auto context = SEALContext::Create(parms); + ASSERT_EQ(3ULL, *context->context_data()->total_coeff_modulus()); + auto qualifiers = context->context_data()->qualifiers(); + ASSERT_FALSE(qualifiers.parameters_set); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_FALSE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_FALSE(qualifiers.using_fast_plain_lift); + } + + // Parameters OK; no fast plain lift + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 17, 41 }); + parms.set_plain_modulus(18); + parms.set_noise_standard_deviation(3.20); + parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); + { + auto context = SEALContext::Create(parms); + ASSERT_EQ(697ULL, *context->context_data()->total_coeff_modulus()); + auto qualifiers = context->context_data()->qualifiers(); + ASSERT_TRUE(qualifiers.parameters_set); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_FALSE(qualifiers.using_fast_plain_lift); + } + + // Parameters OK; fast plain lift + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 17, 41 }); + parms.set_plain_modulus(16); + parms.set_noise_standard_deviation(3.20); + parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); + { + auto context = SEALContext::Create(parms); + ASSERT_EQ(697ULL, *context->context_data()->total_coeff_modulus()); + auto qualifiers = context->context_data()->qualifiers(); + ASSERT_TRUE(qualifiers.parameters_set); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_TRUE(qualifiers.using_fast_plain_lift); + } + + // Parameters OK; no batching due to non-prime plain modulus + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 17, 41 }); + parms.set_plain_modulus(49); + parms.set_noise_standard_deviation(3.20); + parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); + { + auto context = SEALContext::Create(parms); + ASSERT_EQ(697ULL, *context->context_data()->total_coeff_modulus()); + auto qualifiers = context->context_data()->qualifiers(); + ASSERT_TRUE(qualifiers.parameters_set); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_FALSE(qualifiers.using_batching); + ASSERT_FALSE(qualifiers.using_fast_plain_lift); + } + + // Parameters OK; batching enabled + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 17, 41 }); + parms.set_plain_modulus(73); + parms.set_noise_standard_deviation(3.20); + parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); + { + auto context = SEALContext::Create(parms); + ASSERT_EQ(697ULL, *context->context_data()->total_coeff_modulus()); + auto qualifiers = context->context_data()->qualifiers(); + ASSERT_TRUE(qualifiers.parameters_set); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_TRUE(qualifiers.using_batching); + ASSERT_FALSE(qualifiers.using_fast_plain_lift); + } + + // Parameters OK; batching and fast plain lift enabled + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 137, 193 }); + parms.set_plain_modulus(73); + parms.set_noise_standard_deviation(3.20); + parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); + { + auto context = SEALContext::Create(parms); + ASSERT_EQ(26441ULL, *context->context_data()->total_coeff_modulus()); + auto qualifiers = context->context_data()->qualifiers(); + ASSERT_TRUE(qualifiers.parameters_set); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_TRUE(qualifiers.using_batching); + ASSERT_TRUE(qualifiers.using_fast_plain_lift); + } + + // Parameters OK; batching and fast plain lift enabled; nullptr RNG + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 137, 193 }); + parms.set_plain_modulus(73); + parms.set_noise_standard_deviation(3.20); + parms.set_random_generator(nullptr); + { + auto context = SEALContext::Create(parms); + ASSERT_EQ(26441ULL, *context->context_data()->total_coeff_modulus()); + auto qualifiers = context->context_data()->qualifiers(); + ASSERT_TRUE(qualifiers.parameters_set); + ASSERT_TRUE(qualifiers.using_fft); + ASSERT_TRUE(qualifiers.using_ntt); + ASSERT_TRUE(qualifiers.using_batching); + ASSERT_TRUE(qualifiers.using_fast_plain_lift); + } + } + + TEST(ContextTest, ModulusChainExpansion) + { + { + EncryptionParameters parms(scheme_type::BFV); + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 41, 137, 193, 65537 }); + parms.set_plain_modulus(73); + auto context = SEALContext::Create(parms, true); + ASSERT_EQ(size_t(2), context->context_data()->chain_index()); + ASSERT_EQ(71047416497ULL, *context->context_data()->total_coeff_modulus()); + ASSERT_TRUE(!!context->context_data()->next_context_data()); + + context = SEALContext::Create(parms, false); + ASSERT_EQ(size_t(0), context->context_data()->chain_index()); + ASSERT_EQ(71047416497ULL, *context->context_data()->total_coeff_modulus()); + ASSERT_FALSE(!!context->context_data()->next_context_data()); + } + { + EncryptionParameters parms(scheme_type::CKKS); + parms.set_poly_modulus_degree(4); + parms.set_coeff_modulus({ 41, 137, 193, 65537 }); + auto context = SEALContext::Create(parms, true); + ASSERT_EQ(size_t(3), context->context_data()->chain_index()); + ASSERT_EQ(71047416497ULL, *context->context_data()->total_coeff_modulus()); + ASSERT_TRUE(!!context->context_data()->next_context_data()); + + context = SEALContext::Create(parms, false); + ASSERT_EQ(size_t(0), context->context_data()->chain_index()); + ASSERT_EQ(71047416497ULL, *context->context_data()->total_coeff_modulus()); + ASSERT_FALSE(!!context->context_data()->next_context_data()); + } + } +} diff --git a/tests/seal/encoder.cpp b/tests/seal/encoder.cpp new file mode 100644 index 000000000..853d23dbc --- /dev/null +++ b/tests/seal/encoder.cpp @@ -0,0 +1,1109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/encoder.h" +#include "seal/context.h" +#include "seal/defaultparams.h" +#include +#include + +using namespace seal; +using namespace std; + +namespace SEALTest +{ + TEST(Encoder, BinaryEncodeDecodeBigUInt) + { + SmallModulus modulus(0xFFFFFFFFFFFFFFF); + BinaryEncoder encoder(modulus); + + BigUInt value(64); + value = "0"; + Plaintext poly = encoder.encode(value); + ASSERT_EQ(0ULL, poly.significant_coeff_count()); + ASSERT_TRUE(poly.is_zero()); + ASSERT_TRUE(value == encoder.decode_biguint(poly)); + + value = "1"; + Plaintext poly1 = encoder.encode(value); + ASSERT_EQ(1ULL, poly1.coeff_count()); + ASSERT_TRUE("1" == poly1.to_string()); + ASSERT_TRUE(value == encoder.decode_biguint(poly1)); + + value = "2"; + Plaintext poly2 = encoder.encode(value); + ASSERT_EQ(2ULL, poly2.coeff_count()); + ASSERT_TRUE("1x^1" == poly2.to_string()); + ASSERT_TRUE(value == encoder.decode_biguint(poly2)); + + value = "3"; + Plaintext poly3 = encoder.encode(value); + ASSERT_EQ(2ULL, poly3.coeff_count()); + ASSERT_TRUE("1x^1 + 1" == poly3.to_string()); + ASSERT_TRUE(value == encoder.decode_biguint(poly3)); + + value = "FFFFFFFFFFFFFFFF"; + Plaintext poly4 = encoder.encode(value); + ASSERT_EQ(64ULL, poly4.coeff_count()); + for (size_t i = 0; i < 64; ++i) + { + ASSERT_TRUE(poly4[i] == 1); + } + ASSERT_TRUE(value == encoder.decode_biguint(poly4)); + + value = "80F02"; + Plaintext poly5 = encoder.encode(value); + ASSERT_EQ(20ULL, poly5.coeff_count()); + for (size_t i = 0; i < 20; ++i) + { + if (i == 19 || (i >= 8 && i <= 11) || i == 1) + { + ASSERT_TRUE(poly5[i] == 1); + } + else + { + ASSERT_TRUE(poly5[i] == 0); + } + } + ASSERT_TRUE(value == encoder.decode_biguint(poly5)); + + Plaintext poly6(3); + poly6[0] = 1; + poly6[1] = 500; + poly6[2] = 1023; + value = 1 + 500 * 2 + 1023 * 4; + ASSERT_TRUE(value == encoder.decode_biguint(poly6)); + + modulus = 1024; + BinaryEncoder encoder2(modulus); + Plaintext poly7(4); + poly7[0] = 1023; // -1 (*1) + poly7[1] = 512; // -512 (*2) + poly7[2] = 511; // 511 (*4) + poly7[3] = 1; // 1 (*8) + value = -1 + -512 * 2 + 511 * 4 + 1 * 8; + ASSERT_TRUE(value == encoder2.decode_biguint(poly7)); + } + + TEST(Encoder, BalancedEncodeDecodeBigUInt) + { + SmallModulus modulus(0x10000UL); + BalancedEncoder encoder(modulus); + + BigUInt value(64); + value = "0"; + Plaintext poly = encoder.encode(value); + ASSERT_EQ(0ULL, poly.significant_coeff_count()); + ASSERT_TRUE(poly.is_zero()); + ASSERT_TRUE(value == encoder.decode_biguint(poly)); + + value = "1"; + Plaintext poly1 = encoder.encode(value); + ASSERT_EQ(1ULL, poly1.significant_coeff_count()); + ASSERT_TRUE("1" == poly1.to_string()); + ASSERT_TRUE(value == encoder.decode_biguint(poly1)); + + value = "2"; + Plaintext poly2 = encoder.encode(value); + ASSERT_EQ(2ULL, poly2.significant_coeff_count()); + ASSERT_TRUE("1x^1 + FFFF" == poly2.to_string()); + ASSERT_TRUE(value == encoder.decode_biguint(poly2)); + + value = "3"; + Plaintext poly3 = encoder.encode(value); + ASSERT_EQ(2ULL, poly3.significant_coeff_count()); + ASSERT_TRUE("1x^1" == poly3.to_string()); + ASSERT_TRUE(value == encoder.decode_biguint(poly3)); + + value = "2671"; + Plaintext poly4 = encoder.encode(value); + ASSERT_EQ(9ULL, poly4.significant_coeff_count()); + for (size_t i = 0; i < 9; ++i) + { + ASSERT_TRUE(poly4[i] == 1); + } + ASSERT_TRUE(value == encoder.decode_biguint(poly4)); + + value = "D4EB"; + Plaintext poly5 = encoder.encode(value); + ASSERT_EQ(11ULL, poly5.significant_coeff_count()); + for (size_t i = 0; i < 11; ++i) + { + if (i % 3 == 1) + { + ASSERT_TRUE(poly5[i] == 1); + } + else if (i % 3 == 0) + { + ASSERT_TRUE(poly5[i] == 0); + } + else + { + ASSERT_TRUE(poly5[i] == 0xFFFF); + } + } + ASSERT_TRUE(value == encoder.decode_biguint(poly5)); + + Plaintext poly6(3); + poly6[0] = 1; + poly6[1] = 500; + poly6[2] = 1023; + value = 1 + 500 * 3 + 1023 * 9; + ASSERT_TRUE(value == encoder.decode_biguint(poly6)); + + BalancedEncoder encoder2(modulus, 7); + Plaintext poly7(4); + poly7[0] = 123; // 123 (*1) + poly7[1] = 0xFFFF; // -1 (*7) + poly7[2] = 511; // 511 (*49) + poly7[3] = 1; // 1 (*343) + value = 123 + -1 * 7 + 511 * 49 + 1 * 343; + ASSERT_TRUE(value == encoder2.decode_biguint(poly7)); + + BalancedEncoder encoder3(modulus, 6); + Plaintext poly8(4); + poly8[0] = 5; + poly8[1] = 4; + poly8[2] = 3; + poly8[3] = 2; + value = 5 + 4 * 6 + 3 * 36 + 2 * 216; + ASSERT_TRUE(value == encoder3.decode_biguint(poly8)); + + BalancedEncoder encoder4(modulus, 10); + Plaintext poly9(4); + poly9[0] = 1; + poly9[1] = 2; + poly9[2] = 3; + poly9[3] = 4; + value = 4321; + ASSERT_TRUE(value == encoder4.decode_biguint(poly9)); + + value = "4D2"; + Plaintext poly10 = encoder2.encode(value); + ASSERT_EQ(5ULL, poly10.significant_coeff_count()); + ASSERT_TRUE(value == encoder2.decode_biguint(poly10)); + + value = "4D2"; + Plaintext poly11 = encoder3.encode(value); + ASSERT_EQ(5ULL, poly11.significant_coeff_count()); + ASSERT_TRUE(value == encoder3.decode_biguint(poly11)); + + value = "4D2"; + Plaintext poly12 = encoder4.encode(value); + ASSERT_EQ(4ULL, poly12.significant_coeff_count()); + ASSERT_TRUE(value == encoder4.decode_biguint(poly12)); + } + + TEST(Encoder, BinaryEncodeDecodeUInt64) + { + SmallModulus modulus(0xFFFFFFFFFFFFFFF); + BinaryEncoder encoder(modulus); + + Plaintext poly = encoder.encode(static_cast(0)); + ASSERT_EQ(0ULL, poly.significant_coeff_count()); + ASSERT_TRUE(poly.is_zero()); + ASSERT_EQ(static_cast(0), encoder.decode_uint64(poly)); + + Plaintext poly1 = encoder.encode(1u); + ASSERT_EQ(1ULL, poly1.coeff_count()); + ASSERT_TRUE("1" == poly1.to_string()); + ASSERT_EQ(1ULL, encoder.decode_uint64(poly1)); + + Plaintext poly2 = encoder.encode(static_cast(2)); + ASSERT_EQ(2ULL, poly2.coeff_count()); + ASSERT_TRUE("1x^1" == poly2.to_string()); + ASSERT_EQ(static_cast(2), encoder.decode_uint64(poly2)); + + Plaintext poly3 = encoder.encode(static_cast(3)); + ASSERT_EQ(2ULL, poly3.coeff_count()); + ASSERT_TRUE("1x^1 + 1" == poly3.to_string()); + ASSERT_EQ(static_cast(3), encoder.decode_uint64(poly3)); + + Plaintext poly4 = encoder.encode(static_cast(0xFFFFFFFFFFFFFFFF)); + ASSERT_EQ(64ULL, poly4.coeff_count()); + for (size_t i = 0; i < 64; ++i) + { + ASSERT_TRUE(poly4[i] == 1); + } + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), encoder.decode_uint64(poly4)); + + Plaintext poly5 = encoder.encode(static_cast(0x80F02)); + ASSERT_EQ(20ULL, poly5.coeff_count()); + for (size_t i = 0; i < 20; ++i) + { + if (i == 19 || (i >= 8 && i <= 11) || i == 1) + { + ASSERT_TRUE(poly5[i] == 1); + } + else + { + ASSERT_TRUE(poly5[i] == 0); + } + } + ASSERT_EQ(static_cast(0x80F02), encoder.decode_uint64(poly5)); + + Plaintext poly6(3); + poly6[0] = 1; + poly6[1] = 500; + poly6[2] = 1023; + ASSERT_EQ(static_cast(1 + 500 * 2 + 1023 * 4), encoder.decode_uint64(poly6)); + + modulus = 1024; + BinaryEncoder encoder2(modulus); + Plaintext poly7(4); + poly7[0] = 1023; // -1 (*1) + poly7[1] = 512; // -512 (*2) + poly7[2] = 511; // 511 (*4) + poly7[3] = 1; // 1 (*8) + ASSERT_EQ(static_cast(-1 + -512 * 2 + 511 * 4 + 1 * 8), encoder2.decode_uint64(poly7)); + } + + TEST(Encoder, BalancedEncodeDecodeUInt64) + { + SmallModulus modulus(0x10000UL); + BalancedEncoder encoder(modulus); + + Plaintext poly = encoder.encode(static_cast(0)); + ASSERT_EQ(0ULL, poly.significant_coeff_count()); + ASSERT_TRUE(poly.is_zero()); + ASSERT_EQ(static_cast(0), encoder.decode_uint64(poly)); + + Plaintext poly1 = encoder.encode(1u); + ASSERT_EQ(1ULL, poly1.significant_coeff_count()); + ASSERT_TRUE("1" == poly1.to_string()); + ASSERT_EQ(1ULL, encoder.decode_uint64(poly1)); + + Plaintext poly2 = encoder.encode(static_cast(2)); + ASSERT_EQ(2ULL, poly2.significant_coeff_count()); + ASSERT_TRUE("1x^1 + FFFF" == poly2.to_string()); + ASSERT_EQ(static_cast(2), encoder.decode_uint64(poly2)); + + Plaintext poly3 = encoder.encode(static_cast(3)); + ASSERT_EQ(2ULL, poly3.significant_coeff_count()); + ASSERT_TRUE("1x^1" == poly3.to_string()); + ASSERT_EQ(static_cast(3), encoder.decode_uint64(poly3)); + + Plaintext poly4 = encoder.encode(static_cast(0x2671)); + ASSERT_EQ(9ULL, poly4.significant_coeff_count()); + for (size_t i = 0; i < 9; ++i) + { + ASSERT_TRUE(1 == poly4[i]); + } + ASSERT_EQ(static_cast(0x2671), encoder.decode_uint64(poly4)); + + Plaintext poly5 = encoder.encode(static_cast(0xD4EB)); + ASSERT_EQ(11ULL, poly5.significant_coeff_count()); + for (size_t i = 0; i < 11; ++i) + { + if (i % 3 == 1) + { + ASSERT_TRUE(1 == poly5[i]); + } + else if (i % 3 == 0) + { + ASSERT_TRUE(poly5[i] == 0); + } + else + { + ASSERT_TRUE(0xFFFF == poly5[i]); + } + } + ASSERT_EQ(static_cast(0xD4EB), encoder.decode_uint64(poly5)); + + Plaintext poly6(3); + poly6[0] = 1; + poly6[1] = 500; + poly6[2] = 1023; + ASSERT_EQ(static_cast(1 + 500 * 3 + 1023 * 9), encoder.decode_uint64(poly6)); + + BalancedEncoder encoder2(modulus, 7); + Plaintext poly7(4); + poly7[0] = 123; // 123 (*1) + poly7[1] = 0xFFFF; // -1 (*7) + poly7[2] = 511; // 511 (*49) + poly7[3] = 1; // 1 (*343) + ASSERT_EQ(static_cast(123 + -1 * 7 + 511 * 49 + 1 * 343), encoder2.decode_uint64(poly7)); + + BalancedEncoder encoder3(modulus, 6); + Plaintext poly8(4); + poly8[0] = 5; + poly8[1] = 4; + poly8[2] = 3; + poly8[3] = 2; + uint64_t value = 5 + 4 * 6 + 3 * 36 + 2 * 216; + ASSERT_TRUE(value == encoder3.decode_uint64(poly8)); + + BalancedEncoder encoder4(modulus, 10); + Plaintext poly9(4); + poly9[0] = 1; + poly9[1] = 2; + poly9[2] = 3; + poly9[3] = 4; + value = 4321; + ASSERT_TRUE(value == encoder4.decode_uint64(poly9)); + + value = 1234; + Plaintext poly10 = encoder2.encode(value); + ASSERT_EQ(5ULL, poly10.significant_coeff_count()); + ASSERT_TRUE(value == encoder2.decode_uint64(poly10)); + + value = 1234; + Plaintext poly11 = encoder3.encode(value); + ASSERT_EQ(5ULL, poly11.significant_coeff_count()); + ASSERT_TRUE(value == encoder3.decode_uint64(poly11)); + + value = 1234; + Plaintext poly12 = encoder4.encode(value); + ASSERT_EQ(4ULL, poly12.significant_coeff_count()); + ASSERT_TRUE(value == encoder4.decode_uint64(poly12)); + } + + TEST(Encoder, BinaryEncodeDecodeUInt32) + { + SmallModulus modulus(0xFFFFFFFFFFFFFFF); + BinaryEncoder encoder(modulus); + + Plaintext poly = encoder.encode(static_cast(0)); + ASSERT_EQ(0ULL, poly.significant_coeff_count()); + ASSERT_TRUE(poly.is_zero()); + ASSERT_EQ(static_cast(0), encoder.decode_uint32(poly)); + + Plaintext poly1 = encoder.encode(static_cast(1)); + ASSERT_EQ(1ULL, poly1.significant_coeff_count()); + ASSERT_TRUE("1" == poly1.to_string()); + ASSERT_EQ(static_cast(1), encoder.decode_uint32(poly1)); + + Plaintext poly2 = encoder.encode(static_cast(2)); + ASSERT_EQ(2ULL, poly2.significant_coeff_count()); + ASSERT_TRUE("1x^1" == poly2.to_string()); + ASSERT_EQ(static_cast(2), encoder.decode_uint32(poly2)); + + Plaintext poly3 = encoder.encode(static_cast(3)); + ASSERT_EQ(2ULL, poly3.significant_coeff_count()); + ASSERT_TRUE("1x^1 + 1" == poly3.to_string()); + ASSERT_EQ(static_cast(3), encoder.decode_uint32(poly3)); + + Plaintext poly4 = encoder.encode(static_cast(0xFFFFFFFF)); + ASSERT_EQ(32ULL, poly4.significant_coeff_count()); + for (size_t i = 0; i < 32; ++i) + { + ASSERT_TRUE(1 == poly4[i]); + } + ASSERT_EQ(static_cast(0xFFFFFFFF), encoder.decode_uint32(poly4)); + + Plaintext poly5 = encoder.encode(static_cast(0x80F02)); + ASSERT_EQ(20ULL, poly5.significant_coeff_count()); + for (size_t i = 0; i < 20; ++i) + { + if (i == 19 || (i >= 8 && i <= 11) || i == 1) + { + ASSERT_TRUE(1 == poly5[i]); + } + else + { + ASSERT_TRUE(poly5[i] == 0); + } + } + ASSERT_EQ(static_cast(0x80F02), encoder.decode_uint32(poly5)); + + Plaintext poly6(3); + poly6[0] = 1; + poly6[1] = 500; + poly6[2] = 1023; + ASSERT_EQ(static_cast(1 + 500 * 2 + 1023 * 4), encoder.decode_uint32(poly6)); + + modulus = 1024; + BinaryEncoder encoder2(modulus); + Plaintext poly7(4); + poly7[0] = 1023; // -1 (*1) + poly7[1] = 512; // -512 (*2) + poly7[2] = 511; // 511 (*4) + poly7[3] = 1; // 1 (*8) + ASSERT_EQ(static_cast(-1 + -512 * 2 + 511 * 4 + 1 * 8), encoder2.decode_uint32(poly7)); + } + + TEST(Encoder, BalancedEncodeDecodeUInt32) + { + SmallModulus modulus(0x10000UL); + BalancedEncoder encoder(modulus); + + Plaintext poly = encoder.encode(static_cast(0)); + ASSERT_EQ(0ULL, poly.significant_coeff_count()); + ASSERT_TRUE(poly.is_zero()); + ASSERT_EQ(static_cast(0), encoder.decode_uint32(poly)); + + Plaintext poly1 = encoder.encode(static_cast(1)); + ASSERT_EQ(1ULL, poly1.significant_coeff_count()); + ASSERT_TRUE("1" == poly1.to_string()); + ASSERT_EQ(static_cast(1), encoder.decode_uint32(poly1)); + + Plaintext poly2 = encoder.encode(static_cast(2)); + ASSERT_EQ(2ULL, poly2.significant_coeff_count()); + ASSERT_TRUE("1x^1 + FFFF" == poly2.to_string()); + ASSERT_EQ(static_cast(2), encoder.decode_uint32(poly2)); + + Plaintext poly3 = encoder.encode(static_cast(3)); + ASSERT_EQ(2ULL, poly3.significant_coeff_count()); + ASSERT_TRUE("1x^1" == poly3.to_string()); + ASSERT_EQ(static_cast(3), encoder.decode_uint32(poly3)); + + Plaintext poly4 = encoder.encode(static_cast(0x2671)); + ASSERT_EQ(9ULL, poly4.significant_coeff_count()); + for (size_t i = 0; i < 9; ++i) + { + ASSERT_TRUE(1 == poly4[i]); + } + ASSERT_EQ(static_cast(0x2671), encoder.decode_uint32(poly4)); + + Plaintext poly5 = encoder.encode(static_cast(0xD4EB)); + ASSERT_EQ(11ULL, poly5.significant_coeff_count()); + for (size_t i = 0; i < 11; ++i) + { + if (i % 3 == 1) + { + ASSERT_TRUE(1 == poly5[i]); + } + else if (i % 3 == 0) + { + ASSERT_TRUE(poly5[i] == 0); + } + else + { + ASSERT_TRUE(0xFFFF == poly5[i]); + } + } + ASSERT_EQ(static_cast(0xD4EB), encoder.decode_uint32(poly5)); + + Plaintext poly6(3); + poly6[0] = 1; + poly6[1] = 500; + poly6[2] = 1023; + ASSERT_EQ(static_cast(1 + 500 * 3 + 1023 * 9), encoder.decode_uint32(poly6)); + + BalancedEncoder encoder2(modulus, 7); + Plaintext poly7(4); + poly7[0] = 123; // 123 (*1) + poly7[1] = 0xFFFF; // -1 (*7) + poly7[2] = 511; // 511 (*49) + poly7[3] = 1; // 1 (*343) + ASSERT_EQ(static_cast(123 + -1 * 7 + 511 * 49 + 1 * 343), encoder2.decode_uint32(poly7)); + + BalancedEncoder encoder3(modulus, 6); + Plaintext poly8(4); + poly8[0] = 5; + poly8[1] = 4; + poly8[2] = 3; + poly8[3] = 2; + uint64_t value = 5 + 4 * 6 + 3 * 36 + 2 * 216; + ASSERT_TRUE(value == encoder3.decode_uint32(poly8)); + + BalancedEncoder encoder4(modulus, 10); + Plaintext poly9(4); + poly9[0] = 1; + poly9[1] = 2; + poly9[2] = 3; + poly9[3] = 4; + value = 4321; + ASSERT_TRUE(value == encoder4.decode_uint32(poly9)); + + value = 1234; + Plaintext poly10 = encoder2.encode(value); + ASSERT_EQ(5ULL, poly10.significant_coeff_count()); + ASSERT_TRUE(value == encoder2.decode_uint32(poly10)); + + value = 1234; + Plaintext poly11 = encoder3.encode(value); + ASSERT_EQ(5ULL, poly11.significant_coeff_count()); + ASSERT_TRUE(value == encoder3.decode_uint32(poly11)); + + value = 1234; + Plaintext poly12 = encoder4.encode(value); + ASSERT_EQ(4ULL, poly12.significant_coeff_count()); + ASSERT_TRUE(value == encoder4.decode_uint32(poly12)); + } + + TEST(Encoder, BinaryEncodeDecodeInt64) + { + SmallModulus modulus(0x7FFFFFFFFFFFF); + BinaryEncoder encoder(modulus); + + Plaintext poly = encoder.encode(static_cast(0)); + ASSERT_EQ(0ULL, poly.significant_coeff_count()); + ASSERT_TRUE(poly.is_zero()); + ASSERT_EQ(static_cast(0), static_cast(encoder.decode_int64(poly))); + + Plaintext poly1 = encoder.encode(static_cast(1)); + ASSERT_EQ(1ULL, poly1.significant_coeff_count()); + ASSERT_TRUE("1" == poly1.to_string()); + ASSERT_EQ(1ULL, static_cast(encoder.decode_int64(poly1))); + + Plaintext poly2 = encoder.encode(static_cast(2)); + ASSERT_EQ(2ULL, poly2.significant_coeff_count()); + ASSERT_TRUE("1x^1" == poly2.to_string()); + ASSERT_EQ(static_cast(2), static_cast(encoder.decode_int64(poly2))); + + Plaintext poly3 = encoder.encode(static_cast(3)); + ASSERT_EQ(2ULL, poly3.significant_coeff_count()); + ASSERT_TRUE("1x^1 + 1" == poly3.to_string()); + ASSERT_EQ(static_cast(3), static_cast(encoder.decode_int64(poly3))); + + Plaintext poly4 = encoder.encode(static_cast(-1)); + ASSERT_EQ(1ULL, poly4.significant_coeff_count()); + ASSERT_TRUE("7FFFFFFFFFFFE" == poly4.to_string()); + ASSERT_EQ(static_cast(-1), static_cast(encoder.decode_int64(poly4))); + + Plaintext poly5 = encoder.encode(static_cast(-2)); + ASSERT_EQ(2ULL, poly5.significant_coeff_count()); + ASSERT_TRUE("7FFFFFFFFFFFEx^1" == poly5.to_string()); + ASSERT_EQ(static_cast(-2), static_cast(encoder.decode_int64(poly5))); + + Plaintext poly6 = encoder.encode(static_cast(-3)); + ASSERT_EQ(2ULL, poly6.significant_coeff_count()); + ASSERT_TRUE("7FFFFFFFFFFFEx^1 + 7FFFFFFFFFFFE" == poly6.to_string()); + ASSERT_EQ(static_cast(-3), static_cast(encoder.decode_int64(poly6))); + + Plaintext poly7 = encoder.encode(static_cast(0x7FFFFFFFFFFFF)); + ASSERT_EQ(51ULL, poly7.significant_coeff_count()); + for (size_t i = 0; i < 51; ++i) + { + ASSERT_TRUE(1 == poly7[i]); + } + ASSERT_EQ(static_cast(0x7FFFFFFFFFFFF), static_cast(encoder.decode_int64(poly7))); + + Plaintext poly8 = encoder.encode(static_cast(0x8000000000000)); + ASSERT_EQ(52ULL, poly8.significant_coeff_count()); + ASSERT_TRUE(poly8[51] == 1); + for (size_t i = 0; i < 51; ++i) + { + ASSERT_TRUE(poly8[i] == 0); + } + ASSERT_EQ(static_cast(0x8000000000000), static_cast(encoder.decode_int64(poly8))); + + Plaintext poly9 = encoder.encode(static_cast(0x80F02)); + ASSERT_EQ(20ULL, poly9.significant_coeff_count()); + for (size_t i = 0; i < 20; ++i) + { + if (i == 19 || (i >= 8 && i <= 11) || i == 1) + { + ASSERT_TRUE(1 == poly9[i]); + } + else + { + ASSERT_TRUE(poly9[i] == 0); + } + } + ASSERT_EQ(static_cast(0x80F02), static_cast(encoder.decode_int64(poly9))); + + Plaintext poly10 = encoder.encode(static_cast(-1073)); + ASSERT_EQ(11ULL, poly10.significant_coeff_count()); + ASSERT_TRUE(0x7FFFFFFFFFFFE == poly10[10]); + ASSERT_TRUE(poly10[9] == 0); + ASSERT_TRUE(poly10[8] == 0); + ASSERT_TRUE(poly10[7] == 0); + ASSERT_TRUE(poly10[6] == 0); + ASSERT_TRUE(0x7FFFFFFFFFFFE == poly10[5]); + ASSERT_TRUE(0x7FFFFFFFFFFFE == poly10[4]); + ASSERT_TRUE(poly10[3] == 0); + ASSERT_TRUE(poly10[2] == 0); + ASSERT_TRUE(poly10[1] == 0); + ASSERT_TRUE(0x7FFFFFFFFFFFE == poly10[0]); + ASSERT_EQ(static_cast(-1073), static_cast(encoder.decode_int64(poly10))); + + modulus = 0xFFFF; + BinaryEncoder encoder2(modulus); + Plaintext poly11(6); + poly11[0] = 1; + poly11[1] = 0xFFFE; // -1 + poly11[2] = 0xFFFD; // -2 + poly11[3] = 0x8000; // -32767 + poly11[4] = 0x7FFF; // 32767 + poly11[5] = 0x7FFE; // 32766 + ASSERT_EQ(static_cast(1 + -1 * 2 + -2 * 4 + -32767 * 8 + 32767 * 16 + 32766 * 32), static_cast(encoder2.decode_int64(poly11))); + } + + TEST(Encoder, BalancedEncodeDecodeInt64) + { + SmallModulus modulus(0x10000UL); + BalancedEncoder encoder(modulus); + + Plaintext poly = encoder.encode(static_cast(0)); + ASSERT_EQ(0ULL, poly.significant_coeff_count()); + ASSERT_TRUE(poly.is_zero()); + ASSERT_EQ(static_cast(0), static_cast(encoder.decode_int64(poly))); + + Plaintext poly1 = encoder.encode(static_cast(1)); + ASSERT_EQ(1ULL, poly1.significant_coeff_count()); + ASSERT_TRUE("1" == poly1.to_string()); + ASSERT_EQ(1ULL, static_cast(encoder.decode_int64(poly1))); + + Plaintext poly2 = encoder.encode(static_cast(2)); + ASSERT_EQ(2ULL, poly2.significant_coeff_count()); + ASSERT_TRUE("1x^1 + FFFF" == poly2.to_string()); + ASSERT_EQ(static_cast(2), static_cast(encoder.decode_int64(poly2))); + + Plaintext poly3 = encoder.encode(static_cast(3)); + ASSERT_EQ(2ULL, poly3.significant_coeff_count()); + ASSERT_TRUE("1x^1" == poly3.to_string()); + ASSERT_EQ(static_cast(3), static_cast(encoder.decode_int64(poly3))); + + Plaintext poly4 = encoder.encode(static_cast(-1)); + ASSERT_EQ(1ULL, poly4.significant_coeff_count()); + ASSERT_TRUE("FFFF" == poly4.to_string()); + ASSERT_EQ(static_cast(-1), static_cast(encoder.decode_int64(poly4))); + + Plaintext poly5 = encoder.encode(static_cast(-2)); + ASSERT_EQ(2ULL, poly5.significant_coeff_count()); + ASSERT_TRUE("FFFFx^1 + 1" == poly5.to_string()); + ASSERT_EQ(static_cast(-2), static_cast(encoder.decode_int64(poly5))); + + Plaintext poly6 = encoder.encode(static_cast(-3)); + ASSERT_EQ(2ULL, poly6.significant_coeff_count()); + ASSERT_TRUE("FFFFx^1" == poly6.to_string()); + ASSERT_EQ(static_cast(-3), static_cast(encoder.decode_int64(poly6))); + + Plaintext poly7 = encoder.encode(static_cast(-0x2671)); + ASSERT_EQ(9ULL, poly7.significant_coeff_count()); + for (size_t i = 0; i < 9; ++i) + { + ASSERT_TRUE(0xFFFF == poly7[i]); + } + ASSERT_EQ(static_cast(-0x2671), static_cast(encoder.decode_int64(poly7))); + + Plaintext poly8 = encoder.encode(static_cast(-4374)); + ASSERT_EQ(9ULL, poly8.significant_coeff_count()); + ASSERT_TRUE(0xFFFF == poly8[8]); + ASSERT_TRUE(1 == poly8[7]); + for (size_t i = 0; i < 7; ++i) + { + ASSERT_TRUE(poly8[i] == 0); + } + ASSERT_EQ(static_cast(-4374), static_cast(encoder.decode_int64(poly8))); + + Plaintext poly9 = encoder.encode(static_cast(-0xD4EB)); + ASSERT_EQ(11ULL, poly9.significant_coeff_count()); + for (size_t i = 0; i < 11; ++i) + { + if (i % 3 == 1) + { + ASSERT_TRUE(0xFFFF == poly9[i]); + } + else if (i % 3 == 0) + { + ASSERT_TRUE(poly9[i] == 0); + } + else + { + ASSERT_TRUE(1 == poly9[i]); + } + } + ASSERT_EQ(static_cast(-0xD4EB), static_cast(encoder.decode_int64(poly9))); + + Plaintext poly10 = encoder.encode(static_cast(-30724)); + ASSERT_EQ(11ULL, poly10.significant_coeff_count()); + ASSERT_TRUE(0xFFFF == poly10[10]); + ASSERT_TRUE(1 == poly10[9]); + ASSERT_TRUE(1 == poly10[8]); + ASSERT_TRUE(1 == poly10[7]); + ASSERT_TRUE(poly10[6] == 0); + ASSERT_TRUE(poly10[5] == 0); + ASSERT_TRUE(0xFFFF == poly10[4]); + ASSERT_TRUE(0xFFFF == poly10[3]); + ASSERT_TRUE(poly10[2] == 0); + ASSERT_TRUE(1 == poly10[1]); + ASSERT_TRUE(0xFFFF == poly10[0]); + ASSERT_EQ(static_cast(-30724), static_cast(encoder.decode_int64(poly10))); + + BalancedEncoder encoder2(modulus, 13); + Plaintext poly11 = encoder2.encode(static_cast(-126375543984)); + ASSERT_EQ(11ULL, poly11.significant_coeff_count()); + ASSERT_TRUE(0xFFFF == poly11[10]); + ASSERT_TRUE(1 == poly11[9]); + ASSERT_TRUE(1 == poly11[8]); + ASSERT_TRUE(1 == poly11[7]); + ASSERT_TRUE(poly11[6] == 0); + ASSERT_TRUE(poly11[5] == 0); + ASSERT_TRUE(0xFFFF == poly11[4]); + ASSERT_TRUE(0xFFFF == poly11[3]); + ASSERT_TRUE(poly11[2] == 0); + ASSERT_TRUE(1 == poly11[1]); + ASSERT_TRUE(0xFFFF == poly11[0]); + ASSERT_EQ(static_cast(-126375543984), static_cast(encoder2.decode_int64(poly11))); + + modulus = 0xFFFFUL; + BalancedEncoder encoder3(modulus, 7); + Plaintext poly12(6); + poly12[0] = 1; + poly12[1] = 0xFFFE; // -1 + poly12[2] = 0xFFFD; // -2 + poly12[3] = 0x8000; // -32767 + poly12[4] = 0x7FFF; // 32767 + poly12[5] = 0x7FFE; // 32766 + ASSERT_EQ(static_cast(1 + -1 * 7 + -2 * 49 + -32767 * 343 + 32767 * 2401 + 32766 * 16807), static_cast(encoder3.decode_int64(poly12))); + + BalancedEncoder encoder4(modulus, 6); + poly8.resize(4); + poly8[0] = 5; + poly8[1] = 4; + poly8[2] = 3; + poly8[3] = *modulus.data() - 2; + int64_t value = 5 + 4 * 6 + 3 * 36 - 2 * 216; + ASSERT_TRUE(value == encoder4.decode_int64(poly8)); + + BalancedEncoder encoder5(modulus, 10); + poly9.resize(4); + poly9[0] = 1; + poly9[1] = 2; + poly9[2] = 3; + poly9[3] = 4; + value = 4321; + ASSERT_TRUE(value == encoder5.decode_int64(poly9)); + + value = -1234; + poly10 = encoder3.encode(value); + ASSERT_EQ(5ULL, poly10.significant_coeff_count()); + ASSERT_TRUE(value == encoder3.decode_int64(poly10)); + + value = -1234; + poly11 = encoder4.encode(value); + ASSERT_EQ(5ULL, poly11.significant_coeff_count()); + ASSERT_TRUE(value == encoder4.decode_int64(poly11)); + + value = -1234; + poly12 = encoder5.encode(value); + ASSERT_EQ(4ULL, poly12.significant_coeff_count()); + ASSERT_TRUE(value == encoder5.decode_int64(poly12)); + } + + TEST(Encoder, EncodeDecodeInt32) + { + SmallModulus modulus(0x7FFFFFFFFFFFFF); + BinaryEncoder encoder(modulus); + + Plaintext poly = encoder.encode(static_cast(0)); + ASSERT_EQ(0ULL, poly.significant_coeff_count()); + ASSERT_TRUE(poly.is_zero()); + ASSERT_EQ(static_cast(0), encoder.decode_int32(poly)); + + Plaintext poly1 = encoder.encode(static_cast(1)); + ASSERT_EQ(1ULL, poly1.significant_coeff_count()); + ASSERT_TRUE("1" == poly1.to_string()); + ASSERT_EQ(static_cast(1), encoder.decode_int32(poly1)); + + Plaintext poly2 = encoder.encode(static_cast(2)); + ASSERT_EQ(2ULL, poly2.significant_coeff_count()); + ASSERT_TRUE("1x^1" == poly2.to_string()); + ASSERT_EQ(static_cast(2), encoder.decode_int32(poly2)); + + Plaintext poly3 = encoder.encode(static_cast(3)); + ASSERT_EQ(2ULL, poly3.significant_coeff_count()); + ASSERT_TRUE("1x^1 + 1" == poly3.to_string()); + ASSERT_EQ(static_cast(3), encoder.decode_int32(poly3)); + + Plaintext poly4 = encoder.encode(static_cast(-1)); + ASSERT_EQ(1ULL, poly4.significant_coeff_count()); + ASSERT_TRUE("7FFFFFFFFFFFFE" == poly4.to_string()); + ASSERT_EQ(static_cast(-1), encoder.decode_int32(poly4)); + + Plaintext poly5 = encoder.encode(static_cast(-2)); + ASSERT_EQ(2ULL, poly5.significant_coeff_count()); + ASSERT_TRUE("7FFFFFFFFFFFFEx^1" == poly5.to_string()); + ASSERT_EQ(static_cast(-2), encoder.decode_int32(poly5)); + + Plaintext poly6 = encoder.encode(static_cast(-3)); + ASSERT_EQ(2ULL, poly6.significant_coeff_count()); + ASSERT_TRUE("7FFFFFFFFFFFFEx^1 + 7FFFFFFFFFFFFE" == poly6.to_string()); + ASSERT_EQ(static_cast(-3), encoder.decode_int32(poly6)); + + Plaintext poly7 = encoder.encode(static_cast(0x7FFFFFFF)); + ASSERT_EQ(31ULL, poly7.significant_coeff_count()); + for (size_t i = 0; i < 31; ++i) + { + ASSERT_TRUE(1 == poly7[i]); + } + ASSERT_EQ(static_cast(0x7FFFFFFF), encoder.decode_int32(poly7)); + + Plaintext poly8 = encoder.encode(static_cast(0x80000000)); + ASSERT_EQ(32ULL, poly8.significant_coeff_count()); + ASSERT_TRUE(0x7FFFFFFFFFFFFE == poly8[31]); + for (size_t i = 0; i < 31; ++i) + { + ASSERT_TRUE(poly8[i] == 0); + } + ASSERT_EQ(static_cast(0x80000000), encoder.decode_int32(poly8)); + + Plaintext poly9 = encoder.encode(static_cast(0x80F02)); + ASSERT_EQ(20ULL, poly9.significant_coeff_count()); + for (size_t i = 0; i < 20; ++i) + { + if (i == 19 || (i >= 8 && i <= 11) || i == 1) + { + ASSERT_TRUE(1 == poly9[i]); + } + else + { + ASSERT_TRUE(poly9[i] == 0); + } + } + ASSERT_EQ(static_cast(0x80F02), encoder.decode_int32(poly9)); + + Plaintext poly10 = encoder.encode(static_cast(-1073)); + ASSERT_EQ(11ULL, poly10.significant_coeff_count()); + ASSERT_TRUE(0x7FFFFFFFFFFFFE == poly10[10]); + ASSERT_TRUE(poly10[9] == 0); + ASSERT_TRUE(poly10[8] == 0); + ASSERT_TRUE(poly10[7] == 0); + ASSERT_TRUE(poly10[6] == 0); + ASSERT_TRUE(0x7FFFFFFFFFFFFE == poly10[5]); + ASSERT_TRUE(0x7FFFFFFFFFFFFE == poly10[4]); + ASSERT_TRUE(poly10[3] == 0); + ASSERT_TRUE(poly10[2] == 0); + ASSERT_TRUE(poly10[1] == 0); + ASSERT_TRUE(0x7FFFFFFFFFFFFE == poly10[0]); + ASSERT_EQ(static_cast(-1073), encoder.decode_int32(poly10)); + + modulus = 0xFFFF; + BinaryEncoder encoder2(modulus); + Plaintext poly11(6); + poly11[0] = 1; + poly11[1] = 0xFFFE; // -1 + poly11[2] = 0xFFFD; // -2 + poly11[3] = 0x8000; // -32767 + poly11[4] = 0x7FFF; // 32767 + poly11[5] = 0x7FFE; // 32766 + ASSERT_EQ(static_cast(1 + -1 * 2 + -2 * 4 + -32767 * 8 + 32767 * 16 + 32766 * 32), encoder2.decode_int32(poly11)); + } + + TEST(Encoder, BalancedEncodeDecodeInt32) + { + SmallModulus modulus(0x10000UL); + BalancedEncoder encoder(modulus); + + Plaintext poly = encoder.encode(static_cast(0)); + ASSERT_EQ(0ULL, poly.significant_coeff_count()); + ASSERT_TRUE(poly.is_zero()); + ASSERT_EQ(static_cast(0), encoder.decode_int32(poly)); + + Plaintext poly1 = encoder.encode(static_cast(1)); + ASSERT_EQ(1ULL, poly1.significant_coeff_count()); + ASSERT_TRUE("1" == poly1.to_string()); + ASSERT_EQ(static_cast(1), encoder.decode_int32(poly1)); + + Plaintext poly2 = encoder.encode(static_cast(2)); + ASSERT_EQ(2ULL, poly2.significant_coeff_count()); + ASSERT_TRUE("1x^1 + FFFF" == poly2.to_string()); + ASSERT_EQ(static_cast(2), encoder.decode_int32(poly2)); + + Plaintext poly3 = encoder.encode(static_cast(3)); + ASSERT_EQ(2ULL, poly3.significant_coeff_count()); + ASSERT_TRUE("1x^1" == poly3.to_string()); + ASSERT_EQ(static_cast(3), encoder.decode_int32(poly3)); + + Plaintext poly4 = encoder.encode(static_cast(-1)); + ASSERT_EQ(1ULL, poly4.significant_coeff_count()); + ASSERT_TRUE("FFFF" == poly4.to_string()); + ASSERT_EQ(static_cast(-1), encoder.decode_int32(poly4)); + + Plaintext poly5 = encoder.encode(static_cast(-2)); + ASSERT_EQ(2ULL, poly5.significant_coeff_count()); + ASSERT_TRUE("FFFFx^1 + 1" == poly5.to_string()); + ASSERT_EQ(static_cast(-2), encoder.decode_int32(poly5)); + + Plaintext poly6 = encoder.encode(static_cast(-3)); + ASSERT_EQ(2ULL, poly6.significant_coeff_count()); + ASSERT_TRUE("FFFFx^1" == poly6.to_string()); + ASSERT_EQ(static_cast(-3), encoder.decode_int32(poly6)); + + Plaintext poly7 = encoder.encode(static_cast(-0x2671)); + ASSERT_EQ(9ULL, poly7.significant_coeff_count()); + for (size_t i = 0; i < 9; ++i) + { + ASSERT_TRUE(0xFFFF == poly7[i]); + } + ASSERT_EQ(static_cast(-0x2671), encoder.decode_int32(poly7)); + + Plaintext poly8 = encoder.encode(static_cast(-4374)); + ASSERT_EQ(9ULL, poly8.significant_coeff_count()); + ASSERT_TRUE(0xFFFF == poly8[8]); + ASSERT_TRUE(1 == poly8[7]); + for (size_t i = 0; i < 7; ++i) + { + ASSERT_TRUE(poly8[i] == 0); + } + ASSERT_EQ(static_cast(-4374), encoder.decode_int32(poly8)); + + Plaintext poly9 = encoder.encode(static_cast(-0xD4EB)); + ASSERT_EQ(11ULL, poly9.significant_coeff_count()); + for (size_t i = 0; i < 11; ++i) + { + if (i % 3 == 1) + { + ASSERT_TRUE(0xFFFF == poly9[i]); + } + else if (i % 3 == 0) + { + ASSERT_TRUE(poly9[i] == 0); + } + else + { + ASSERT_TRUE(1 == poly9[i]); + } + } + ASSERT_EQ(static_cast(-0xD4EB), encoder.decode_int32(poly9)); + + Plaintext poly10 = encoder.encode(static_cast(-30724)); + ASSERT_EQ(11ULL, poly10.significant_coeff_count()); + ASSERT_TRUE(0xFFFF == poly10[10]); + ASSERT_TRUE(1 == poly10[9]); + ASSERT_TRUE(1 == poly10[8]); + ASSERT_TRUE(1 == poly10[7]); + ASSERT_TRUE(poly10[6] == 0); + ASSERT_TRUE(poly10[5] == 0); + ASSERT_TRUE(0xFFFF == poly10[4]); + ASSERT_TRUE(0xFFFF == poly10[3]); + ASSERT_TRUE(poly10[2] == 0); + ASSERT_TRUE(1 == poly10[1]); + ASSERT_TRUE(0xFFFF == poly10[0]); + ASSERT_EQ(static_cast(-30724), encoder.decode_int32(poly10)); + + modulus = 0xFFFFUL; + BalancedEncoder encoder2(modulus, 7); + Plaintext poly12(6); + poly12[0] = 1; + poly12[1] = 0xFFFE; // -1 + poly12[2] = 0xFFFD; // -2 + poly12[3] = 0x8000; // -32767 + poly12[4] = 0x7FFF; // 32767 + poly12[5] = 0x7FFE; // 32766 + ASSERT_EQ(static_cast(1 + -1 * 7 + -2 * 49 + -32767 * 343 + 32767 * 2401 + 32766 * 16807), encoder2.decode_int32(poly12)); + + BalancedEncoder encoder4(modulus, 6); + poly8.resize(4); + poly8[0] = 5; + poly8[1] = 4; + poly8[2] = 3; + poly8[3] = *modulus.data() - 2; + int32_t value = 5 + 4 * 6 + 3 * 36 - 2 * 216; + ASSERT_TRUE(value == encoder4.decode_int32(poly8)); + + BalancedEncoder encoder5(modulus, 10); + poly9.resize(4); + poly9[0] = 1; + poly9[1] = 2; + poly9[2] = 3; + poly9[3] = 4; + value = 4321; + ASSERT_TRUE(value == encoder5.decode_int32(poly9)); + + value = -1234; + poly10 = encoder2.encode(value); + ASSERT_EQ(5ULL, poly10.significant_coeff_count()); + ASSERT_TRUE(value == encoder2.decode_int32(poly10)); + + value = -1234; + Plaintext poly11 = encoder4.encode(value); + ASSERT_EQ(5ULL, poly11.significant_coeff_count()); + ASSERT_TRUE(value == encoder4.decode_int32(poly11)); + + value = -1234; + poly12 = encoder5.encode(value); + ASSERT_EQ(4ULL, poly12.significant_coeff_count()); + ASSERT_TRUE(value == encoder5.decode_int32(poly12)); + } + + TEST(Encoder, BinaryFractionalEncodeDecode) + { + size_t poly_modulus_degree = 1024; + SmallModulus modulus(0x10000UL); + BinaryFractionalEncoder encoder(modulus, poly_modulus_degree, 500, 50); + + Plaintext poly = encoder.encode(0.0); + ASSERT_TRUE(poly.is_zero()); + ASSERT_EQ(0.0, encoder.decode(poly)); + + Plaintext poly1 = encoder.encode(-1.0); + ASSERT_EQ(-1.0, encoder.decode(poly1)); + + Plaintext poly2 = encoder.encode(0.1); + ASSERT_TRUE(fabs(encoder.decode(poly2) - 0.1) / 0.1 < 0.000001); + + Plaintext poly3 = encoder.encode(3.123); + ASSERT_TRUE(fabs(encoder.decode(poly3) - 3.123) / 3.123 < 0.000001); + + Plaintext poly4 = encoder.encode(-123.456); + ASSERT_TRUE(fabs(encoder.decode(poly4) + 123.456) / (-123.456) < 0.000001); + + Plaintext poly5 = encoder.encode(12345.98765); + ASSERT_TRUE(fabs(encoder.decode(poly5) - 12345.98765) / 12345.98765 < 0.000001); + } + + TEST(Encoder, BalancedFractionalEncodeDecode) + { + size_t poly_modulus_degree = 1024; + { + SmallModulus modulus(0x10000UL); + for (uint64_t b = 3; b < 20; ++b) + { + BalancedFractionalEncoder encoder(modulus, poly_modulus_degree, 500, 50, b); + + Plaintext poly = encoder.encode(0.0); + ASSERT_TRUE(poly.is_zero()); + ASSERT_EQ(0.0, encoder.decode(poly)); + + Plaintext poly1 = encoder.encode(-1.0); + ASSERT_EQ(-1.0, encoder.decode(poly1)); + + Plaintext poly2 = encoder.encode(0.1); + ASSERT_TRUE(fabs(encoder.decode(poly2) - 0.1) / 0.1 < 0.000001); + + Plaintext poly3 = encoder.encode(3.123); + ASSERT_TRUE(fabs(encoder.decode(poly3) - 3.123) / 3.123 < 0.000001); + + Plaintext poly4 = encoder.encode(-123.456); + ASSERT_TRUE(fabs(encoder.decode(poly4) + 123.456) / (-123.456) < 0.000001); + + Plaintext poly5 = encoder.encode(12345.98765); + ASSERT_TRUE(fabs(encoder.decode(poly5) - 12345.98765) / 12345.98765 < 0.000001); + + Plaintext poly6 = encoder.encode(-0.0); + ASSERT_EQ(0.0, encoder.decode(poly)); + + Plaintext poly7 = encoder.encode(0.115); + ASSERT_TRUE(fabs(encoder.decode(poly7) - 0.115) / 0.115 < 0.000001); + } + } + + { + SmallModulus modulus(0x100000000000); + for (uint64_t b = 3; b < 20; ++b) + { + BalancedFractionalEncoder encoder(modulus, poly_modulus_degree, 500, 50, b); + + Plaintext poly = encoder.encode(0.0); + ASSERT_TRUE(poly.is_zero()); + ASSERT_EQ(0.0, encoder.decode(poly)); + + Plaintext poly1 = encoder.encode(-1.0); + ASSERT_EQ(-1.0, encoder.decode(poly1)); + + Plaintext poly2 = encoder.encode(0.1); + ASSERT_TRUE(fabs(encoder.decode(poly2) - 0.1) / 0.1 < 0.000001); + + Plaintext poly3 = encoder.encode(3.123); + ASSERT_TRUE(fabs(encoder.decode(poly3) - 3.123) / 3.123 < 0.000001); + + Plaintext poly4 = encoder.encode(-123.456); + ASSERT_TRUE(fabs(encoder.decode(poly4) + 123.456) / (-123.456) < 0.000001); + + Plaintext poly5 = encoder.encode(12345.98765); + ASSERT_TRUE(fabs(encoder.decode(poly5) - 12345.98765) / 12345.98765 < 0.000001); + + Plaintext poly6 = encoder.encode(-0.0); + ASSERT_EQ(0.0, encoder.decode(poly)); + + Plaintext poly7 = encoder.encode(0.115); + ASSERT_TRUE(fabs(encoder.decode(poly7) - 0.115) / 0.115 < 0.000001); + } + } + } +} diff --git a/tests/seal/encryptionparams.cpp b/tests/seal/encryptionparams.cpp new file mode 100644 index 000000000..6f4d12ee6 --- /dev/null +++ b/tests/seal/encryptionparams.cpp @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/encryptionparams.h" +#include "seal/defaultparams.h" + +using namespace seal; +using namespace std; + +namespace SEALTest +{ + TEST(EncryptionParametersTest, EncryptionParametersSet) + { + auto scheme = scheme_type::BFV; + EncryptionParameters parms(scheme); + parms.set_noise_standard_deviation(3.20); + parms.set_coeff_modulus({ 2, 3 }); + parms.set_plain_modulus(2); + parms.set_poly_modulus_degree(2); + parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); + + ASSERT_TRUE(scheme == parms.scheme()); + ASSERT_EQ(3.20, parms.noise_standard_deviation()); + ASSERT_EQ(3.20 * 6, parms.noise_max_deviation()); + ASSERT_TRUE(parms.coeff_modulus()[0] == 2); + ASSERT_TRUE(parms.coeff_modulus()[1] == 3); + ASSERT_TRUE(parms.plain_modulus() == 2); + ASSERT_TRUE(parms.poly_modulus_degree() == 2); + ASSERT_TRUE(parms.random_generator() == UniformRandomGeneratorFactory::default_factory()); + + parms.set_noise_standard_deviation(3.20); + parms.set_coeff_modulus({ + small_mods_30bit(0), small_mods_40bit(0), small_mods_50bit(0) + }); + parms.set_plain_modulus(2); + parms.set_poly_modulus_degree(128); + parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); + + ASSERT_EQ(3.20, parms.noise_standard_deviation()); + ASSERT_EQ(3.20 * 6, parms.noise_max_deviation()); + ASSERT_TRUE(parms.coeff_modulus()[0] == small_mods_30bit(0)); + ASSERT_TRUE(parms.coeff_modulus()[1] == small_mods_40bit(0)); + ASSERT_TRUE(parms.coeff_modulus()[2] == small_mods_50bit(0)); + ASSERT_TRUE(parms.plain_modulus() == 2); + ASSERT_TRUE(parms.poly_modulus_degree() == 128); + ASSERT_TRUE(parms.random_generator() == UniformRandomGeneratorFactory::default_factory()); + } + + TEST(EncryptionParametersTest, EncryptionParametersCompare) + { + auto scheme = scheme_type::BFV; + EncryptionParameters parms1(scheme); + parms1.set_noise_standard_deviation(3.20); + parms1.set_coeff_modulus({ small_mods_30bit(0) }); + parms1.set_plain_modulus(1 << 6); + parms1.set_poly_modulus_degree(64); + parms1.set_random_generator(UniformRandomGeneratorFactory::default_factory()); + + EncryptionParameters parms2(parms1); + ASSERT_TRUE(parms1 == parms2); + + EncryptionParameters parms3(scheme); + parms3 = parms2; + ASSERT_TRUE(parms3 == parms2); + parms3.set_coeff_modulus({ small_mods_30bit(1) }); + ASSERT_FALSE(parms3 == parms2); + + parms3 = parms2; + ASSERT_TRUE(parms3 == parms2); + parms3.set_coeff_modulus({ + small_mods_30bit(0), small_mods_30bit(1) + }); + ASSERT_FALSE(parms3 == parms2); + + parms3 = parms2; + parms3.set_poly_modulus_degree(128); + ASSERT_FALSE(parms3 == parms1); + + parms3 = parms2; + parms3.set_plain_modulus((1 << 6) + 1); + ASSERT_FALSE(parms3 == parms2); + + parms3 = parms2; + parms3.set_noise_standard_deviation(3.18); + ASSERT_FALSE(parms3 == parms2); + + parms3 = parms2; + parms3.set_random_generator(nullptr); + ASSERT_TRUE(parms3 == parms2); + + parms3 = parms2; + parms3.set_poly_modulus_degree(128); + parms3.set_poly_modulus_degree(64); + ASSERT_TRUE(parms3 == parms1); + + parms3 = parms2; + parms3.set_coeff_modulus({ 2 }); + parms3.set_coeff_modulus({ small_mods_50bit(0) }); + parms3.set_coeff_modulus(parms2.coeff_modulus()); + ASSERT_TRUE(parms3 == parms2); + } + + TEST(EncryptionParametersTest, EncryptionParametersSaveLoad) + { + stringstream stream; + + auto scheme = scheme_type::BFV; + EncryptionParameters parms(scheme); + EncryptionParameters parms2(scheme); + parms.set_noise_standard_deviation(3.20); + parms.set_coeff_modulus({ small_mods_30bit(0) }); + parms.set_plain_modulus(1 << 6); + parms.set_poly_modulus_degree(64); + EncryptionParameters::Save(parms, stream); + parms2 = EncryptionParameters::Load(stream); + ASSERT_TRUE(parms.scheme() == parms2.scheme()); + ASSERT_EQ(parms.noise_standard_deviation(), parms2.noise_standard_deviation()); + ASSERT_EQ(parms.noise_max_deviation(), parms2.noise_max_deviation()); + ASSERT_TRUE(parms.coeff_modulus() == parms2.coeff_modulus()); + ASSERT_TRUE(parms.plain_modulus() == parms2.plain_modulus()); + ASSERT_TRUE(parms.poly_modulus_degree() == parms2.poly_modulus_degree()); + ASSERT_TRUE(parms == parms2); + + parms.set_noise_standard_deviation(3.20); + parms.set_coeff_modulus({ + small_mods_30bit(0), small_mods_60bit(0), small_mods_60bit(1) + }); + parms.set_plain_modulus(1 << 30); + parms.set_poly_modulus_degree(256); + + EncryptionParameters::Save(parms, stream); + parms2 = EncryptionParameters::Load(stream); + ASSERT_TRUE(parms.scheme() == parms2.scheme()); + ASSERT_EQ(parms.noise_standard_deviation(), parms2.noise_standard_deviation()); + ASSERT_EQ(parms.noise_max_deviation(), parms2.noise_max_deviation()); + ASSERT_TRUE(parms.coeff_modulus() == parms2.coeff_modulus()); + ASSERT_TRUE(parms.plain_modulus() == parms2.plain_modulus()); + ASSERT_TRUE(parms.poly_modulus_degree() == parms2.poly_modulus_degree()); + ASSERT_TRUE(parms == parms2); + } +} diff --git a/tests/seal/encryptor.cpp b/tests/seal/encryptor.cpp new file mode 100644 index 000000000..5a94ba3e9 --- /dev/null +++ b/tests/seal/encryptor.cpp @@ -0,0 +1,400 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/context.h" +#include "seal/encryptor.h" +#include "seal/decryptor.h" +#include "seal/keygenerator.h" +#include "seal/batchencoder.h" +#include "seal/ckks.h" +#include "seal/encoder.h" +#include "seal/defaultparams.h" +#include +#include +#include + +using namespace seal; +using namespace std; + +namespace SEALTest +{ + TEST(EncryptorTest, FVEncryptDecrypt) + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_noise_standard_deviation(3.20); + parms.set_plain_modulus(plain_modulus); + { + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus({ small_mods_40bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + encryptor.encrypt(encoder.encode(0x12345678), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0x12345678ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(1), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(1ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(2), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(2ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFD)), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0x7FFFFFFFFFFFFFFDULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFE)), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0x7FFFFFFFFFFFFFFEULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFF)), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0x7FFFFFFFFFFFFFFFULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(314159265), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(314159265ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + } + { + parms.set_poly_modulus_degree(128); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + encryptor.encrypt(encoder.encode(0x12345678), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0x12345678ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(1), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(1ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(2), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(2ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFD)), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0x7FFFFFFFFFFFFFFDULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFE)), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0x7FFFFFFFFFFFFFFEULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFF)), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0x7FFFFFFFFFFFFFFFULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(314159265), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(314159265ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + } + + { + parms.set_poly_modulus_degree(256); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1), small_mods_40bit(2) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + encryptor.encrypt(encoder.encode(0x12345678), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0x12345678ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(1), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(1ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(2), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(2ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFD)), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0x7FFFFFFFFFFFFFFDULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFE)), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0x7FFFFFFFFFFFFFFEULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFF)), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0x7FFFFFFFFFFFFFFFULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(314159265), encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(314159265ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + } + } + + TEST(EncryptorTest, CKKSEncryptDecrypt) + { + EncryptionParameters parms(scheme_type::CKKS); + parms.set_noise_standard_deviation(3.20); + { + //input consists of ones + size_t slot_size = 32; + parms.set_poly_modulus_degree(2 * slot_size); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1), small_mods_40bit(2), small_mods_40bit(3) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + std::vector> input(slot_size, 1.0); + std::vector> output(slot_size); + const double delta = static_cast(1 << 16); + + encoder.encode(input, parms.parms_id(), delta, plain); + encryptor.encrypt(plain, encrypted); + + //check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + { + //input consists of zeros + size_t slot_size = 32; + parms.set_poly_modulus_degree(2 * slot_size); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1), small_mods_40bit(2), small_mods_40bit(3) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + std::vector> input(slot_size, 0.0); + std::vector> output(slot_size); + const double delta = static_cast(1 << 16); + + encoder.encode(input, parms.parms_id(), delta, plain); + encryptor.encrypt(plain, encrypted); + + //check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + { + // Input is a random mix of positive and negative integers + size_t slot_size = 64; + parms.set_poly_modulus_degree(2 * slot_size); + parms.set_coeff_modulus({ small_mods_60bit(0), small_mods_60bit(1), small_mods_60bit(2) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + std::vector> input(slot_size); + std::vector> output(slot_size); + + srand(static_cast(time(NULL))); + int input_bound = 1 << 30; + const double delta = static_cast(1ULL << 50); + + for (int round = 0; round < 100; round++) + { + for (size_t i = 0; i < slot_size; i++) + { + input[i] = pow(-1.0, rand() % 2) * static_cast(rand() % input_bound); + } + + encoder.encode(input, parms.parms_id(), delta, plain); + encryptor.encrypt(plain, encrypted); + + //check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + // Input is a random mix of positive and negative integers + size_t slot_size = 32; + parms.set_poly_modulus_degree(128); + parms.set_coeff_modulus({ small_mods_60bit(0), small_mods_60bit(1), small_mods_60bit(2) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + std::vector> input(slot_size); + std::vector> output(slot_size); + + srand(static_cast(time(NULL))); + int input_bound = 1 << 30; + const double delta = static_cast(1ULL << 60); + + for (int round = 0; round < 100; round++) + { + for (size_t i = 0; i < slot_size; i++) + { + input[i] = pow(-1.0, rand() % 2) * static_cast(rand() % input_bound); + } + + encoder.encode(input, parms.parms_id(), delta, plain); + encryptor.encrypt(plain, encrypted); + + //check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plain, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + // Encrypt at lower level + size_t slot_size = 32; + parms.set_poly_modulus_degree(2 * slot_size); + parms.set_coeff_modulus({ small_mods_40bit(0), + small_mods_40bit(1), small_mods_40bit(2), + small_mods_40bit(3) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + std::vector> input(slot_size, 1.0); + std::vector> output(slot_size); + const double delta = static_cast(1 << 16); + + auto first_context_data = context->context_data(); + ASSERT_NE(nullptr, first_context_data.get()); + auto second_context_data = first_context_data->next_context_data(); + ASSERT_NE(nullptr, second_context_data.get()); + auto second_parms_id = second_context_data->parms().parms_id(); + + encoder.encode(input, second_parms_id, delta, plain); + encryptor.encrypt(plain, encrypted); + + // Check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == second_parms_id); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } +} diff --git a/tests/seal/evaluator.cpp b/tests/seal/evaluator.cpp new file mode 100644 index 000000000..f9f9e4e55 --- /dev/null +++ b/tests/seal/evaluator.cpp @@ -0,0 +1,3844 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/context.h" +#include "seal/encryptor.h" +#include "seal/decryptor.h" +#include "seal/evaluator.h" +#include "seal/keygenerator.h" +#include "seal/batchencoder.h" +#include "seal/ckks.h" +#include "seal/encoder.h" +#include "seal/defaultparams.h" +#include +#include +#include +#include + +using namespace seal; +using namespace std; + +namespace SEALTest +{ + TEST(EvaluatorTest, FVEncryptNegateDecrypt) + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + encryptor.encrypt(encoder.encode(0x12345678), encrypted); + evaluator.negate_inplace(encrypted); + Plaintext plain; + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(static_cast(-0x12345678), encoder.decode_int32(plain)); + ASSERT_TRUE(encrypted.parms_id() == encrypted.parms_id()); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted); + evaluator.negate_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(static_cast(0), encoder.decode_int32(plain)); + ASSERT_TRUE(encrypted.parms_id() == encrypted.parms_id()); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(1), encrypted); + evaluator.negate_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(static_cast(-1), encoder.decode_int32(plain)); + ASSERT_TRUE(encrypted.parms_id() == encrypted.parms_id()); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(-1), encrypted); + evaluator.negate_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(static_cast(1), encoder.decode_int32(plain)); + ASSERT_TRUE(encrypted.parms_id() == encrypted.parms_id()); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(2), encrypted); + evaluator.negate_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(static_cast(-2), encoder.decode_int32(plain)); + ASSERT_TRUE(encrypted.parms_id() == encrypted.parms_id()); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(-5), encrypted); + evaluator.negate_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(static_cast(5), encoder.decode_int32(plain)); + ASSERT_TRUE(encrypted.parms_id() == encrypted.parms_id()); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + } + + TEST(EvaluatorTest, FVEncryptAddDecrypt) + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1; + encryptor.encrypt(encoder.encode(0x12345678), encrypted1); + Ciphertext encrypted2; + encryptor.encrypt(encoder.encode(0x54321), encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + Plaintext plain; + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0x12399999), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted1); + encryptor.encrypt(encoder.encode(0), encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted1); + encryptor.encrypt(encoder.encode(5), encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(5), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(5), encrypted1); + encryptor.encrypt(encoder.encode(-3), encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(2), encoder.decode_int32(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(-7), encrypted1); + encryptor.encrypt(encoder.encode(2), encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(-5), encoder.decode_int32(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + Plaintext plain1("2x^2 + 1x^1 + 3"); + Plaintext plain2("3x^3 + 4x^2 + 5x^1 + 6"); + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_TRUE(plain.to_string() == "3x^3 + 6x^2 + 6x^1 + 9"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + plain1 = "3x^5 + 1x^4 + 4x^3 + 1"; + plain2 = "5x^2 + 9x^1 + 2"; + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_TRUE(plain.to_string() == "3x^5 + 1x^4 + 4x^3 + 5x^2 + 9x^1 + 3"); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + } + + TEST(EvaluatorTest, CKKSEncryptAddDecrypt) + { + EncryptionParameters parms(scheme_type::CKKS); + parms.set_noise_standard_deviation(3.20); + { + //adding two zero vectors + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus( + { small_mods_30bit(0), small_mods_30bit(1), + small_mods_30bit(2), small_mods_30bit(3), + small_mods_30bit(4) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + std::vector> input(slot_size, 0.0); + std::vector> output(slot_size); + const double delta = static_cast(1 << 16); + encoder.encode(input, parms.parms_id(), delta, plain); + + encryptor.encrypt(plain, encrypted); + evaluator.add_inplace(encrypted, encrypted); + + //check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted, plainRes); + + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + { + //adding two random vectors 100 times + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus( + { small_mods_60bit(0), small_mods_60bit(1), small_mods_60bit(2)}); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + std::vector> input2(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + std::vector> output(slot_size); + + int data_bound = (1 << 30); + const double delta = static_cast(1 << 16); + + srand(static_cast(time(NULL))); + + for (int expCount = 0; expCount < 100; expCount++) + { + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] + input2[i]; + } + + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted1, plainRes); + + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + + } + { + //adding two random vectors 100 times + size_t slot_size = 8; + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus( + { small_mods_60bit(0), small_mods_60bit(1), small_mods_60bit(2) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + std::vector> input2(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + std::vector> output(slot_size); + + int data_bound = (1 << 30); + const double delta = static_cast(1 << 16); + + srand(static_cast(time(NULL))); + + for (int expCount = 0; expCount < 100; expCount++) + { + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] + input2[i]; + } + + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.add_inplace(encrypted1, encrypted2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted1, plainRes); + + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + + } + } + TEST(EvaluatorTest, CKKSEncryptAddPlainDecrypt) + { + EncryptionParameters parms(scheme_type::CKKS); + parms.set_noise_standard_deviation(3.20); + { + //adding two zero vectors + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus( + { small_mods_30bit(0), small_mods_30bit(1), + small_mods_30bit(2), small_mods_30bit(3), small_mods_30bit(4) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + std::vector> input(slot_size, 0.0); + std::vector> output(slot_size); + const double delta = static_cast(1 << 16); + encoder.encode(input, parms.parms_id(), delta, plain); + + encryptor.encrypt(plain, encrypted); + evaluator.add_plain_inplace(encrypted, plain); + + //check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted, plainRes); + + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + { + //adding two random vectors 50 times + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus( + { small_mods_60bit(0), small_mods_60bit(1), small_mods_60bit(2) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted1; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + std::vector> input2(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + std::vector> output(slot_size); + + int data_bound = (1 << 8); + const double delta = static_cast(1ULL << 16); + + srand(static_cast(time(NULL))); + + for (int expCount = 0; expCount < 50; expCount++) + { + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] + input2[i]; + } + + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + + encryptor.encrypt(plain1, encrypted1); + evaluator.add_plain_inplace(encrypted1, plain2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted1, plainRes); + + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + + } + { + //adding two random vectors 50 times + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus( + { small_mods_60bit(0), small_mods_60bit(1), small_mods_60bit(2) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted1; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + double input2; + std::vector> expected(slot_size, 0.0); + std::vector> output(slot_size); + + int data_bound = (1 << 8); + const double delta = static_cast(1ULL << 16); + + srand(static_cast(time(NULL))); + + for (int expCount = 0; expCount < 50; expCount++) + { + input2 = static_cast(rand() % (data_bound*data_bound))/data_bound; + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] + input2; + } + + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + + encryptor.encrypt(plain1, encrypted1); + evaluator.add_plain_inplace(encrypted1, plain2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted1, plainRes); + + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + + } + { + //adding two random vectors 50 times + size_t slot_size = 8; + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus( + { small_mods_60bit(0), small_mods_60bit(1), small_mods_60bit(2) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted1; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + double input2; + std::vector> expected(slot_size, 0.0); + std::vector> output(slot_size); + + int data_bound = (1 << 8); + const double delta = static_cast(1ULL << 16); + + srand(static_cast(time(NULL))); + + for (int expCount = 0; expCount < 50; expCount++) + { + input2 = static_cast(rand() % (data_bound*data_bound)) / data_bound; + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] + input2; + } + + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + + encryptor.encrypt(plain1, encrypted1); + evaluator.add_plain_inplace(encrypted1, plain2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted1, plainRes); + + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + + } + } + + TEST(EvaluatorTest, CKKSEncryptSubPlainDecrypt) + { + EncryptionParameters parms(scheme_type::CKKS); + parms.set_noise_standard_deviation(3.20); + { + //adding two zero vectors + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus( + { small_mods_30bit(0), small_mods_30bit(1), + small_mods_30bit(2), small_mods_30bit(3), small_mods_30bit(4) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + std::vector> input(slot_size, 0.0); + std::vector> output(slot_size); + const double delta = static_cast(1 << 16); + encoder.encode(input, parms.parms_id(), delta, plain); + + encryptor.encrypt(plain, encrypted); + evaluator.add_plain_inplace(encrypted, plain); + + //check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted, plainRes); + + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + { + //adding two random vectors 100 times + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus( + { small_mods_60bit(0), small_mods_60bit(1), small_mods_60bit(2) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted1; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + std::vector> input2(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + std::vector> output(slot_size); + + int data_bound = (1 << 8); + const double delta = static_cast(1ULL << 16); + + srand(static_cast(time(NULL))); + + for (int expCount = 0; expCount < 100; expCount++) + { + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] - input2[i]; + } + + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + + encryptor.encrypt(plain1, encrypted1); + evaluator.sub_plain_inplace(encrypted1, plain2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted1, plainRes); + + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + + } + { + //adding two random vectors 100 times + size_t slot_size = 8; + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus( + { small_mods_60bit(0), small_mods_60bit(1), small_mods_60bit(2) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted1; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + std::vector> input2(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + std::vector> output(slot_size); + + int data_bound = (1 << 8); + const double delta = static_cast(1ULL << 16); + + srand(static_cast(time(NULL))); + + for (int expCount = 0; expCount < 100; expCount++) + { + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] - input2[i]; + } + + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + + encryptor.encrypt(plain1, encrypted1); + evaluator.sub_plain_inplace(encrypted1, plain2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted1, plainRes); + + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + + } + } + + TEST(EvaluatorTest, FVEncryptSubDecrypt) + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1; + encryptor.encrypt(encoder.encode(0x12345678), encrypted1); + Ciphertext encrypted2; + encryptor.encrypt(encoder.encode(0x54321), encrypted2); + evaluator.sub_inplace(encrypted1, encrypted2); + Plaintext plain; + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0x122F1357), encoder.decode_int32(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted1); + encryptor.encrypt(encoder.encode(0), encrypted2); + evaluator.sub_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0), encoder.decode_int32(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted1); + encryptor.encrypt(encoder.encode(5), encrypted2); + evaluator.sub_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(-5), encoder.decode_int32(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(5), encrypted1); + encryptor.encrypt(encoder.encode(-3), encrypted2); + evaluator.sub_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(8), encoder.decode_int32(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(-7), encrypted1); + encryptor.encrypt(encoder.encode(2), encrypted2); + evaluator.sub_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(-9), encoder.decode_int32(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + } + + TEST(EvaluatorTest, FVEncryptAddPlainDecrypt) + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Plaintext plain; + encryptor.encrypt(encoder.encode(0x12345678), encrypted1); + plain = encoder.encode(0x54321); + evaluator.add_plain_inplace(encrypted1, plain); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0x12399999), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted1); + plain = encoder.encode(0); + evaluator.add_plain_inplace(encrypted1, plain); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted1); + plain = encoder.encode(5); + evaluator.add_plain_inplace(encrypted1, plain); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(5), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(5), encrypted1); + plain = encoder.encode(-3); + evaluator.add_plain_inplace(encrypted1, plain); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(2), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(-7), encrypted1); + plain = encoder.encode(7); + evaluator.add_plain_inplace(encrypted1, plain); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + } + + TEST(EvaluatorTest, FVEncryptSubPlainDecrypt) + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1; + Plaintext plain; + encryptor.encrypt(encoder.encode(0x12345678), encrypted1); + plain = encoder.encode(0x54321); + evaluator.sub_plain_inplace(encrypted1, plain); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0x122F1357), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted1); + plain = encoder.encode(0); + evaluator.sub_plain_inplace(encrypted1, plain); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted1); + plain = encoder.encode(5); + evaluator.sub_plain_inplace(encrypted1, plain); + decryptor.decrypt(encrypted1, plain); + ASSERT_TRUE(static_cast(-5) == encoder.decode_int64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(5), encrypted1); + plain = encoder.encode(-3); + evaluator.sub_plain_inplace(encrypted1, plain); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(8), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(-7), encrypted1); + plain = encoder.encode(2); + evaluator.sub_plain_inplace(encrypted1, plain); + decryptor.decrypt(encrypted1, plain); + ASSERT_TRUE(static_cast(-9) == encoder.decode_int64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + } + + TEST(EvaluatorTest, FVEncryptMultiplyPlainDecrypt) + { + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + encryptor.encrypt(encoder.encode(0x12345678), encrypted); + plain = encoder.encode(0x54321); + evaluator.multiply_plain_inplace(encrypted, plain); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(static_cast(0x5FCBBBB88D78), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted); + plain = encoder.encode(5); + evaluator.multiply_plain_inplace(encrypted, plain); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(7), encrypted); + plain = encoder.encode(1); + evaluator.multiply_plain_inplace(encrypted, plain); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(static_cast(7), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(5), encrypted); + plain = encoder.encode(-3); + evaluator.multiply_plain_inplace(encrypted, plain); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(static_cast(-15) == encoder.decode_int64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(-7), encrypted); + plain = encoder.encode(2); + evaluator.multiply_plain_inplace(encrypted, plain); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(static_cast(-14) == encoder.decode_int64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + } + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus((1ULL << 20) - 1); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ + small_mods_30bit(0), + small_mods_60bit(0), + small_mods_60bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + encryptor.encrypt(encoder.encode(0x12345678), encrypted); + plain = "1"; + evaluator.multiply_plain_inplace(encrypted, plain); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(static_cast(0x12345678), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + plain = "5"; + evaluator.multiply_plain_inplace(encrypted, plain); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(static_cast(0x5B05B058), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + } + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus((1ULL << 40) - 1); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ + small_mods_30bit(0), + small_mods_60bit(0), + small_mods_60bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + encryptor.encrypt(encoder.encode(0x12345678), encrypted); + plain = "1"; + evaluator.multiply_plain_inplace(encrypted, plain); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(static_cast(0x12345678), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + plain = "5"; + evaluator.multiply_plain_inplace(encrypted, plain); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(static_cast(0x5B05B058), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + } + } + + TEST(EvaluatorTest, FVEncryptMultiplyDecrypt) + { + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Plaintext plain; + encryptor.encrypt(encoder.encode(0x12345678), encrypted1); + encryptor.encrypt(encoder.encode(0x54321), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0x5FCBBBB88D78), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted1); + encryptor.encrypt(encoder.encode(0), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted1); + encryptor.encrypt(encoder.encode(5), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(7), encrypted1); + encryptor.encrypt(encoder.encode(1), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(7), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(5), encrypted1); + encryptor.encrypt(encoder.encode(-3), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_TRUE(static_cast(-15) == encoder.decode_int64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0x10000), encrypted1); + encryptor.encrypt(encoder.encode(0x100), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0x1000000), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + } + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus((1ULL << 60) - 1); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ + small_mods_60bit(0), + small_mods_60bit(1), + small_mods_60bit(2) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Plaintext plain; + encryptor.encrypt(encoder.encode(0x12345678), encrypted1); + encryptor.encrypt(encoder.encode(0x54321), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0x5FCBBBB88D78), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted1); + encryptor.encrypt(encoder.encode(0), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted1); + encryptor.encrypt(encoder.encode(5), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(7), encrypted1); + encryptor.encrypt(encoder.encode(1), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(7), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(5), encrypted1); + encryptor.encrypt(encoder.encode(-3), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_TRUE(static_cast(-15) == encoder.decode_int64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0x10000), encrypted1); + encryptor.encrypt(encoder.encode(0x100), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0x1000000), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + } + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Plaintext plain; + encryptor.encrypt(encoder.encode(0x12345678), encrypted1); + encryptor.encrypt(encoder.encode(0x54321), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0x5FCBBBB88D78), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted1); + encryptor.encrypt(encoder.encode(0), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted1); + encryptor.encrypt(encoder.encode(5), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(7), encrypted1); + encryptor.encrypt(encoder.encode(1), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(7), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(5), encrypted1); + encryptor.encrypt(encoder.encode(-3), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_TRUE(static_cast(-15) == encoder.decode_int64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0x10000), encrypted1); + encryptor.encrypt(encoder.encode(0x100), encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(0x1000000), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + } + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1; + Plaintext plain; + encryptor.encrypt(encoder.encode(123), encrypted1); + evaluator.multiply(encrypted1, encrypted1, encrypted1); + evaluator.multiply(encrypted1, encrypted1, encrypted1); + decryptor.decrypt(encrypted1, plain); + ASSERT_EQ(static_cast(228886641), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + } + } + TEST(EvaluatorTest, FVRelinearize) + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1), small_mods_40bit(2) }); + parms.set_noise_standard_deviation(3.20); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + RelinKeys rlk = keygen.relin_keys(60, 3); + + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted(context); + Ciphertext encrypted2(context); + + Plaintext plain; + Plaintext plain2; + + plain = 0; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain == plain2); + + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain == plain2); + + plain = "1x^10 + 2"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain2.to_string() == "1x^20 + 4x^10 + 4"); + + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain2.to_string() == "1x^40 + 8x^30 + 18x^20 + 20x^10 + 10"); + + // Relinearization with modulus switching + plain = "1x^10 + 2"; + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.mod_switch_to_next_inplace(encrypted); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain2.to_string() == "1x^20 + 4x^10 + 4"); + + encryptor.encrypt(plain, encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.mod_switch_to_next_inplace(encrypted); + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.mod_switch_to_next_inplace(encrypted); + decryptor.decrypt(encrypted, plain2); + ASSERT_TRUE(plain2.to_string() == "1x^40 + 8x^30 + 18x^20 + 20x^10 + 10"); + } + TEST(EvaluatorTest, CKKSEncryptNaiveMultiplyDecrypt) + { + EncryptionParameters parms(scheme_type::CKKS); + parms.set_noise_standard_deviation(3.20); + { + //multiplying two zero vectors + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus({ small_mods_30bit(0), small_mods_30bit(1), + small_mods_30bit(2) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + std::vector> input(slot_size, 0.0); + std::vector> output(slot_size); + const double delta = static_cast(1 << 30); + encoder.encode(input, parms.parms_id(), delta, plain); + + encryptor.encrypt(plain, encrypted); + evaluator.multiply_inplace(encrypted, encrypted); + + //check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + { + //multiplying two random vectors + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus({ small_mods_60bit(0), small_mods_60bit(1)}); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + std::vector> input2(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + std::vector> output(slot_size); + const double delta = static_cast(1ULL << 40); + + int data_bound = (1 << 10); + srand(static_cast(time(NULL))); + + for (int round = 0; round < 100; round++) + { + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2[i]; + } + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted1, plainRes); + + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + //multiplying two random vectors + size_t slot_size = 16; + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus({ small_mods_60bit(0), small_mods_60bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + std::vector> input2(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + std::vector> output(slot_size); + const double delta = static_cast(1ULL << 40); + + int data_bound = (1 << 10); + srand(static_cast(time(NULL))); + + for (int round = 0; round < 100; round++) + { + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2[i]; + } + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + evaluator.multiply_inplace(encrypted1, encrypted2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted1, plainRes); + + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + } + + TEST(EvaluatorTest, CKKSEncryptMultiplyByNumberDecrypt) + { + EncryptionParameters parms(scheme_type::CKKS); + parms.set_noise_standard_deviation(3.20); + { + //multiplying two random vectors by an integer + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus({ small_mods_60bit(0), small_mods_60bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted1; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + int64_t input2; + std::vector> expected(slot_size, 0.0); + + int data_bound = (1 << 10); + srand(static_cast(time(NULL))); + + for (int iExp = 0; iExp < 50; iExp++) + { + input2 = max(rand() % data_bound, 1); + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * static_cast(input2); + } + + std::vector> output(slot_size); + const double delta = static_cast(1ULL << 40); + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), plain2); + + encryptor.encrypt(plain1, encrypted1); + evaluator.multiply_plain_inplace(encrypted1, plain2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted1, plainRes); + + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + //multiplying two random vectors by an integer + size_t slot_size = 8; + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus({ small_mods_60bit(0), small_mods_60bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted1; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + int64_t input2; + std::vector> expected(slot_size, 0.0); + + int data_bound = (1 << 10); + srand(static_cast(time(NULL))); + + for (int iExp = 0; iExp < 50; iExp++) + { + input2 = max(rand() % data_bound, 1); + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * static_cast(input2); + } + + std::vector> output(slot_size); + const double delta = static_cast(1ULL << 40); + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), plain2); + + encryptor.encrypt(plain1, encrypted1); + evaluator.multiply_plain_inplace(encrypted1, plain2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted1, plainRes); + + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + //multiplying two random vectors by a double + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus({ small_mods_60bit(0), small_mods_60bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted1; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + double input2; + std::vector> expected(slot_size, 0.0); + std::vector> output(slot_size); + + int data_bound = (1 << 10); + srand(static_cast(time(NULL))); + + for (int iExp = 0; iExp < 50; iExp++) + { + input2 = static_cast(rand() % (data_bound*data_bound)) + /static_cast(data_bound); + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2; + } + + const double delta = static_cast(1ULL << 40); + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + + encryptor.encrypt(plain1, encrypted1); + evaluator.multiply_plain_inplace(encrypted1, plain2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted1, plainRes); + + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + //multiplying two random vectors by a double + size_t slot_size = 16; + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus({ small_mods_60bit(0), small_mods_60bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + Ciphertext encrypted1; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 2.1); + double input2; + std::vector> expected(slot_size, 2.1); + std::vector> output(slot_size); + + int data_bound = (1 << 10); + srand(static_cast(time(NULL))); + + for (int iExp = 0; iExp < 50; iExp++) + { + input2 = static_cast(rand() % (data_bound*data_bound)) + / static_cast(data_bound); + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2; + } + + const double delta = static_cast(1ULL << 40); + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + + encryptor.encrypt(plain1, encrypted1); + evaluator.multiply_plain_inplace(encrypted1, plain2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + + decryptor.decrypt(encrypted1, plainRes); + + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + } + + TEST(EvaluatorTest, CKKSEncryptMultiplyRelinDecrypt) + { + EncryptionParameters parms(scheme_type::CKKS); + parms.set_noise_standard_deviation(3.20); + { + //multiplying two random vectors 50 times + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus({ + small_mods_60bit(0), small_mods_60bit(1)}); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + RelinKeys rlk = keygen.relin_keys(4, 1); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Ciphertext encryptedRes; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + std::vector> input2(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + int data_bound = 1 << 10; + + for (int round = 0; round < 50; round++) + { + srand(static_cast(time(NULL))); + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2[i]; + } + + std::vector> output(slot_size); + const double delta = static_cast(1ULL << 40); + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + //check correctness of encryption + ASSERT_TRUE(encrypted2.parms_id() == parms.parms_id()); + + evaluator.multiply_inplace(encrypted1, encrypted2); + evaluator.relinearize_inplace(encrypted1, rlk); + + decryptor.decrypt(encrypted1, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + //multiplying two random vectors 50 times + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus({ + small_mods_30bit(0), small_mods_30bit(1), + small_mods_30bit(2), small_mods_30bit(3) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + RelinKeys rlk = keygen.relin_keys(4, 1); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Ciphertext encryptedRes; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + std::vector> input2(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + int data_bound = 1 << 10; + + for (int round = 0; round < 50; round++) + { + srand(static_cast(time(NULL))); + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2[i]; + } + + std::vector> output(slot_size); + const double delta = static_cast(1ULL << 40); + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + //check correctness of encryption + ASSERT_TRUE(encrypted2.parms_id() == parms.parms_id()); + + evaluator.multiply_inplace(encrypted1, encrypted2); + evaluator.relinearize_inplace(encrypted1, rlk); + + decryptor.decrypt(encrypted1, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + //multiplying two random vectors 50 times + size_t slot_size = 2; + parms.set_poly_modulus_degree(8); + parms.set_coeff_modulus({ + small_mods_30bit(0), small_mods_30bit(1), + small_mods_30bit(2), small_mods_30bit(3) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + RelinKeys rlk = keygen.relin_keys(4, 1); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Ciphertext encryptedRes; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + std::vector> input2(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + std::vector> output(slot_size); + int data_bound = 1 << 10; + const double delta = static_cast(1ULL << 40); + + for (int round = 0; round < 50; round++) + { + srand(static_cast(time(NULL))); + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2[i]; + } + + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + //check correctness of encryption + ASSERT_TRUE(encrypted2.parms_id() == parms.parms_id()); + + evaluator.multiply_inplace(encrypted1, encrypted2); + //evaluator.relinearize_inplace(encrypted1, rlk); + + decryptor.decrypt(encrypted1, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + } + TEST(EvaluatorTest, CKKSEncryptSquareRelinDecrypt) + { + EncryptionParameters parms(scheme_type::CKKS); + { + //squaring two random vectors 100 times + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus({ + small_mods_60bit(0), small_mods_60bit(1)}); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + RelinKeys rlk = keygen.relin_keys(4, 1); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + std::vector> input(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + + int data_bound = 1 << 7; + srand(static_cast(time(NULL))); + + for (int round = 0; round < 100; round++) + { + for (size_t i = 0; i < slot_size; i++) + { + input[i] = static_cast(rand() % data_bound); + expected[i] = input[i] * input[i]; + } + + std::vector> output(slot_size); + const double delta = static_cast(1ULL << 40); + encoder.encode(input, parms.parms_id(), delta, plain); + + encryptor.encrypt(plain, encrypted); + + //check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + //evaluator.square_inplace(encrypted); + evaluator.multiply_inplace(encrypted, encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + //squaring two random vectors 100 times + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus({ + small_mods_30bit(0), small_mods_30bit(1), + small_mods_30bit(2), small_mods_30bit(3) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + RelinKeys rlk = keygen.relin_keys(4, 1); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + std::vector> input(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + + int data_bound = 1 << 7; + srand(static_cast(time(NULL))); + + for (int round = 0; round < 100; round++) + { + for (size_t i = 0; i < slot_size; i++) + { + input[i] = static_cast(rand() % data_bound); + expected[i] = input[i] * input[i]; + } + + std::vector> output(slot_size); + const double delta = static_cast(1ULL << 40); + encoder.encode(input, parms.parms_id(), delta, plain); + + encryptor.encrypt(plain, encrypted); + + //check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + //evaluator.square_inplace(encrypted); + evaluator.multiply_inplace(encrypted, encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + //squaring two random vectors 100 times + size_t slot_size = 16; + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus({ + small_mods_30bit(0), small_mods_30bit(1), + small_mods_30bit(2), small_mods_30bit(3) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + RelinKeys rlk = keygen.relin_keys(4, 1); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + std::vector> input(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + + int data_bound = 1 << 7; + srand(static_cast(time(NULL))); + + for (int round = 0; round < 100; round++) + { + for (size_t i = 0; i < slot_size; i++) + { + input[i] = static_cast(rand() % data_bound); + expected[i] = input[i] * input[i]; + } + + std::vector> output(slot_size); + const double delta = static_cast(1ULL << 40); + encoder.encode(input, parms.parms_id(), delta, plain); + + encryptor.encrypt(plain, encrypted); + + //check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + //evaluator.square_inplace(encrypted); + evaluator.multiply_inplace(encrypted, encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + } + TEST(EvaluatorTest, CKKSEncryptMultiplyRelinRescaleDecrypt) + { + EncryptionParameters parms(scheme_type::CKKS); + { + //multiplying two random vectors 100 times + size_t slot_size = 64; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus({ + small_mods_30bit(0), small_mods_30bit(1), + small_mods_30bit(2), small_mods_30bit(3) }); + auto context = SEALContext::Create(parms); + auto next_parms_id = context->context_data()-> + next_context_data()->parms().parms_id(); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + RelinKeys rlk = keygen.relin_keys(4, 1); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Ciphertext encryptedRes; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + std::vector> input2(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + + for (int round = 0; round < 100; round++) + { + int data_bound = 1 << 7; + srand(static_cast(time(NULL))); + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2[i]; + } + + std::vector> output(slot_size); + double delta = static_cast(1ULL << 40); + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + //check correctness of encryption + ASSERT_TRUE(encrypted2.parms_id() == parms.parms_id()); + + evaluator.multiply_inplace(encrypted1, encrypted2); + evaluator.relinearize_inplace(encrypted1, rlk); + evaluator.rescale_to_next_inplace(encrypted1); + + //check correctness of modulo switching + ASSERT_TRUE(encrypted1.parms_id() == next_parms_id); + + decryptor.decrypt(encrypted1, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + //multiplying two random vectors 100 times + size_t slot_size = 16; + parms.set_poly_modulus_degree(128); + parms.set_coeff_modulus({ + small_mods_30bit(0), small_mods_30bit(1), + small_mods_30bit(2), small_mods_30bit(3) }); + auto context = SEALContext::Create(parms); + auto next_parms_id = context->context_data()-> + next_context_data()->parms().parms_id(); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + RelinKeys rlk = keygen.relin_keys(4, 1); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Ciphertext encryptedRes; + Plaintext plain1; + Plaintext plain2; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + std::vector> input2(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + + for (int round = 0; round < 100; round++) + { + int data_bound = 1 << 7; + srand(static_cast(time(NULL))); + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2[i]; + } + + std::vector> output(slot_size); + double delta = static_cast(1ULL << 40); + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + //check correctness of encryption + ASSERT_TRUE(encrypted2.parms_id() == parms.parms_id()); + + evaluator.multiply_inplace(encrypted1, encrypted2); + evaluator.relinearize_inplace(encrypted1, rlk); + evaluator.rescale_to_next_inplace(encrypted1); + + //check correctness of modulo switching + ASSERT_TRUE(encrypted1.parms_id() == next_parms_id); + + decryptor.decrypt(encrypted1, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + } + TEST(EvaluatorTest, CKKSEncryptSquareRelinRescaleDecrypt) + { + EncryptionParameters parms(scheme_type::CKKS); + { + //squaring two random vectors 100 times + size_t slot_size = 64; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus({ + small_mods_50bit(0), small_mods_50bit(1), + small_mods_50bit(2) }); + auto context = SEALContext::Create(parms); + auto next_parms_id = context->context_data()-> + next_context_data()->parms().parms_id(); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + RelinKeys rlk = keygen.relin_keys(4, 1); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + std::vector> input(slot_size, 0.0); + std::vector> output(slot_size); + std::vector> expected(slot_size, 0.0); + int data_bound = 1 << 8; + + for (int round = 0; round < 100; round++) + { + srand(static_cast(time(NULL))); + for (size_t i = 0; i < slot_size; i++) + { + input[i] = static_cast(rand() % data_bound); + expected[i] = input[i] * input[i]; + } + + double delta = static_cast(1ULL << 40); + encoder.encode(input, parms.parms_id(), delta, plain); + + encryptor.encrypt(plain, encrypted); + + //check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.rescale_to_next_inplace(encrypted); + + //check correctness of modulo switching + ASSERT_TRUE(encrypted.parms_id() == next_parms_id); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + //squaring two random vectors 100 times + size_t slot_size = 16; + parms.set_poly_modulus_degree(128); + parms.set_coeff_modulus({ + small_mods_50bit(0), small_mods_50bit(1), + small_mods_50bit(2) }); + auto context = SEALContext::Create(parms); + auto next_parms_id = context->context_data()-> + next_context_data()->parms().parms_id(); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + RelinKeys rlk = keygen.relin_keys(4, 1); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + std::vector> input(slot_size, 0.0); + std::vector> output(slot_size); + std::vector> expected(slot_size, 0.0); + int data_bound = 1 << 8; + + for (int round = 0; round < 100; round++) + { + srand(static_cast(time(NULL))); + for (size_t i = 0; i < slot_size; i++) + { + input[i] = static_cast(rand() % data_bound); + expected[i] = input[i] * input[i]; + } + + double delta = static_cast(1ULL << 40); + encoder.encode(input, parms.parms_id(), delta, plain); + + encryptor.encrypt(plain, encrypted); + + //check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + evaluator.square_inplace(encrypted); + evaluator.relinearize_inplace(encrypted, rlk); + evaluator.rescale_to_next_inplace(encrypted); + + //check correctness of modulo switching + ASSERT_TRUE(encrypted.parms_id() == next_parms_id); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + } + TEST(EvaluatorTest, CKKSEncryptModSwitchDecrypt) + { + EncryptionParameters parms(scheme_type::CKKS); + { + //modulo switching without rescaling for random vectors + size_t slot_size = 64; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus({ + small_mods_60bit(0), small_mods_60bit(1), + small_mods_60bit(2), small_mods_60bit(3), small_mods_60bit(4) }); + auto context = SEALContext::Create(parms); + auto next_parms_id = context->context_data()-> + next_context_data()->parms().parms_id(); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + int data_bound = 1 << 30; + srand(static_cast(time(NULL))); + + std::vector> input(slot_size, 0.0); + std::vector> output(slot_size); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + for (int round = 0; round < 100; round++) + { + for (size_t i = 0; i < slot_size; i++) + { + input[i] = static_cast(rand() % data_bound); + } + + double delta = static_cast(1ULL << 40); + encoder.encode(input, parms.parms_id(), delta, plain); + + encryptor.encrypt(plain, encrypted); + + //check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + evaluator.mod_switch_to_next_inplace(encrypted); + + //check correctness of modulo switching + ASSERT_TRUE(encrypted.parms_id() == next_parms_id); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + //modulo switching without rescaling for random vectors + size_t slot_size = 32; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus({ + small_mods_40bit(0), small_mods_40bit(1), small_mods_40bit(2), + small_mods_40bit(3), small_mods_40bit(4) }); + auto context = SEALContext::Create(parms); + auto next_parms_id = context->context_data()-> + next_context_data()->parms().parms_id(); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + int data_bound = 1 << 30; + srand(static_cast(time(NULL))); + + std::vector> input(slot_size, 0.0); + std::vector> output(slot_size); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + for (int round = 0; round < 100; round++) + { + for (size_t i = 0; i < slot_size; i++) + { + input[i] = static_cast(rand() % data_bound); + } + + double delta = static_cast(1ULL << 40); + encoder.encode(input, parms.parms_id(), delta, plain); + + encryptor.encrypt(plain, encrypted); + + //check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + evaluator.mod_switch_to_next_inplace(encrypted); + + //check correctness of modulo switching + ASSERT_TRUE(encrypted.parms_id() == next_parms_id); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + //modulo switching without rescaling for random vectors + size_t slot_size = 32; + parms.set_poly_modulus_degree(128); + parms.set_coeff_modulus({ + small_mods_40bit(0), small_mods_40bit(1), small_mods_40bit(2), + small_mods_40bit(3), small_mods_40bit(4) }); + auto context = SEALContext::Create(parms); + auto next_parms_id = context->context_data()-> + next_context_data()->parms().parms_id(); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + + int data_bound = 1 << 30; + srand(static_cast(time(NULL))); + + std::vector> input(slot_size, 0.0); + std::vector> output(slot_size); + + Ciphertext encrypted; + Plaintext plain; + Plaintext plainRes; + + for (int round = 0; round < 100; round++) + { + for (size_t i = 0; i < slot_size; i++) + { + input[i] = static_cast(rand() % data_bound); + } + + double delta = static_cast(1ULL << 40); + encoder.encode(input, parms.parms_id(), delta, plain); + + encryptor.encrypt(plain, encrypted); + + //check correctness of encryption + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + evaluator.mod_switch_to_next_inplace(encrypted); + + //check correctness of modulo switching + ASSERT_TRUE(encrypted.parms_id() == next_parms_id); + + decryptor.decrypt(encrypted, plainRes); + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(input[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + } + TEST(EvaluatorTest, CKKSEncryptMultiplyRelinRescaleModSwitchAddDecrypt) + { + EncryptionParameters parms(scheme_type::CKKS); + { + //multiplication and addition without rescaling for random vectors + size_t slot_size = 64; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus({ + small_mods_50bit(0), small_mods_50bit(1), + small_mods_50bit(2) }); + auto context = SEALContext::Create(parms); + auto next_parms_id = context->context_data()-> + next_context_data()->parms().parms_id(); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + RelinKeys rlk = keygen.relin_keys(4, 1); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Ciphertext encrypted3; + Plaintext plain1; + Plaintext plain2; + Plaintext plain3; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + std::vector> input2(slot_size, 0.0); + std::vector> input3(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + + for (int round = 0; round < 100; round++) + { + int data_bound = 1 << 8; + srand(static_cast(time(NULL))); + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2[i] + input3[i]; + } + + std::vector> output(slot_size); + double delta = static_cast(1ULL << 40); + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + encoder.encode(input3, parms.parms_id(), delta * delta, plain3); + + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + //check correctness of encryption + ASSERT_TRUE(encrypted2.parms_id() == parms.parms_id()); + //check correctness of encryption + ASSERT_TRUE(encrypted3.parms_id() == parms.parms_id()); + + //enc1*enc2 + evaluator.multiply_inplace(encrypted1, encrypted2); + evaluator.relinearize_inplace(encrypted1, rlk); + evaluator.rescale_to_next_inplace(encrypted1); + + //check correctness of modulo switching with rescaling + ASSERT_TRUE(encrypted1.parms_id() == next_parms_id); + + //move enc3 to the level of enc1 * enc2 + evaluator.rescale_to_inplace(encrypted3, next_parms_id); + + //enc1*enc2 + enc3 + evaluator.add_inplace(encrypted1, encrypted3); + + //decryption + decryptor.decrypt(encrypted1, plainRes); + //decoding + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + { + //multiplication and addition without rescaling for random vectors + size_t slot_size = 16; + parms.set_poly_modulus_degree(128); + parms.set_coeff_modulus({ + small_mods_50bit(0), small_mods_50bit(1), + small_mods_50bit(2) }); + auto context = SEALContext::Create(parms); + auto next_parms_id = context->context_data()-> + next_context_data()->parms().parms_id(); + KeyGenerator keygen(context); + + CKKSEncoder encoder(context); + Encryptor encryptor(context, keygen.public_key()); + Decryptor decryptor(context, keygen.secret_key()); + Evaluator evaluator(context); + RelinKeys rlk = keygen.relin_keys(4, 1); + + Ciphertext encrypted1; + Ciphertext encrypted2; + Ciphertext encrypted3; + Plaintext plain1; + Plaintext plain2; + Plaintext plain3; + Plaintext plainRes; + + std::vector> input1(slot_size, 0.0); + std::vector> input2(slot_size, 0.0); + std::vector> input3(slot_size, 0.0); + std::vector> expected(slot_size, 0.0); + std::vector> output(slot_size); + + for (int round = 0; round < 100; round++) + { + int data_bound = 1 << 8; + srand(static_cast(time(NULL))); + for (size_t i = 0; i < slot_size; i++) + { + input1[i] = static_cast(rand() % data_bound); + input2[i] = static_cast(rand() % data_bound); + expected[i] = input1[i] * input2[i] + input3[i]; + } + + double delta = static_cast(1ULL << 40); + encoder.encode(input1, parms.parms_id(), delta, plain1); + encoder.encode(input2, parms.parms_id(), delta, plain2); + encoder.encode(input3, parms.parms_id(), delta * delta, plain3); + + encryptor.encrypt(plain1, encrypted1); + encryptor.encrypt(plain2, encrypted2); + encryptor.encrypt(plain3, encrypted3); + + //check correctness of encryption + ASSERT_TRUE(encrypted1.parms_id() == parms.parms_id()); + //check correctness of encryption + ASSERT_TRUE(encrypted2.parms_id() == parms.parms_id()); + //check correctness of encryption + ASSERT_TRUE(encrypted3.parms_id() == parms.parms_id()); + + //enc1*enc2 + evaluator.multiply_inplace(encrypted1, encrypted2); + evaluator.relinearize_inplace(encrypted1, rlk); + evaluator.rescale_to_next_inplace(encrypted1); + + //check correctness of modulo switching with rescaling + ASSERT_TRUE(encrypted1.parms_id() == next_parms_id); + + //move enc3 to the level of enc1 * enc2 + evaluator.rescale_to_inplace(encrypted3, next_parms_id); + + //enc1*enc2 + enc3 + evaluator.add_inplace(encrypted1, encrypted3); + + //decryption + decryptor.decrypt(encrypted1, plainRes); + //decoding + encoder.decode(plainRes, output); + + for (size_t i = 0; i < slot_size; i++) + { + auto tmp = abs(expected[i].real() - output[i].real()); + ASSERT_TRUE(tmp < 0.5); + } + } + } + } + TEST(EvaluatorTest, CKKSEncryptRotateDecrypt) + { + EncryptionParameters parms(scheme_type::CKKS); + { + // maximal number of slots + size_t slot_size = 4; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1), small_mods_40bit(2), small_mods_40bit(3) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + GaloisKeys glk = keygen.galois_keys(4); + + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + CKKSEncoder encoder(context); + const double delta = static_cast(1ULL << 30); + + Ciphertext encrypted; + Plaintext plain; + + vector> input{ + std::complex(1, 1), + std::complex(2, 2), + std::complex(3, 3), + std::complex(4, 4) + }; + input.resize(slot_size); + + vector> output(slot_size, 0); + + encoder.encode(input, parms.parms_id(), delta, plain); + int shift = 1; + encryptor.encrypt(plain, encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); + } + + encoder.encode(input, parms.parms_id(), delta, plain); + shift = 2; + encryptor.encrypt(plain, encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); + } + + encoder.encode(input, parms.parms_id(), delta, plain); + shift = 3; + encryptor.encrypt(plain, encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); + } + + encoder.encode(input, parms.parms_id(), delta, plain); + encryptor.encrypt(plain, encrypted); + evaluator.complex_conjugate_inplace(encrypted, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(input[i].real(), round(output[i].real())); + ASSERT_EQ(-input[i].imag(), round(output[i].imag())); + } + } + { + size_t slot_size = 32; + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1), small_mods_40bit(2), small_mods_40bit(3) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + GaloisKeys glk = keygen.galois_keys(4); + + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + CKKSEncoder encoder(context); + const double delta = static_cast(1ULL << 30); + + Ciphertext encrypted; + Plaintext plain; + + vector> input{ + std::complex(1, 1), + std::complex(2, 2), + std::complex(3, 3), + std::complex(4, 4) + }; + input.resize(slot_size); + + vector> output(slot_size, 0); + + encoder.encode(input, parms.parms_id(), delta, plain); + int shift = 1; + encryptor.encrypt(plain, encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < input.size(); i++) + { + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); + } + + encoder.encode(input, parms.parms_id(), delta, plain); + shift = 2; + encryptor.encrypt(plain, encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); + } + + encoder.encode(input, parms.parms_id(), delta, plain); + shift = 3; + encryptor.encrypt(plain, encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); + } + + encoder.encode(input, parms.parms_id(), delta, plain); + encryptor.encrypt(plain, encrypted); + evaluator.complex_conjugate_inplace(encrypted, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(round(input[i].real()), round(output[i].real())); + ASSERT_EQ(round(-input[i].imag()), round(output[i].imag())); + } + } + } + + TEST(EvaluatorTest, CKKSEncryptRescaleRotateDecrypt) + { + EncryptionParameters parms(scheme_type::CKKS); + { + // maximal number of slots + size_t slot_size = 4; + parms.set_poly_modulus_degree(slot_size * 2); + parms.set_coeff_modulus({ small_mods_40bit(0), + small_mods_40bit(1), small_mods_40bit(2), small_mods_40bit(3) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + GaloisKeys glk = keygen.galois_keys(4); + + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + CKKSEncoder encoder(context); + const double delta = std::pow(2.0, 70); + + Ciphertext encrypted; + Plaintext plain; + + vector> input{ + std::complex(1, 1), + std::complex(2, 2), + std::complex(3, 3), + std::complex(4, 4) + }; + input.resize(slot_size); + + vector> output(slot_size, 0); + + encoder.encode(input, parms.parms_id(), delta, plain); + int shift = 1; + encryptor.encrypt(plain, encrypted); + evaluator.rescale_to_next_inplace(encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); + } + + encoder.encode(input, parms.parms_id(), delta, plain); + shift = 2; + encryptor.encrypt(plain, encrypted); + evaluator.rescale_to_next_inplace(encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); + } + + encoder.encode(input, parms.parms_id(), delta, plain); + shift = 3; + encryptor.encrypt(plain, encrypted); + evaluator.rescale_to_next_inplace(encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); + ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); + } + + encoder.encode(input, parms.parms_id(), delta, plain); + encryptor.encrypt(plain, encrypted); + evaluator.rescale_to_next_inplace(encrypted); + evaluator.complex_conjugate_inplace(encrypted, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(input[i].real(), round(output[i].real())); + ASSERT_EQ(-input[i].imag(), round(output[i].imag())); + } + } + { + size_t slot_size = 32; + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1), + small_mods_40bit(2), small_mods_40bit(3) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + GaloisKeys glk = keygen.galois_keys(4); + + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + CKKSEncoder encoder(context); + const double delta = std::pow(2, 70); + + Ciphertext encrypted; + Plaintext plain; + + vector> input{ + std::complex(1, 1), + std::complex(2, 2), + std::complex(3, 3), + std::complex(4, 4) + }; + input.resize(slot_size); + + vector> output(slot_size, 0); + + encoder.encode(input, parms.parms_id(), delta, plain); + int shift = 1; + encryptor.encrypt(plain, encrypted); + evaluator.rescale_to_next_inplace(encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); + } + + encoder.encode(input, parms.parms_id(), delta, plain); + shift = 2; + encryptor.encrypt(plain, encrypted); + evaluator.rescale_to_next_inplace(encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); + } + + encoder.encode(input, parms.parms_id(), delta, plain); + shift = 3; + encryptor.encrypt(plain, encrypted); + evaluator.rescale_to_next_inplace(encrypted); + evaluator.rotate_vector_inplace(encrypted, shift, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); + ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); + } + + encoder.encode(input, parms.parms_id(), delta, plain); + encryptor.encrypt(plain, encrypted); + evaluator.rescale_to_next_inplace(encrypted); + evaluator.complex_conjugate_inplace(encrypted, glk); + decryptor.decrypt(encrypted, plain); + encoder.decode(plain, output); + for (size_t i = 0; i < slot_size; i++) + { + ASSERT_EQ(round(input[i].real()), round(output[i].real())); + ASSERT_EQ(round(-input[i].imag()), round(output[i].imag())); + } + } + } + + TEST(EvaluatorTest, FVEncryptSquareDecrypt) + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted; + Plaintext plain; + encryptor.encrypt(encoder.encode(1), encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(1ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0), encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(-5), encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(25ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(-1), encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(1ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(123), encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(15129ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0x10000), encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(0x100000000ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(123), encrypted); + evaluator.square_inplace(encrypted); + evaluator.square_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(228886641ULL, encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + } + + TEST(EvaluatorTest, FVEncryptMultiplyManyDecrypt) + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + RelinKeys rlk = keygen.relin_keys(4); + + Ciphertext encrypted1, encrypted2, encrypted3, encrypted4, product; + Plaintext plain; + encryptor.encrypt(encoder.encode(5), encrypted1); + encryptor.encrypt(encoder.encode(6), encrypted2); + encryptor.encrypt(encoder.encode(7), encrypted3); + vector encrypteds{ encrypted1, encrypted2, encrypted3 }; + evaluator.multiply_many(encrypteds, rlk, product); + decryptor.decrypt(product, plain); + ASSERT_EQ(static_cast(210), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); + ASSERT_TRUE(product.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(-9), encrypted1); + encryptor.encrypt(encoder.encode(-17), encrypted2); + encrypteds = { encrypted1, encrypted2 }; + evaluator.multiply_many(encrypteds, rlk, product); + decryptor.decrypt(product, plain); + ASSERT_EQ(static_cast(153), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); + ASSERT_TRUE(product.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(2), encrypted1); + encryptor.encrypt(encoder.encode(-31), encrypted2); + encryptor.encrypt(encoder.encode(7), encrypted3); + encrypteds = { encrypted1, encrypted2, encrypted3 }; + evaluator.multiply_many(encrypteds, rlk, product); + decryptor.decrypt(product, plain); + ASSERT_TRUE(static_cast(-434) == encoder.decode_int64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); + ASSERT_TRUE(product.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(1), encrypted1); + encryptor.encrypt(encoder.encode(-1), encrypted2); + encryptor.encrypt(encoder.encode(1), encrypted3); + encryptor.encrypt(encoder.encode(-1), encrypted4); + encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; + evaluator.multiply_many(encrypteds, rlk, product); + decryptor.decrypt(product, plain); + ASSERT_EQ(static_cast(1), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted4.parms_id() == product.parms_id()); + ASSERT_TRUE(product.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(98765), encrypted1); + encryptor.encrypt(encoder.encode(0), encrypted2); + encryptor.encrypt(encoder.encode(12345), encrypted3); + encryptor.encrypt(encoder.encode(34567), encrypted4); + encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; + evaluator.multiply_many(encrypteds, rlk, product); + decryptor.decrypt(product, plain); + ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); + ASSERT_TRUE(encrypted4.parms_id() == product.parms_id()); + ASSERT_TRUE(product.parms_id() == parms.parms_id()); + } + + TEST(EvaluatorTest, FVEncryptExponentiateDecrypt) + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + RelinKeys rlk = keygen.relin_keys(4); + + Ciphertext encrypted; + Plaintext plain; + encryptor.encrypt(encoder.encode(5), encrypted); + evaluator.exponentiate_inplace(encrypted, 1, rlk); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(static_cast(5), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == encrypted.parms_id()); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(7), encrypted); + evaluator.exponentiate_inplace(encrypted, 2, rlk); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(static_cast(49), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == encrypted.parms_id()); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(-7), encrypted); + evaluator.exponentiate_inplace(encrypted, 3, rlk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(static_cast(-343) == encoder.decode_int64(plain)); + ASSERT_TRUE(encrypted.parms_id() == encrypted.parms_id()); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(0x100), encrypted); + evaluator.exponentiate_inplace(encrypted, 4, rlk); + decryptor.decrypt(encrypted, plain); + ASSERT_EQ(static_cast(0x100000000), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted.parms_id() == encrypted.parms_id()); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + } + + TEST(EvaluatorTest, FVEncryptAddManyDecrypt) + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + BalancedEncoder encoder(plain_modulus); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Ciphertext encrypted1, encrypted2, encrypted3, encrypted4, sum; + Plaintext plain; + encryptor.encrypt(encoder.encode(5), encrypted1); + encryptor.encrypt(encoder.encode(6), encrypted2); + encryptor.encrypt(encoder.encode(7), encrypted3); + vector encrypteds = { encrypted1, encrypted2, encrypted3 }; + evaluator.add_many(encrypteds, sum); + decryptor.decrypt(sum, plain); + ASSERT_EQ(static_cast(18), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); + ASSERT_TRUE(sum.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(-9), encrypted1); + encryptor.encrypt(encoder.encode(-17), encrypted2); + encrypteds = { encrypted1, encrypted2, }; + evaluator.add_many(encrypteds, sum); + decryptor.decrypt(sum, plain); + ASSERT_TRUE(static_cast(-26) == encoder.decode_int64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); + ASSERT_TRUE(sum.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(2), encrypted1); + encryptor.encrypt(encoder.encode(-31), encrypted2); + encryptor.encrypt(encoder.encode(7), encrypted3); + encrypteds = { encrypted1, encrypted2, encrypted3 }; + evaluator.add_many(encrypteds, sum); + decryptor.decrypt(sum, plain); + ASSERT_TRUE(static_cast(-22) == encoder.decode_int64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); + ASSERT_TRUE(sum.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(1), encrypted1); + encryptor.encrypt(encoder.encode(-1), encrypted2); + encryptor.encrypt(encoder.encode(1), encrypted3); + encryptor.encrypt(encoder.encode(-1), encrypted4); + encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; + evaluator.add_many(encrypteds, sum); + decryptor.decrypt(sum, plain); + ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted4.parms_id() == sum.parms_id()); + ASSERT_TRUE(sum.parms_id() == parms.parms_id()); + + encryptor.encrypt(encoder.encode(98765), encrypted1); + encryptor.encrypt(encoder.encode(0), encrypted2); + encryptor.encrypt(encoder.encode(12345), encrypted3); + encryptor.encrypt(encoder.encode(34567), encrypted4); + encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; + evaluator.add_many(encrypteds, sum); + decryptor.decrypt(sum, plain); + ASSERT_EQ(static_cast(145677), encoder.decode_uint64(plain)); + ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted4.parms_id() == sum.parms_id()); + ASSERT_TRUE(sum.parms_id() == parms.parms_id()); + + BalancedFractionalEncoder frac_encoder(plain_modulus, 128, 10, 15); + encryptor.encrypt(frac_encoder.encode(3.1415), encrypted1); + encryptor.encrypt(frac_encoder.encode(12.345), encrypted2); + encryptor.encrypt(frac_encoder.encode(98.765), encrypted3); + encryptor.encrypt(frac_encoder.encode(1.1111), encrypted4); + encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; + evaluator.add_many(encrypteds, sum); + decryptor.decrypt(sum, plain); + ASSERT_TRUE(abs(frac_encoder.decode(plain) - 115.3626) < 0.000001); + ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); + ASSERT_TRUE(encrypted4.parms_id() == sum.parms_id()); + ASSERT_TRUE(sum.parms_id() == parms.parms_id()); + } + + TEST(EvaluatorTest, TransformPlainToNTT) + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + Evaluator evaluator(context); + Plaintext plain("0"); + ASSERT_FALSE(plain.is_ntt_form()); + evaluator.transform_to_ntt_inplace(plain, parms.parms_id()); + ASSERT_TRUE(plain.is_zero()); + ASSERT_TRUE(plain.is_ntt_form()); + ASSERT_TRUE(plain.parms_id() == parms.parms_id()); + + plain.release(); + plain = "0"; + ASSERT_FALSE(plain.is_ntt_form()); + auto next_parms_id = context->context_data()-> + next_context_data()->parms().parms_id(); + evaluator.transform_to_ntt_inplace(plain, next_parms_id); + ASSERT_TRUE(plain.is_zero()); + ASSERT_TRUE(plain.is_ntt_form()); + ASSERT_TRUE(plain.parms_id() == next_parms_id); + + plain.release(); + plain = "1"; + ASSERT_FALSE(plain.is_ntt_form()); + evaluator.transform_to_ntt_inplace(plain, parms.parms_id()); + for (size_t i = 0; i < 256; i++) + { + ASSERT_TRUE(plain[i] == uint64_t(1)); + } + ASSERT_TRUE(plain.is_ntt_form()); + ASSERT_TRUE(plain.parms_id() == parms.parms_id()); + + plain.release(); + plain = "1"; + ASSERT_FALSE(plain.is_ntt_form()); + evaluator.transform_to_ntt_inplace(plain, next_parms_id); + for (size_t i = 0; i < 128; i++) + { + ASSERT_TRUE(plain[i] == uint64_t(1)); + } + ASSERT_TRUE(plain.is_ntt_form()); + ASSERT_TRUE(plain.parms_id() == next_parms_id); + + plain.release(); + plain = "2"; + ASSERT_FALSE(plain.is_ntt_form()); + evaluator.transform_to_ntt_inplace(plain, parms.parms_id()); + for (size_t i = 0; i < 256; i++) + { + ASSERT_TRUE(plain[i] == uint64_t(2)); + } + ASSERT_TRUE(plain.is_ntt_form()); + ASSERT_TRUE(plain.parms_id() == parms.parms_id()); + + plain.release(); + plain = "2"; + evaluator.transform_to_ntt_inplace(plain, next_parms_id); + for (size_t i = 0; i < 128; i++) + { + ASSERT_TRUE(plain[i] == uint64_t(2)); + } + ASSERT_TRUE(plain.is_ntt_form()); + ASSERT_TRUE(plain.parms_id() == next_parms_id); + } + + TEST(EvaluatorTest, TransformEncryptedToFromNTT) + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Plaintext plain; + Ciphertext encrypted; + plain = "0"; + encryptor.encrypt(plain, encrypted); + evaluator.transform_to_ntt_inplace(encrypted); + evaluator.transform_from_ntt_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(plain.to_string() == "0"); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + plain = "1"; + encryptor.encrypt(plain, encrypted); + evaluator.transform_to_ntt_inplace(encrypted); + evaluator.transform_from_ntt_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(plain.to_string() == "1"); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + plain = "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"; + encryptor.encrypt(plain, encrypted); + evaluator.transform_to_ntt_inplace(encrypted); + evaluator.transform_from_ntt_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(plain.to_string() == "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + } + + TEST(EvaluatorTest, FVEncryptMultiplyPlainNTTDecrypt) + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(1 << 6); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Plaintext plain; + Plaintext plain_multiplier; + Ciphertext encrypted; + + plain = 0; + encryptor.encrypt(plain, encrypted); + evaluator.transform_to_ntt_inplace(encrypted); + plain_multiplier = 1; + evaluator.transform_to_ntt_inplace(plain_multiplier, parms.parms_id()); + evaluator.multiply_plain_inplace(encrypted, plain_multiplier); + evaluator.transform_from_ntt_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(plain.to_string() == "0"); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + plain = 2; + encryptor.encrypt(plain, encrypted); + evaluator.transform_to_ntt_inplace(encrypted); + plain_multiplier.release(); + plain_multiplier = 3; + evaluator.transform_to_ntt_inplace(plain_multiplier, parms.parms_id()); + evaluator.multiply_plain_inplace(encrypted, plain_multiplier); + evaluator.transform_from_ntt_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(plain.to_string() == "6"); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + plain = 1; + encryptor.encrypt(plain, encrypted); + evaluator.transform_to_ntt_inplace(encrypted); + plain_multiplier.release(); + plain_multiplier = "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"; + evaluator.transform_to_ntt_inplace(plain_multiplier, parms.parms_id()); + evaluator.multiply_plain_inplace(encrypted, plain_multiplier); + evaluator.transform_from_ntt_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(plain.to_string() == "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + + plain = "1x^20"; + encryptor.encrypt(plain, encrypted); + evaluator.transform_to_ntt_inplace(encrypted); + plain_multiplier.release(); + plain_multiplier = "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"; + evaluator.transform_to_ntt_inplace(plain_multiplier, parms.parms_id()); + evaluator.multiply_plain_inplace(encrypted, plain_multiplier); + evaluator.transform_from_ntt_inplace(encrypted); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(plain.to_string() == "Fx^30 + Ex^29 + Dx^28 + Cx^27 + Bx^26 + Ax^25 + 1x^24 + 2x^23 + 3x^22 + 4x^21 + 5x^20"); + ASSERT_TRUE(encrypted.parms_id() == parms.parms_id()); + } + + TEST(EvaluatorTest, FVEncryptApplyGaloisDecrypt) + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(257); + parms.set_poly_modulus_degree(8); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + GaloisKeys glk = keygen.galois_keys(24, vector{ 1, 3, 5, 15 }); + + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + + Plaintext plain("1"); + Ciphertext encrypted; + encryptor.encrypt(plain, encrypted); + evaluator.apply_galois_inplace(encrypted, 1, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 3, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 5, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 15, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1" == plain.to_string()); + + plain = "1x^1"; + encryptor.encrypt(plain, encrypted); + evaluator.apply_galois_inplace(encrypted, 1, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1x^1" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 3, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1x^3" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 5, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("100x^7" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 15, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1x^1" == plain.to_string()); + + plain = "1x^2"; + encryptor.encrypt(plain, encrypted); + evaluator.apply_galois_inplace(encrypted, 1, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1x^2" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 3, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1x^6" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 5, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("100x^6" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 15, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1x^2" == plain.to_string()); + + plain = "1x^3 + 2x^2 + 1x^1 + 1"; + encryptor.encrypt(plain, encrypted); + evaluator.apply_galois_inplace(encrypted, 1, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1x^3 + 2x^2 + 1x^1 + 1" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 3, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("2x^6 + 1x^3 + 100x^1 + 1" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 5, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("100x^7 + FFx^6 + 100x^5 + 1" == plain.to_string()); + evaluator.apply_galois_inplace(encrypted, 15, glk); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE("1x^3 + 2x^2 + 1x^1 + 1" == plain.to_string()); + } + + TEST(EvaluatorTest, FVEncryptRotateMatrixDecrypt) + { + EncryptionParameters parms(scheme_type::BFV); + SmallModulus plain_modulus(257); + parms.set_poly_modulus_degree(8); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + GaloisKeys glk = keygen.galois_keys(24); + + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + BatchEncoder batch_encoder(context); + + Plaintext plain; + vector plain_vec{ + 1, 2, 3, 4, + 5, 6, 7, 8 + }; + batch_encoder.encode(plain_vec, plain); + Ciphertext encrypted; + encryptor.encrypt(plain, encrypted); + + evaluator.rotate_columns_inplace(encrypted, glk); + decryptor.decrypt(encrypted, plain); + batch_encoder.decode(plain, plain_vec); + ASSERT_TRUE((plain_vec == vector{ + 5, 6, 7, 8, + 1, 2, 3, 4 + })); + + evaluator.rotate_rows_inplace(encrypted, -1, glk); + decryptor.decrypt(encrypted, plain); + batch_encoder.decode(plain, plain_vec); + ASSERT_TRUE((plain_vec == vector{ + 8, 5, 6, 7, + 4, 1, 2, 3 + })); + + evaluator.rotate_rows_inplace(encrypted, 2, glk); + decryptor.decrypt(encrypted, plain); + batch_encoder.decode(plain, plain_vec); + ASSERT_TRUE((plain_vec == vector{ + 6, 7, 8, 5, + 2, 3, 4, 1 + })); + + evaluator.rotate_columns_inplace(encrypted, glk); + decryptor.decrypt(encrypted, plain); + batch_encoder.decode(plain, plain_vec); + ASSERT_TRUE((plain_vec == vector{ + 2, 3, 4, 1, + 6, 7, 8, 5 + })); + + evaluator.rotate_rows_inplace(encrypted, 0, glk); + decryptor.decrypt(encrypted, plain); + batch_encoder.decode(plain, plain_vec); + ASSERT_TRUE((plain_vec == vector{ + 2, 3, 4, 1, + 6, 7, 8, 5 + })); + } + TEST(EvaluatorTest, FVEncryptModSwitchToNextDecrypt) + { + // the common parameters: the plaintext and the polynomial moduli + SmallModulus plain_modulus(1 << 6); + + // the parameters and the context of the higher level + EncryptionParameters parms(scheme_type::BFV); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_30bit(0), small_mods_30bit(1), small_mods_30bit(2) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + SecretKey secret_key = keygen.secret_key(); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + auto parms_id = parms.parms_id(); + + Ciphertext encrypted(context); + Ciphertext encryptedRes; + Plaintext plain; + + plain = 0; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_next(encrypted, encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context->context_data(parms_id)-> + next_context_data()->parms().parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "0"); + + evaluator.mod_switch_to_next_inplace(encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context->context_data(parms_id)-> + next_context_data()->parms().parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "0"); + + parms_id = parms.parms_id(); + plain = 1; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_next(encrypted, encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context->context_data(parms_id)-> + next_context_data()->parms().parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1"); + + evaluator.mod_switch_to_next_inplace(encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context->context_data(parms_id)-> + next_context_data()->parms().parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1"); + + parms_id = parms.parms_id(); + plain = "1x^127"; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_next(encrypted, encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context->context_data(parms_id)-> + next_context_data()->parms().parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1x^127"); + + evaluator.mod_switch_to_next_inplace(encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context->context_data(parms_id)-> + next_context_data()->parms().parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1x^127"); + + parms_id = parms.parms_id(); + plain = "5x^64 + Ax^5"; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_next(encrypted, encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context->context_data(parms_id)-> + next_context_data()->parms().parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); + + evaluator.mod_switch_to_next_inplace(encryptedRes); + decryptor.decrypt(encryptedRes, plain); + parms_id = context->context_data(parms_id)-> + next_context_data()->parms().parms_id(); + ASSERT_TRUE(encryptedRes.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); + } + + TEST(EvaluatorTest, FVEncryptModSwitchToDecrypt) + { + // the common parameters: the plaintext and the polynomial moduli + SmallModulus plain_modulus(1 << 6); + + // the parameters and the context of the higher level + EncryptionParameters parms(scheme_type::BFV); + parms.set_poly_modulus_degree(128); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ small_mods_30bit(0), small_mods_30bit(1), small_mods_30bit(2) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + SecretKey secret_key = keygen.secret_key(); + Encryptor encryptor(context, keygen.public_key()); + Evaluator evaluator(context); + Decryptor decryptor(context, keygen.secret_key()); + auto parms_id = parms.parms_id(); + + Ciphertext encrypted(context); + Plaintext plain; + + plain = 0; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "0"); + + parms_id = context->context_data(parms_id)-> + next_context_data()->parms().parms_id(); + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "0"); + + parms_id = context->context_data(parms_id)-> + next_context_data()->parms().parms_id(); + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "0"); + + parms_id = parms.parms_id(); + encryptor.encrypt(plain, encrypted); + parms_id = context->context_data(parms_id)-> + next_context_data()-> + next_context_data()->parms().parms_id(); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "0"); + + parms_id = parms.parms_id(); + plain = 1; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1"); + + parms_id = context->context_data(parms_id)-> + next_context_data()->parms().parms_id(); + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1"); + + parms_id = context->context_data(parms_id)-> + next_context_data()->parms().parms_id(); + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1"); + + parms_id = parms.parms_id(); + encryptor.encrypt(plain, encrypted); + parms_id = context->context_data(parms_id)-> + next_context_data()-> + next_context_data()->parms().parms_id(); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1"); + + parms_id = parms.parms_id(); + plain = "1x^127"; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1x^127"); + + parms_id = context->context_data(parms_id)-> + next_context_data()->parms().parms_id(); + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1x^127"); + + parms_id = context->context_data(parms_id)-> + next_context_data()->parms().parms_id(); + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1x^127"); + + parms_id = parms.parms_id(); + encryptor.encrypt(plain, encrypted); + parms_id = context->context_data(parms_id)-> + next_context_data()-> + next_context_data()->parms().parms_id(); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "1x^127"); + + parms_id = parms.parms_id(); + plain = "5x^64 + Ax^5"; + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); + + parms_id = context->context_data(parms_id)-> + next_context_data()->parms().parms_id(); + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); + + parms_id = context->context_data(parms_id)-> + next_context_data()->parms().parms_id(); + encryptor.encrypt(plain, encrypted); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); + + parms_id = parms.parms_id(); + encryptor.encrypt(plain, encrypted); + parms_id = context->context_data(parms_id)-> + next_context_data()-> + next_context_data()->parms().parms_id(); + evaluator.mod_switch_to_inplace(encrypted, parms_id); + decryptor.decrypt(encrypted, plain); + ASSERT_TRUE(encrypted.parms_id() == parms_id); + ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); + } +} diff --git a/tests/seal/galoiskeys.cpp b/tests/seal/galoiskeys.cpp new file mode 100644 index 000000000..d21cc6145 --- /dev/null +++ b/tests/seal/galoiskeys.cpp @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/galoiskeys.h" +#include "seal/context.h" +#include "seal/keygenerator.h" +#include "seal/util/uintcore.h" +#include "seal/defaultparams.h" +#include + +using namespace seal; +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + TEST(GaloisKeysTest, GaloisKeysSaveLoad) + { + stringstream stream; + { + EncryptionParameters parms(scheme_type::BFV); + parms.set_noise_standard_deviation(3.20); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(65537); + parms.set_coeff_modulus({ small_mods_60bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + GaloisKeys keys; + GaloisKeys test_keys; + ASSERT_EQ(keys.decomposition_bit_count(), 0); + keys.save(stream); + test_keys.unsafe_load(stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(keys.decomposition_bit_count(), test_keys.decomposition_bit_count()); + ASSERT_EQ(0ULL, keys.size()); + + keys = keygen.galois_keys(1); + ASSERT_EQ(keys.decomposition_bit_count(), 1); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(keys.decomposition_bit_count(), test_keys.decomposition_bit_count()); + for (size_t j = 0; j < test_keys.size(); j++) + { + for (size_t i = 0; i < test_keys.data()[j].size(); i++) + { + ASSERT_EQ(keys.data()[j][i].size(), test_keys.data()[j][i].size()); + ASSERT_EQ(keys.data()[j][i].uint64_count(), test_keys.data()[j][i].uint64_count()); + ASSERT_TRUE(is_equal_uint_uint(keys.data()[j][i].data(), test_keys.data()[j][i].data(), keys.data()[j][i].uint64_count())); + } + } + ASSERT_EQ(10ULL, keys.size()); + + keys = keygen.galois_keys(8); + ASSERT_EQ(keys.decomposition_bit_count(), 8); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(keys.decomposition_bit_count(), test_keys.decomposition_bit_count()); + for (size_t j = 0; j < test_keys.size(); j++) + { + for (size_t i = 0; i < test_keys.data()[j].size(); i++) + { + ASSERT_EQ(keys.data()[j][i].size(), test_keys.data()[j][i].size()); + ASSERT_EQ(keys.data()[j][i].uint64_count(), test_keys.data()[j][i].uint64_count()); + ASSERT_TRUE(is_equal_uint_uint(keys.data()[j][i].data(), test_keys.data()[j][i].data(), keys.data()[j][i].uint64_count())); + } + } + ASSERT_EQ(10ULL, keys.size()); + + keys = keygen.galois_keys(60); + ASSERT_EQ(keys.decomposition_bit_count(), 60); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(keys.decomposition_bit_count(), test_keys.decomposition_bit_count()); + for (size_t j = 0; j < test_keys.size(); j++) + { + for (size_t i = 0; i < test_keys.data()[j].size(); i++) + { + ASSERT_EQ(keys.data()[j][i].size(), test_keys.data()[j][i].size()); + ASSERT_EQ(keys.data()[j][i].uint64_count(), test_keys.data()[j][i].uint64_count()); + ASSERT_TRUE(is_equal_uint_uint(keys.data()[j][i].data(), test_keys.data()[j][i].data(), keys.data()[j][i].uint64_count())); + } + } + ASSERT_EQ(10ULL, keys.size()); + } + { + EncryptionParameters parms(scheme_type::BFV); + parms.set_noise_standard_deviation(3.20); + parms.set_poly_modulus_degree(256); + parms.set_plain_modulus(65537); + parms.set_coeff_modulus({ small_mods_60bit(0), small_mods_50bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + GaloisKeys keys; + GaloisKeys test_keys; + ASSERT_EQ(keys.decomposition_bit_count(), 0); + keys.save(stream); + test_keys.unsafe_load(stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(keys.decomposition_bit_count(), test_keys.decomposition_bit_count()); + ASSERT_EQ(0ULL, keys.size()); + + keys = keygen.galois_keys(8); + ASSERT_EQ(keys.decomposition_bit_count(), 8); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(keys.decomposition_bit_count(), test_keys.decomposition_bit_count()); + for (size_t j = 0; j < test_keys.size(); j++) + { + for (size_t i = 0; i < test_keys.data()[j].size(); i++) + { + ASSERT_EQ(keys.data()[j][i].size(), test_keys.data()[j][i].size()); + ASSERT_EQ(keys.data()[j][i].uint64_count(), test_keys.data()[j][i].uint64_count()); + ASSERT_TRUE(is_equal_uint_uint(keys.data()[j][i].data(), test_keys.data()[j][i].data(), keys.data()[j][i].uint64_count())); + } + } + ASSERT_EQ(14ULL, keys.size()); + + keys = keygen.galois_keys(60); + ASSERT_EQ(keys.decomposition_bit_count(), 60); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(keys.decomposition_bit_count(), test_keys.decomposition_bit_count()); + for (size_t j = 0; j < test_keys.size(); j++) + { + for (size_t i = 0; i < test_keys.data()[j].size(); i++) + { + ASSERT_EQ(keys.data()[j][i].size(), test_keys.data()[j][i].size()); + ASSERT_EQ(keys.data()[j][i].uint64_count(), test_keys.data()[j][i].uint64_count()); + ASSERT_TRUE(is_equal_uint_uint(keys.data()[j][i].data(), test_keys.data()[j][i].data(), keys.data()[j][i].uint64_count())); + } + } + ASSERT_EQ(14ULL, keys.size()); + } + } +} diff --git a/tests/seal/intarray.cpp b/tests/seal/intarray.cpp new file mode 100644 index 000000000..24b3b4bc7 --- /dev/null +++ b/tests/seal/intarray.cpp @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/intarray.h" +#include "seal/memorymanager.h" +#include + +using namespace seal; +using namespace std; + +namespace SEALTest +{ + TEST(IntArrayTest, IntArrayBasics) + { + { + auto pool = MemoryPoolHandle::New(); + MemoryManager::SwitchProfile(new MMProfFixed(pool)); + IntArray arr; + ASSERT_TRUE(arr.begin() == nullptr); + ASSERT_TRUE(arr.end() == nullptr); + ASSERT_EQ(0ULL, arr.size()); + ASSERT_EQ(0ULL, arr.capacity()); + ASSERT_TRUE(arr.empty()); + + arr.resize(1); + ASSERT_FALSE(arr.begin() == nullptr); + ASSERT_FALSE(arr.end() == nullptr); + ASSERT_FALSE(arr.begin() == arr.end()); + ASSERT_EQ(1, static_cast(arr.end() - arr.begin())); + ASSERT_EQ(1ULL, arr.size()); + ASSERT_EQ(1ULL, arr.capacity()); + ASSERT_FALSE(arr.empty()); + ASSERT_EQ(0, arr[0]); + arr.at(0) = 1; + ASSERT_EQ(1, arr[0]); + ASSERT_EQ(4, static_cast(pool.alloc_byte_count())); + + arr.reserve(6); + ASSERT_FALSE(arr.begin() == nullptr); + ASSERT_FALSE(arr.end() == nullptr); + ASSERT_FALSE(arr.begin() == arr.end()); + ASSERT_EQ(1, static_cast(arr.end() - arr.begin())); + ASSERT_EQ(1ULL, arr.size()); + ASSERT_EQ(6ULL, arr.capacity()); + ASSERT_FALSE(arr.empty()); + ASSERT_EQ(1, arr[0]); + ASSERT_EQ(28, static_cast(pool.alloc_byte_count())); + + arr.resize(4); + ASSERT_FALSE(arr.begin() == nullptr); + ASSERT_FALSE(arr.end() == nullptr); + ASSERT_FALSE(arr.begin() == arr.end()); + ASSERT_EQ(4, static_cast(arr.end() - arr.begin())); + ASSERT_EQ(4ULL, arr.size()); + ASSERT_EQ(6ULL, arr.capacity()); + ASSERT_FALSE(arr.empty()); + arr.at(0) = 0; + arr.at(1) = 1; + arr.at(2) = 2; + arr.at(3) = 3; + ASSERT_EQ(0, arr[0]); + ASSERT_EQ(1, arr[1]); + ASSERT_EQ(2, arr[2]); + ASSERT_EQ(3, arr[3]); + ASSERT_EQ(28, static_cast(pool.alloc_byte_count())); + + arr.shrink_to_fit(); + ASSERT_FALSE(arr.begin() == nullptr); + ASSERT_FALSE(arr.end() == nullptr); + ASSERT_FALSE(arr.begin() == arr.end()); + ASSERT_EQ(4, static_cast(arr.end() - arr.begin())); + ASSERT_EQ(4ULL, arr.size()); + ASSERT_EQ(4ULL, arr.capacity()); + ASSERT_FALSE(arr.empty()); + ASSERT_EQ(0, arr[0]); + ASSERT_EQ(1, arr[1]); + ASSERT_EQ(2, arr[2]); + ASSERT_EQ(3, arr[3]); + ASSERT_EQ(44, static_cast(pool.alloc_byte_count())); + } + { + auto pool = MemoryPoolHandle::New(); + MemoryManager::SwitchProfile(new MMProfFixed(pool)); + IntArray arr; + ASSERT_TRUE(arr.begin() == nullptr); + ASSERT_TRUE(arr.end() == nullptr); + ASSERT_EQ(0ULL, arr.size()); + ASSERT_EQ(0ULL, arr.capacity()); + ASSERT_TRUE(arr.empty()); + + arr.resize(1); + ASSERT_FALSE(arr.begin() == nullptr); + ASSERT_FALSE(arr.end() == nullptr); + ASSERT_FALSE(arr.begin() == arr.end()); + ASSERT_EQ(1, static_cast(arr.end() - arr.begin())); + ASSERT_EQ(1ULL, arr.size()); + ASSERT_EQ(1ULL, arr.capacity()); + ASSERT_FALSE(arr.empty()); + ASSERT_EQ(0ULL, arr[0]); + arr.at(0) = 1; + ASSERT_EQ(1ULL, arr[0]); + ASSERT_EQ(8, static_cast(pool.alloc_byte_count())); + + arr.reserve(6); + ASSERT_FALSE(arr.begin() == nullptr); + ASSERT_FALSE(arr.end() == nullptr); + ASSERT_FALSE(arr.begin() == arr.end()); + ASSERT_EQ(1, static_cast(arr.end() - arr.begin())); + ASSERT_EQ(1ULL, arr.size()); + ASSERT_EQ(6ULL, arr.capacity()); + ASSERT_FALSE(arr.empty()); + ASSERT_EQ(1ULL, arr[0]); + ASSERT_EQ(56, static_cast(pool.alloc_byte_count())); + + arr.resize(4); + ASSERT_FALSE(arr.begin() == nullptr); + ASSERT_FALSE(arr.end() == nullptr); + ASSERT_FALSE(arr.begin() == arr.end()); + ASSERT_EQ(4, static_cast(arr.end() - arr.begin())); + ASSERT_EQ(4ULL, arr.size()); + ASSERT_EQ(6ULL, arr.capacity()); + ASSERT_FALSE(arr.empty()); + arr.at(0) = 0; + arr.at(1) = 1; + arr.at(2) = 2; + arr.at(3) = 3; + ASSERT_EQ(0ULL, arr[0]); + ASSERT_EQ(1ULL, arr[1]); + ASSERT_EQ(2ULL, arr[2]); + ASSERT_EQ(3ULL, arr[3]); + ASSERT_EQ(56, static_cast(pool.alloc_byte_count())); + + arr.shrink_to_fit(); + ASSERT_FALSE(arr.begin() == nullptr); + ASSERT_FALSE(arr.end() == nullptr); + ASSERT_FALSE(arr.begin() == arr.end()); + ASSERT_EQ(4, static_cast(arr.end() - arr.begin())); + ASSERT_EQ(4ULL, arr.size()); + ASSERT_EQ(4ULL, arr.capacity()); + ASSERT_FALSE(arr.empty()); + ASSERT_EQ(0ULL, arr[0]); + ASSERT_EQ(1ULL, arr[1]); + ASSERT_EQ(2ULL, arr[2]); + ASSERT_EQ(3ULL, arr[3]); + ASSERT_EQ(88, static_cast(pool.alloc_byte_count())); + } + } + + TEST(IntArrayTest, SaveLoadIntArray) + { + IntArray arr(6, 4); + arr.at(0) = 0; + arr.at(1) = 1; + arr.at(2) = 2; + arr.at(3) = 3; + stringstream ss; + arr.save(ss); + IntArray arr2; + arr2.load(ss); + + ASSERT_EQ(arr.size(), arr2.size()); + ASSERT_EQ(arr.size(), arr2.capacity()); + ASSERT_EQ(arr[0], arr2[0]); + ASSERT_EQ(arr[1], arr2[1]); + ASSERT_EQ(arr[2], arr2[2]); + ASSERT_EQ(arr[3], arr2[3]); + + arr.resize(2); + arr[0] = 5; + arr[1] = 6; + arr.save(ss); + arr2.load(ss); + + ASSERT_EQ(arr.size(), arr2.size()); + ASSERT_EQ(4ULL, arr2.capacity()); + ASSERT_EQ(arr[0], arr2[0]); + ASSERT_EQ(arr[1], arr2[1]); + } +} diff --git a/tests/seal/keygenerator.cpp b/tests/seal/keygenerator.cpp new file mode 100644 index 000000000..1046d8612 --- /dev/null +++ b/tests/seal/keygenerator.cpp @@ -0,0 +1,574 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/context.h" +#include "seal/keygenerator.h" +#include "seal/util/polycore.h" +#include "seal/defaultparams.h" +#include "seal/encryptor.h" +#include "seal/decryptor.h" + +using namespace seal; +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + TEST(KeyGeneratorTest, FVKeyGeneration) + { + EncryptionParameters parms(scheme_type::BFV); + { + parms.set_noise_standard_deviation(3.20); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(1 << 6); + parms.set_coeff_modulus({ small_mods_60bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + RelinKeys evk = keygen.relin_keys(60); + ASSERT_TRUE(evk.parms_id() == parms.parms_id()); + ASSERT_EQ(2ULL, evk.key(2)[0].size()); + for (size_t j = 0; j < evk.size(); j++) + { + for (size_t i = 0; i < evk.key(j + 2).size(); i++) + { + for (size_t k = 0; k < evk.key(j + 2)[i].size(); k++) + { + ASSERT_FALSE(is_zero_poly(evk.key(j + 2)[i].data(k), evk.key(j + 2)[i].poly_modulus_degree(), evk.key(j + 2)[i].coeff_mod_count())); + } + } + } + + evk = keygen.relin_keys(30, 1); + ASSERT_TRUE(evk.parms_id() == parms.parms_id()); + ASSERT_EQ(4ULL, evk.key(2)[0].size()); + for (size_t j = 0; j < evk.size(); j++) + { + for (size_t i = 0; i < evk.key(j + 2).size(); i++) + { + for (size_t k = 0; k < evk.key(j + 2)[i].size(); k++) + { + ASSERT_FALSE(is_zero_poly(evk.key(j + 2)[i].data(k), evk.key(j + 2)[i].poly_modulus_degree(), evk.key(j + 2)[i].coeff_mod_count())); + } + } + } + + evk = keygen.relin_keys(2, 2); + ASSERT_TRUE(evk.parms_id() == parms.parms_id()); + ASSERT_EQ(60ULL, evk.key(2)[0].size()); + for (size_t j = 0; j < evk.size(); j++) + { + for (size_t i = 0; i < evk.key(j + 2).size(); i++) + { + for (size_t k = 0; k < evk.key(j + 2)[i].size(); k++) + { + ASSERT_FALSE(is_zero_poly(evk.key(j + 2)[i].data(k), evk.key(j + 2)[i].poly_modulus_degree(), evk.key(j + 2)[i].coeff_mod_count())); + } + } + } + + GaloisKeys galks = keygen.galois_keys(60); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_EQ(2ULL, galks.key(3)[0].size()); + ASSERT_EQ(10ULL, galks.size()); + + galks = keygen.galois_keys(30); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_EQ(4ULL, galks.key(3)[0].size()); + ASSERT_EQ(10ULL, galks.size()); + + galks = keygen.galois_keys(2); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_EQ(60ULL, galks.key(3)[0].size()); + ASSERT_EQ(10ULL, galks.size()); + + galks = keygen.galois_keys(60, vector{ 1, 3, 5, 7 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(3)); + ASSERT_TRUE(galks.has_key(5)); + ASSERT_TRUE(galks.has_key(7)); + ASSERT_FALSE(galks.has_key(9)); + ASSERT_FALSE(galks.has_key(127)); + ASSERT_EQ(2ULL, galks.key(1)[0].size()); + ASSERT_EQ(2ULL, galks.key(3)[0].size()); + ASSERT_EQ(2ULL, galks.key(5)[0].size()); + ASSERT_EQ(2ULL, galks.key(7)[0].size()); + ASSERT_EQ(4ULL, galks.size()); + + galks = keygen.galois_keys(30, vector{ 1, 3, 5, 7 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(3)); + ASSERT_TRUE(galks.has_key(5)); + ASSERT_TRUE(galks.has_key(7)); + ASSERT_FALSE(galks.has_key(9)); + ASSERT_FALSE(galks.has_key(127)); + ASSERT_EQ(4ULL, galks.key(1)[0].size()); + ASSERT_EQ(4ULL, galks.key(3)[0].size()); + ASSERT_EQ(4ULL, galks.key(5)[0].size()); + ASSERT_EQ(4ULL, galks.key(7)[0].size()); + ASSERT_EQ(4ULL, galks.size()); + + galks = keygen.galois_keys(2, vector{ 1, 3, 5, 7 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(3)); + ASSERT_TRUE(galks.has_key(5)); + ASSERT_TRUE(galks.has_key(7)); + ASSERT_FALSE(galks.has_key(9)); + ASSERT_FALSE(galks.has_key(127)); + ASSERT_EQ(60ULL, galks.key(1)[0].size()); + ASSERT_EQ(60ULL, galks.key(3)[0].size()); + ASSERT_EQ(60ULL, galks.key(5)[0].size()); + ASSERT_EQ(60ULL, galks.key(7)[0].size()); + ASSERT_EQ(4ULL, galks.size()); + + galks = keygen.galois_keys(30, vector{ 1 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_FALSE(galks.has_key(3)); + ASSERT_FALSE(galks.has_key(127)); + ASSERT_EQ(4ULL, galks.key(1)[0].size()); + ASSERT_EQ(1ULL, galks.size()); + + galks = keygen.galois_keys(30, vector{ 127 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_FALSE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(127)); + ASSERT_EQ(4ULL, galks.key(127)[0].size()); + ASSERT_EQ(1ULL, galks.size()); + } + { + parms.set_noise_standard_deviation(3.20); + parms.set_poly_modulus_degree(256); + parms.set_plain_modulus(1 << 6); + parms.set_coeff_modulus({ small_mods_60bit(0), small_mods_30bit(0), small_mods_30bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + RelinKeys evk = keygen.relin_keys(60, 2); + ASSERT_TRUE(evk.parms_id() == parms.parms_id()); + ASSERT_EQ(2ULL, evk.key(2)[0].size()); + for (size_t j = 0; j < evk.size(); j++) + { + for (size_t i = 0; i < evk.key(j + 2).size(); i++) + { + for (size_t k = 0; k < evk.key(j + 2)[i].size(); k++) + { + ASSERT_FALSE(is_zero_poly(evk.key(j + 2)[i].data(k), evk.key(j + 2)[i].poly_modulus_degree(), evk.key(j + 2)[i].coeff_mod_count())); + } + } + } + + evk = keygen.relin_keys(30, 2); + ASSERT_TRUE(evk.parms_id() == parms.parms_id()); + ASSERT_EQ(4ULL, evk.key(2)[0].size()); + for (size_t j = 0; j < evk.size(); j++) + { + for (size_t i = 0; i < evk.key(j + 2).size(); i++) + { + for (size_t k = 0; k < evk.key(j + 2)[i].size(); k++) + { + ASSERT_FALSE(is_zero_poly(evk.key(j + 2)[i].data(k), evk.key(j + 2)[i].poly_modulus_degree(), evk.key(j + 2)[i].coeff_mod_count())); + } + } + } + + evk = keygen.relin_keys(4, 1); + ASSERT_TRUE(evk.parms_id() == parms.parms_id()); + ASSERT_EQ(30ULL, evk.key(2)[0].size()); + for (size_t j = 0; j < evk.size(); j++) + { + for (size_t i = 0; i < evk.key(j + 2).size(); i++) + { + for (size_t k = 0; k < evk.key(j + 2)[i].size(); k++) + { + ASSERT_FALSE(is_zero_poly(evk.key(j + 2)[i].data(k), evk.key(j + 2)[i].poly_modulus_degree(), evk.key(j + 2)[i].coeff_mod_count())); + } + } + } + + GaloisKeys galks = keygen.galois_keys(60); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_EQ(2ULL, galks.key(3)[0].size()); + ASSERT_EQ(14ULL, galks.size()); + + galks = keygen.galois_keys(30); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_EQ(4ULL, galks.key(3)[0].size()); + ASSERT_EQ(14ULL, galks.size()); + + galks = keygen.galois_keys(2); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_EQ(60ULL, galks.key(3)[0].size()); + ASSERT_EQ(14ULL, galks.size()); + + galks = keygen.galois_keys(60, vector{ 1, 3, 5, 7 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(3)); + ASSERT_TRUE(galks.has_key(5)); + ASSERT_TRUE(galks.has_key(7)); + ASSERT_FALSE(galks.has_key(9)); + ASSERT_FALSE(galks.has_key(511)); + ASSERT_EQ(2ULL, galks.key(1)[0].size()); + ASSERT_EQ(2ULL, galks.key(3)[0].size()); + ASSERT_EQ(2ULL, galks.key(5)[0].size()); + ASSERT_EQ(2ULL, galks.key(7)[0].size()); + ASSERT_EQ(4ULL, galks.size()); + + galks = keygen.galois_keys(30, vector{ 1, 3, 5, 7 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(3)); + ASSERT_TRUE(galks.has_key(5)); + ASSERT_TRUE(galks.has_key(7)); + ASSERT_FALSE(galks.has_key(9)); + ASSERT_FALSE(galks.has_key(511)); + ASSERT_EQ(4ULL, galks.key(1)[0].size()); + ASSERT_EQ(4ULL, galks.key(3)[0].size()); + ASSERT_EQ(4ULL, galks.key(5)[0].size()); + ASSERT_EQ(4ULL, galks.key(7)[0].size()); + ASSERT_EQ(4ULL, galks.size()); + + galks = keygen.galois_keys(2, vector{ 1, 3, 5, 7 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(3)); + ASSERT_TRUE(galks.has_key(5)); + ASSERT_TRUE(galks.has_key(7)); + ASSERT_FALSE(galks.has_key(9)); + ASSERT_FALSE(galks.has_key(511)); + ASSERT_EQ(60ULL, galks.key(1)[0].size()); + ASSERT_EQ(60ULL, galks.key(3)[0].size()); + ASSERT_EQ(60ULL, galks.key(5)[0].size()); + ASSERT_EQ(60ULL, galks.key(7)[0].size()); + ASSERT_EQ(4ULL, galks.size()); + + galks = keygen.galois_keys(30, vector{ 1 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_FALSE(galks.has_key(3)); + ASSERT_FALSE(galks.has_key(511)); + ASSERT_EQ(4ULL, galks.key(1)[0].size()); + ASSERT_EQ(1ULL, galks.size()); + + galks = keygen.galois_keys(30, vector{ 511 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_FALSE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(511)); + ASSERT_EQ(4ULL, galks.key(511)[0].size()); + ASSERT_EQ(1ULL, galks.size()); + } + } + + TEST(KeyGeneratorTest, CKKSKeyGeneration) + { + EncryptionParameters parms(scheme_type::CKKS); + { + parms.set_noise_standard_deviation(3.20); + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus({ small_mods_60bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + RelinKeys evk = keygen.relin_keys(60); + ASSERT_TRUE(evk.parms_id() == parms.parms_id()); + ASSERT_EQ(2ULL, evk.key(2)[0].size()); + for (size_t j = 0; j < evk.size(); j++) + { + for (size_t i = 0; i < evk.key(j + 2).size(); i++) + { + for (size_t k = 0; k < evk.key(j + 2)[i].size(); k++) + { + ASSERT_FALSE(is_zero_poly(evk.key(j + 2)[i].data(k), evk.key(j + 2)[i].poly_modulus_degree(), evk.key(j + 2)[i].coeff_mod_count())); + } + } + } + + evk = keygen.relin_keys(30, 1); + ASSERT_TRUE(evk.parms_id() == parms.parms_id()); + ASSERT_EQ(4ULL, evk.key(2)[0].size()); + for (size_t j = 0; j < evk.size(); j++) + { + for (size_t i = 0; i < evk.key(j + 2).size(); i++) + { + for (size_t k = 0; k < evk.key(j + 2)[i].size(); k++) + { + ASSERT_FALSE(is_zero_poly(evk.key(j + 2)[i].data(k), evk.key(j + 2)[i].poly_modulus_degree(), evk.key(j + 2)[i].coeff_mod_count())); + } + } + } + + evk = keygen.relin_keys(2, 2); + ASSERT_TRUE(evk.parms_id() == parms.parms_id()); + ASSERT_EQ(60ULL, evk.key(2)[0].size()); + for (size_t j = 0; j < evk.size(); j++) + { + for (size_t i = 0; i < evk.key(j + 2).size(); i++) + { + for (size_t k = 0; k < evk.key(j + 2)[i].size(); k++) + { + ASSERT_FALSE(is_zero_poly(evk.key(j + 2)[i].data(k), evk.key(j + 2)[i].poly_modulus_degree(), evk.key(j + 2)[i].coeff_mod_count())); + } + } + } + + GaloisKeys galks = keygen.galois_keys(60); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_EQ(2ULL, galks.key(3)[0].size()); + ASSERT_EQ(10ULL, galks.size()); + + galks = keygen.galois_keys(30); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_EQ(4ULL, galks.key(3)[0].size()); + ASSERT_EQ(10ULL, galks.size()); + + galks = keygen.galois_keys(2); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_EQ(60ULL, galks.key(3)[0].size()); + ASSERT_EQ(10ULL, galks.size()); + + galks = keygen.galois_keys(60, vector{ 1, 3, 5, 7 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(3)); + ASSERT_TRUE(galks.has_key(5)); + ASSERT_TRUE(galks.has_key(7)); + ASSERT_FALSE(galks.has_key(9)); + ASSERT_FALSE(galks.has_key(127)); + ASSERT_EQ(2ULL, galks.key(1)[0].size()); + ASSERT_EQ(2ULL, galks.key(3)[0].size()); + ASSERT_EQ(2ULL, galks.key(5)[0].size()); + ASSERT_EQ(2ULL, galks.key(7)[0].size()); + ASSERT_EQ(4ULL, galks.size()); + + galks = keygen.galois_keys(30, vector{ 1, 3, 5, 7 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(3)); + ASSERT_TRUE(galks.has_key(5)); + ASSERT_TRUE(galks.has_key(7)); + ASSERT_FALSE(galks.has_key(9)); + ASSERT_FALSE(galks.has_key(127)); + ASSERT_EQ(4ULL, galks.key(1)[0].size()); + ASSERT_EQ(4ULL, galks.key(3)[0].size()); + ASSERT_EQ(4ULL, galks.key(5)[0].size()); + ASSERT_EQ(4ULL, galks.key(7)[0].size()); + ASSERT_EQ(4ULL, galks.size()); + + galks = keygen.galois_keys(2, vector{ 1, 3, 5, 7 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(3)); + ASSERT_TRUE(galks.has_key(5)); + ASSERT_TRUE(galks.has_key(7)); + ASSERT_FALSE(galks.has_key(9)); + ASSERT_FALSE(galks.has_key(127)); + ASSERT_EQ(60ULL, galks.key(1)[0].size()); + ASSERT_EQ(60ULL, galks.key(3)[0].size()); + ASSERT_EQ(60ULL, galks.key(5)[0].size()); + ASSERT_EQ(60ULL, galks.key(7)[0].size()); + ASSERT_EQ(4ULL, galks.size()); + + galks = keygen.galois_keys(30, vector{ 1 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_FALSE(galks.has_key(3)); + ASSERT_FALSE(galks.has_key(127)); + ASSERT_EQ(4ULL, galks.key(1)[0].size()); + ASSERT_EQ(1ULL, galks.size()); + + galks = keygen.galois_keys(30, vector{ 127 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_FALSE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(127)); + ASSERT_EQ(4ULL, galks.key(127)[0].size()); + ASSERT_EQ(1ULL, galks.size()); + } + { + parms.set_noise_standard_deviation(3.20); + parms.set_poly_modulus_degree(256); + parms.set_coeff_modulus({ small_mods_60bit(0), small_mods_30bit(0), small_mods_30bit(1) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + RelinKeys evk = keygen.relin_keys(60, 2); + ASSERT_TRUE(evk.parms_id() == parms.parms_id()); + ASSERT_EQ(2ULL, evk.key(2)[0].size()); + for (size_t j = 0; j < evk.size(); j++) + { + for (size_t i = 0; i < evk.key(j + 2).size(); i++) + { + for (size_t k = 0; k < evk.key(j + 2)[i].size(); k++) + { + ASSERT_FALSE(is_zero_poly(evk.key(j + 2)[i].data(k), evk.key(j + 2)[i].poly_modulus_degree(), evk.key(j + 2)[i].coeff_mod_count())); + } + } + } + + evk = keygen.relin_keys(30, 2); + ASSERT_TRUE(evk.parms_id() == parms.parms_id()); + ASSERT_EQ(4ULL, evk.key(2)[0].size()); + for (size_t j = 0; j < evk.size(); j++) + { + for (size_t i = 0; i < evk.key(j + 2).size(); i++) + { + for (size_t k = 0; k < evk.key(j + 2)[i].size(); k++) + { + ASSERT_FALSE(is_zero_poly(evk.key(j + 2)[i].data(k), evk.key(j + 2)[i].poly_modulus_degree(), evk.key(j + 2)[i].coeff_mod_count())); + } + } + } + + evk = keygen.relin_keys(4, 1); + ASSERT_TRUE(evk.parms_id() == parms.parms_id()); + ASSERT_EQ(30ULL, evk.key(2)[0].size()); + for (size_t j = 0; j < evk.size(); j++) + { + for (size_t i = 0; i < evk.key(j + 2).size(); i++) + { + for (size_t k = 0; k < evk.key(j + 2)[i].size(); k++) + { + ASSERT_FALSE(is_zero_poly(evk.key(j + 2)[i].data(k), evk.key(j + 2)[i].poly_modulus_degree(), evk.key(j + 2)[i].coeff_mod_count())); + } + } + } + + GaloisKeys galks = keygen.galois_keys(60); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_EQ(2ULL, galks.key(3)[0].size()); + ASSERT_EQ(14ULL, galks.size()); + + galks = keygen.galois_keys(30); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_EQ(4ULL, galks.key(3)[0].size()); + ASSERT_EQ(14ULL, galks.size()); + + galks = keygen.galois_keys(2); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_EQ(60ULL, galks.key(3)[0].size()); + ASSERT_EQ(14ULL, galks.size()); + + galks = keygen.galois_keys(60, vector{ 1, 3, 5, 7 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(3)); + ASSERT_TRUE(galks.has_key(5)); + ASSERT_TRUE(galks.has_key(7)); + ASSERT_FALSE(galks.has_key(9)); + ASSERT_FALSE(galks.has_key(511)); + ASSERT_EQ(2ULL, galks.key(1)[0].size()); + ASSERT_EQ(2ULL, galks.key(3)[0].size()); + ASSERT_EQ(2ULL, galks.key(5)[0].size()); + ASSERT_EQ(2ULL, galks.key(7)[0].size()); + ASSERT_EQ(4ULL, galks.size()); + + galks = keygen.galois_keys(30, vector{ 1, 3, 5, 7 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(3)); + ASSERT_TRUE(galks.has_key(5)); + ASSERT_TRUE(galks.has_key(7)); + ASSERT_FALSE(galks.has_key(9)); + ASSERT_FALSE(galks.has_key(511)); + ASSERT_EQ(4ULL, galks.key(1)[0].size()); + ASSERT_EQ(4ULL, galks.key(3)[0].size()); + ASSERT_EQ(4ULL, galks.key(5)[0].size()); + ASSERT_EQ(4ULL, galks.key(7)[0].size()); + ASSERT_EQ(4ULL, galks.size()); + + galks = keygen.galois_keys(2, vector{ 1, 3, 5, 7 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(3)); + ASSERT_TRUE(galks.has_key(5)); + ASSERT_TRUE(galks.has_key(7)); + ASSERT_FALSE(galks.has_key(9)); + ASSERT_FALSE(galks.has_key(511)); + ASSERT_EQ(60ULL, galks.key(1)[0].size()); + ASSERT_EQ(60ULL, galks.key(3)[0].size()); + ASSERT_EQ(60ULL, galks.key(5)[0].size()); + ASSERT_EQ(60ULL, galks.key(7)[0].size()); + ASSERT_EQ(4ULL, galks.size()); + + galks = keygen.galois_keys(30, vector{ 1 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_TRUE(galks.has_key(1)); + ASSERT_FALSE(galks.has_key(3)); + ASSERT_FALSE(galks.has_key(511)); + ASSERT_EQ(4ULL, galks.key(1)[0].size()); + ASSERT_EQ(1ULL, galks.size()); + + galks = keygen.galois_keys(30, vector{ 511 }); + ASSERT_TRUE(galks.parms_id() == parms.parms_id()); + ASSERT_FALSE(galks.has_key(1)); + ASSERT_TRUE(galks.has_key(511)); + ASSERT_EQ(4ULL, galks.key(511)[0].size()); + ASSERT_EQ(1ULL, galks.size()); + } + } + + TEST(KeyGeneratorTest, FVSecretKeyGeneration) + { + EncryptionParameters parms(scheme_type::BFV); + parms.set_noise_standard_deviation(3.20); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(1 << 6); + parms.set_coeff_modulus({ small_mods_60bit(0) }); + auto context = SEALContext::Create(parms); + { + KeyGenerator keygen(context); + auto pk = keygen.public_key(); + auto sk = keygen.secret_key(); + + KeyGenerator keygen2(context, sk); + auto sk2 = keygen.secret_key(); + auto pk2 = keygen2.public_key(); + + ASSERT_EQ(sk.data().coeff_count(), sk2.data().coeff_count()); + for (size_t i = 0; i < sk.data().coeff_count(); i++) + { + ASSERT_EQ(sk.data()[i], sk2.data()[i]); + } + + ASSERT_EQ(pk.data().uint64_count(), pk2.data().uint64_count()); + for (size_t i = 0; i < pk.data().uint64_count(); i++) + { + ASSERT_NE(pk.data()[i], pk2.data()[i]); + } + + Encryptor encryptor(context, pk2); + Decryptor decryptor(context, sk); + Ciphertext ctxt; + Plaintext pt1("1x^63 + 2x^33 + 3x^23 + 4x^13 + 5x^1 + 6"); + Plaintext pt2; + encryptor.encrypt(pt1, ctxt); + decryptor.decrypt(ctxt, pt2); + ASSERT_TRUE(pt1 == pt2); + } + { + KeyGenerator keygen(context); + auto pk = keygen.public_key(); + auto sk = keygen.secret_key(); + + KeyGenerator keygen2(context, sk, pk); + auto sk2 = keygen.secret_key(); + auto pk2 = keygen2.public_key(); + + ASSERT_EQ(sk.data().coeff_count(), sk2.data().coeff_count()); + for (size_t i = 0; i < sk.data().coeff_count(); i++) + { + ASSERT_EQ(sk.data()[i], sk2.data()[i]); + } + + ASSERT_EQ(pk.data().uint64_count(), pk2.data().uint64_count()); + for (size_t i = 0; i < pk.data().uint64_count(); i++) + { + ASSERT_EQ(pk.data()[i], pk2.data()[i]); + } + } + } +} diff --git a/tests/seal/memorymanager.cpp b/tests/seal/memorymanager.cpp new file mode 100644 index 000000000..2cca07d9f --- /dev/null +++ b/tests/seal/memorymanager.cpp @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/pointer.h" +#include "seal/memorymanager.h" +#include "seal/util/uintcore.h" + +using namespace seal; +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + TEST(MemoryPoolHandleTest, MemoryPoolHandleConstructAssign) + { + MemoryPoolHandle pool; + ASSERT_FALSE(pool); + pool = MemoryPoolHandle::Global(); + ASSERT_TRUE(&static_cast(pool) == global_variables::global_memory_pool.get()); + pool = MemoryPoolHandle::New(); + ASSERT_FALSE(&pool.operator seal::util::MemoryPool &() == global_variables::global_memory_pool.get()); + MemoryPoolHandle pool2 = MemoryPoolHandle::New(); + ASSERT_FALSE(pool == pool2); + + pool = pool2; + ASSERT_TRUE(pool == pool2); + pool = MemoryPoolHandle::Global(); + ASSERT_FALSE(pool == pool2); + pool2 = MemoryPoolHandle::Global(); + ASSERT_TRUE(pool == pool2); + } + + TEST(MemoryPoolHandleTest, MemoryPoolHandleAllocate) + { + MemoryPoolHandle pool = MemoryPoolHandle::New(); + ASSERT_TRUE(0LL == pool.alloc_byte_count()); + { + auto ptr(allocate_uint(5, pool)); + ASSERT_TRUE(5LL * bytes_per_uint64 == pool.alloc_byte_count()); + } + + pool = MemoryPoolHandle::New(); + ASSERT_TRUE(0LL * bytes_per_uint64 == pool.alloc_byte_count()); + { + auto ptr(allocate_uint(5, pool)); + ASSERT_TRUE(5LL * bytes_per_uint64 == pool.alloc_byte_count()); + + ptr = allocate_uint(8, pool); + ASSERT_TRUE(13LL * bytes_per_uint64 == pool.alloc_byte_count()); + + auto ptr2(allocate_uint(2, pool)); + ASSERT_TRUE(15LL * bytes_per_uint64 == pool.alloc_byte_count()); + } + } +} diff --git a/tests/seal/plaintext.cpp b/tests/seal/plaintext.cpp new file mode 100644 index 000000000..6ae853bda --- /dev/null +++ b/tests/seal/plaintext.cpp @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include "gtest/gtest.h" +#include "seal/plaintext.h" +#include "seal/evaluator.h" +#include "seal/context.h" +#include "seal/memorymanager.h" +#include "seal/defaultparams.h" +#include "seal/ckks.h" + +using namespace seal; +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + TEST(PlaintextTest, PlaintextBasics) + { + Plaintext plain(2); + ASSERT_EQ(2ULL, plain.capacity()); + ASSERT_EQ(2ULL, plain.coeff_count()); + ASSERT_EQ(0ULL, plain.significant_coeff_count()); + ASSERT_FALSE(plain.is_ntt_form()); + + plain[0] = 1; + plain[1] = 2; + + plain.reserve(10); + ASSERT_EQ(10ULL, plain.capacity()); + ASSERT_EQ(2ULL, plain.coeff_count()); + ASSERT_EQ(2ULL, plain.significant_coeff_count()); + ASSERT_EQ(1ULL, plain[0]); + ASSERT_EQ(2ULL, plain[1]); + ASSERT_FALSE(plain.is_ntt_form()); + + plain.resize(5); + ASSERT_EQ(10ULL, plain.capacity()); + ASSERT_EQ(5ULL, plain.coeff_count()); + ASSERT_EQ(2ULL, plain.significant_coeff_count()); + ASSERT_EQ(1ULL, plain[0]); + ASSERT_EQ(2ULL, plain[1]); + ASSERT_EQ(0ULL, plain[2]); + ASSERT_EQ(0ULL, plain[3]); + ASSERT_EQ(0ULL, plain[4]); + ASSERT_FALSE(plain.is_ntt_form()); + + Plaintext plain2; + plain2.resize(15); + ASSERT_EQ(15ULL, plain2.capacity()); + ASSERT_EQ(15ULL, plain2.coeff_count()); + ASSERT_EQ(0ULL, plain2.significant_coeff_count()); + ASSERT_FALSE(plain.is_ntt_form()); + + plain2 = plain; + ASSERT_EQ(15ULL, plain2.capacity()); + ASSERT_EQ(5ULL, plain2.coeff_count()); + ASSERT_EQ(2ULL, plain2.significant_coeff_count()); + ASSERT_EQ(1ULL, plain2[0]); + ASSERT_EQ(2ULL, plain2[1]); + ASSERT_EQ(0ULL, plain2[2]); + ASSERT_EQ(0ULL, plain2[3]); + ASSERT_EQ(0ULL, plain2[4]); + ASSERT_FALSE(plain.is_ntt_form()); + + plain.parms_id() = { 1ULL, 2ULL, 3ULL, 4ULL }; + ASSERT_TRUE(plain.is_ntt_form()); + plain2 = plain; + ASSERT_TRUE(plain == plain2); + plain2.parms_id() = parms_id_zero; + ASSERT_FALSE(plain2.is_ntt_form()); + ASSERT_FALSE(plain == plain2); + plain2.parms_id() = { 1ULL, 2ULL, 3ULL, 5ULL }; + ASSERT_FALSE(plain == plain2); + } + + TEST(PlaintextTest, SaveLoadPlaintext) + { + stringstream stream; + + Plaintext plain; + Plaintext plain2; + plain.save(stream); + plain2.unsafe_load(stream); + ASSERT_TRUE(plain.data() == plain2.data()); + ASSERT_TRUE(plain2.data() == nullptr); + ASSERT_EQ(0ULL, plain2.capacity()); + ASSERT_EQ(0ULL, plain2.coeff_count()); + ASSERT_FALSE(plain2.is_ntt_form()); + + plain.reserve(20); + plain.resize(5); + plain[0] = 1; + plain[1] = 2; + plain[2] = 3; + plain.save(stream); + plain2.unsafe_load(stream); + ASSERT_TRUE(plain.data() != plain2.data()); + ASSERT_EQ(5ULL, plain2.capacity()); + ASSERT_EQ(5ULL, plain2.coeff_count()); + ASSERT_EQ(1ULL, plain2[0]); + ASSERT_EQ(2ULL, plain2[1]); + ASSERT_EQ(3ULL, plain2[2]); + ASSERT_EQ(0ULL, plain2[3]); + ASSERT_EQ(0ULL, plain2[4]); + ASSERT_FALSE(plain2.is_ntt_form()); + + plain.parms_id() = { 1, 2, 3, 4 }; + plain.save(stream); + plain2.unsafe_load(stream); + ASSERT_TRUE(plain2.is_ntt_form()); + ASSERT_TRUE(plain2.parms_id() == plain.parms_id()); + + { + EncryptionParameters parms(scheme_type::BFV); + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus({ small_mods_30bit(0), small_mods_30bit(1) }); + parms.set_plain_modulus(65537); + auto context = SEALContext::Create(parms, false); + + plain.parms_id() = parms_id_zero; + plain = "1x^63 + 2x^62 + Fx^32 + Ax^9 + 1x^1 + 1"; + plain.save(stream); + plain2.load(context, stream); + ASSERT_TRUE(plain.data() != plain2.data()); + ASSERT_FALSE(plain2.is_ntt_form()); + + Evaluator evaluator(context); + evaluator.transform_to_ntt_inplace(plain, context->first_parms_id()); + plain.save(stream); + plain2.load(context, stream); + ASSERT_TRUE(plain.data() != plain2.data()); + ASSERT_TRUE(plain2.is_ntt_form()); + } + { + EncryptionParameters parms(scheme_type::CKKS); + parms.set_poly_modulus_degree(64); + parms.set_coeff_modulus({ small_mods_30bit(0), small_mods_30bit(1) }); + auto context = SEALContext::Create(parms, false); + CKKSEncoder encoder(context); + + encoder.encode(vector{ 0.1, 2.3, 34.4 }, pow(2.0, 20), plain); + ASSERT_TRUE(plain.is_ntt_form()); + plain.save(stream); + plain2.load(context, stream); + ASSERT_TRUE(plain.data() != plain2.data()); + ASSERT_TRUE(plain2.is_ntt_form()); + } + } +} diff --git a/tests/seal/publickey.cpp b/tests/seal/publickey.cpp new file mode 100644 index 000000000..e517e6965 --- /dev/null +++ b/tests/seal/publickey.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/publickey.h" +#include "seal/context.h" +#include "seal/defaultparams.h" +#include "seal/keygenerator.h" + +using namespace seal; +using namespace std; + +namespace SEALTest +{ + TEST(PublicKeyTest, SaveLoadPublicKey) + { + stringstream stream; + { + EncryptionParameters parms(scheme_type::BFV); + parms.set_noise_standard_deviation(3.20); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(1 << 6); + parms.set_coeff_modulus({ small_mods_60bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + PublicKey pk = keygen.public_key(); + ASSERT_TRUE(pk.parms_id() == parms.parms_id()); + pk.save(stream); + + PublicKey pk2; + pk2.load(context, stream); + + ASSERT_EQ(pk.data().uint64_count(), pk2.data().uint64_count()); + for (size_t i = 0; i < pk.data().uint64_count(); i++) + { + ASSERT_EQ(pk.data().data()[i], pk2.data().data()[i]); + } + ASSERT_TRUE(pk.parms_id() == pk2.parms_id()); + } + { + EncryptionParameters parms(scheme_type::BFV); + parms.set_noise_standard_deviation(3.20); + parms.set_poly_modulus_degree(256); + parms.set_plain_modulus(1 << 20); + parms.set_coeff_modulus({ small_mods_30bit(0), small_mods_40bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + PublicKey pk = keygen.public_key(); + ASSERT_TRUE(pk.parms_id() == parms.parms_id()); + pk.save(stream); + + PublicKey pk2; + pk2.load(context, stream); + + ASSERT_EQ(pk.data().uint64_count(), pk2.data().uint64_count()); + for (size_t i = 0; i < pk.data().uint64_count(); i++) + { + ASSERT_EQ(pk.data().data()[i], pk2.data().data()[i]); + } + ASSERT_TRUE(pk.parms_id() == pk2.parms_id()); + } + } +} diff --git a/tests/seal/randomgen.cpp b/tests/seal/randomgen.cpp new file mode 100644 index 000000000..15d07a4f7 --- /dev/null +++ b/tests/seal/randomgen.cpp @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/randomgen.h" +#include "seal/keygenerator.h" +#include +#include +#include + +using namespace seal; +using namespace std; + +namespace SEALTest +{ + namespace + { + class CustomRandomEngine : public UniformRandomGenerator + { + public: + CustomRandomEngine() + { + } + + uint32_t generate() override + { + count_++; + return static_cast(engine_()); + } + + static int count() + { + return count_; + } + + private: + default_random_engine engine_; + + static int count_; + }; + + class CustomRandomEngineFactory : public UniformRandomGeneratorFactory + { + public: + shared_ptr create() override + { + return shared_ptr(new CustomRandomEngine()); + } + }; + + int CustomRandomEngine::count_ = 0; + } + + TEST(RandomGenerator, UniformRandomCreateDefault) + { + shared_ptr generator(UniformRandomGeneratorFactory::default_factory()->create()); + bool lower_half = false; + bool upper_half = false; + bool even = false; + bool odd = false; + for (int i = 0; i < 10; ++i) + { + uint32_t value = generator->generate(); + if (value < UINT32_MAX / 2) + { + lower_half = true; + } + else + { + upper_half = true; + } + if ((value % 2) == 0) + { + even = true; + } + else + { + odd = true; + } + } + ASSERT_TRUE(lower_half); + ASSERT_TRUE(upper_half); + ASSERT_TRUE(even); + ASSERT_TRUE(odd); + } + + TEST(RandomGenerator, StandardRandomAdapterGenerate) + { + StandardRandomAdapter generator; + bool lower_half = false; + bool upper_half = false; + bool even = false; + bool odd = false; + for (int i = 0; i < 10; ++i) + { + uint32_t value = generator.generate(); + if (value < UINT32_MAX / 2) + { + lower_half = true; + } + else + { + upper_half = true; + } + if ((value % 2) == 0) + { + even = true; + } + else + { + odd = true; + } + } + ASSERT_TRUE(lower_half); + ASSERT_TRUE(upper_half); + ASSERT_TRUE(even); + ASSERT_TRUE(odd); + } + + TEST(RandomGenerator, CustomRandomGenerator) + { + shared_ptr factory(new CustomRandomEngineFactory); + + EncryptionParameters parms(scheme_type::BFV); + uint64_t coeff_modulus; + SmallModulus plain_modulus; + parms.set_noise_standard_deviation(3.20); + coeff_modulus = 0xFFFFFFFFC001; + plain_modulus = 1 << 6; + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(plain_modulus); + parms.set_coeff_modulus({ coeff_modulus }); + parms.set_random_generator(factory); + auto context = SEALContext::Create(parms); + + ASSERT_EQ(0, CustomRandomEngine::count()); + + KeyGenerator keygen(context); + + ASSERT_NE(0, CustomRandomEngine::count()); + } +} diff --git a/tests/seal/relinkeys.cpp b/tests/seal/relinkeys.cpp new file mode 100644 index 000000000..4c74b04fd --- /dev/null +++ b/tests/seal/relinkeys.cpp @@ -0,0 +1,179 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/relinkeys.h" +#include "seal/context.h" +#include "seal/keygenerator.h" +#include "seal/util/uintcore.h" +#include "seal/defaultparams.h" + +using namespace seal; +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + TEST(RelinKeysTest, RelinKeysSaveLoad) + { + stringstream stream; + { + EncryptionParameters parms(scheme_type::BFV); + parms.set_noise_standard_deviation(3.20); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(1 << 6); + parms.set_coeff_modulus({ small_mods_60bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + RelinKeys keys; + RelinKeys test_keys; + keys = keygen.relin_keys(1, 1); + ASSERT_EQ(keys.decomposition_bit_count(), 1); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(keys.decomposition_bit_count(), test_keys.decomposition_bit_count()); + for (size_t j = 0; j < test_keys.size(); j++) + { + for (size_t i = 0; i < test_keys.key(j + 2).size(); i++) + { + ASSERT_EQ(keys.key(j + 2)[i].size(), test_keys.key(j + 2)[i].size()); + ASSERT_EQ(keys.key(j + 2)[i].uint64_count(), test_keys.key(j + 2)[i].uint64_count()); + ASSERT_TRUE(is_equal_uint_uint(keys.key(j + 2)[i].data(), test_keys.key(j + 2)[i].data(), keys.key(j + 2)[i].uint64_count())); + } + } + + keys = keygen.relin_keys(2, 1); + ASSERT_EQ(keys.decomposition_bit_count(), 2); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(keys.decomposition_bit_count(), test_keys.decomposition_bit_count()); + for (size_t j = 0; j < test_keys.size(); j++) + { + for (size_t i = 0; i < test_keys.key(j + 2).size(); i++) + { + ASSERT_EQ(keys.key(j + 2)[i].size(), test_keys.key(j + 2)[i].size()); + ASSERT_EQ(keys.key(j + 2)[i].uint64_count(), test_keys.key(j + 2)[i].uint64_count()); + ASSERT_TRUE(is_equal_uint_uint(keys.key(j + 2)[i].data(), test_keys.key(j + 2)[i].data(), keys.key(j + 2)[i].uint64_count())); + } + } + + keys = keygen.relin_keys(59, 2); + ASSERT_EQ(keys.decomposition_bit_count(), 59); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(keys.decomposition_bit_count(), test_keys.decomposition_bit_count()); + for (size_t j = 0; j < test_keys.size(); j++) + { + for (size_t i = 0; i < test_keys.key(j + 2).size(); i++) + { + ASSERT_EQ(keys.key(j + 2)[i].size(), test_keys.key(j + 2)[i].size()); + ASSERT_EQ(keys.key(j + 2)[i].uint64_count(), test_keys.key(j + 2)[i].uint64_count()); + ASSERT_TRUE(is_equal_uint_uint(keys.key(j + 2)[i].data(), test_keys.key(j + 2)[i].data(), keys.key(j + 2)[i].uint64_count())); + } + } + + keys = keygen.relin_keys(60, 5); + ASSERT_EQ(keys.decomposition_bit_count(), 60); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(keys.decomposition_bit_count(), test_keys.decomposition_bit_count()); + for (size_t j = 0; j < test_keys.size(); j++) + { + for (size_t i = 0; i < test_keys.key(j + 2).size(); i++) + { + ASSERT_EQ(keys.key(j + 2)[i].size(), test_keys.key(j + 2)[i].size()); + ASSERT_EQ(keys.key(j + 2)[i].uint64_count(), test_keys.key(j + 2)[i].uint64_count()); + ASSERT_TRUE(is_equal_uint_uint(keys.key(j + 2)[i].data(), test_keys.key(j + 2)[i].data(), keys.key(j + 2)[i].uint64_count())); + } + } + } + { + EncryptionParameters parms(scheme_type::BFV); + parms.set_noise_standard_deviation(3.20); + parms.set_poly_modulus_degree(256); + parms.set_plain_modulus(1 << 6); + parms.set_coeff_modulus({ small_mods_60bit(0), small_mods_50bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + RelinKeys keys; + RelinKeys test_keys; + keys = keygen.relin_keys(8, 1); + ASSERT_EQ(keys.decomposition_bit_count(), 8); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(keys.decomposition_bit_count(), test_keys.decomposition_bit_count()); + for (size_t j = 0; j < test_keys.size(); j++) + { + for (size_t i = 0; i < test_keys.key(j + 2).size(); i++) + { + ASSERT_EQ(keys.key(j + 2)[i].size(), test_keys.key(j + 2)[i].size()); + ASSERT_EQ(keys.key(j + 2)[i].uint64_count(), test_keys.key(j + 2)[i].uint64_count()); + ASSERT_TRUE(is_equal_uint_uint(keys.key(j + 2)[i].data(), test_keys.key(j + 2)[i].data(), keys.key(j + 2)[i].uint64_count())); + } + } + + keys = keygen.relin_keys(8, 2); + ASSERT_EQ(keys.decomposition_bit_count(), 8); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(keys.decomposition_bit_count(), test_keys.decomposition_bit_count()); + for (size_t j = 0; j < test_keys.size(); j++) + { + for (size_t i = 0; i < test_keys.key(j + 2).size(); i++) + { + ASSERT_EQ(keys.key(j + 2)[i].size(), test_keys.key(j + 2)[i].size()); + ASSERT_EQ(keys.key(j + 2)[i].uint64_count(), test_keys.key(j + 2)[i].uint64_count()); + ASSERT_TRUE(is_equal_uint_uint(keys.key(j + 2)[i].data(), test_keys.key(j + 2)[i].data(), keys.key(j + 2)[i].uint64_count())); + } + } + + keys = keygen.relin_keys(59, 2); + ASSERT_EQ(keys.decomposition_bit_count(), 59); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(keys.decomposition_bit_count(), test_keys.decomposition_bit_count()); + for (size_t j = 0; j < test_keys.size(); j++) + { + for (size_t i = 0; i < test_keys.key(j + 2).size(); i++) + { + ASSERT_EQ(keys.key(j + 2)[i].size(), test_keys.key(j + 2)[i].size()); + ASSERT_EQ(keys.key(j + 2)[i].uint64_count(), test_keys.key(j + 2)[i].uint64_count()); + ASSERT_TRUE(is_equal_uint_uint(keys.key(j + 2)[i].data(), test_keys.key(j + 2)[i].data(), keys.key(j + 2)[i].uint64_count())); + } + } + + keys = keygen.relin_keys(60, 5); + ASSERT_EQ(keys.decomposition_bit_count(), 60); + keys.save(stream); + test_keys.load(context, stream); + ASSERT_EQ(keys.size(), test_keys.size()); + ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); + ASSERT_EQ(keys.decomposition_bit_count(), test_keys.decomposition_bit_count()); + for (size_t j = 0; j < test_keys.size(); j++) + { + for (size_t i = 0; i < test_keys.key(j + 2).size(); i++) + { + ASSERT_EQ(keys.key(j + 2)[i].size(), test_keys.key(j + 2)[i].size()); + ASSERT_EQ(keys.key(j + 2)[i].uint64_count(), test_keys.key(j + 2)[i].uint64_count()); + ASSERT_TRUE(is_equal_uint_uint(keys.key(j + 2)[i].data(), test_keys.key(j + 2)[i].data(), keys.key(j + 2)[i].uint64_count())); + } + } + } + } +} diff --git a/tests/seal/secretkey.cpp b/tests/seal/secretkey.cpp new file mode 100644 index 000000000..8f32baf4c --- /dev/null +++ b/tests/seal/secretkey.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/secretkey.h" +#include "seal/context.h" +#include "seal/keygenerator.h" +#include "seal/defaultparams.h" + +using namespace seal; +using namespace std; + +namespace SEALTest +{ + TEST(SecretKeyTest, SaveLoadSecretKey) + { + stringstream stream; + { + EncryptionParameters parms(scheme_type::BFV); + parms.set_noise_standard_deviation(3.20); + parms.set_poly_modulus_degree(64); + parms.set_plain_modulus(1 << 6); + parms.set_coeff_modulus({ small_mods_60bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + SecretKey sk = keygen.secret_key(); + ASSERT_TRUE(sk.parms_id() == parms.parms_id()); + sk.save(stream); + + SecretKey sk2; + sk2.load(context, stream); + + ASSERT_TRUE(sk.data() == sk2.data()); + ASSERT_TRUE(sk.parms_id() == sk2.parms_id()); + } + { + EncryptionParameters parms(scheme_type::BFV); + parms.set_noise_standard_deviation(3.20); + parms.set_poly_modulus_degree(256); + parms.set_plain_modulus(1 << 20); + parms.set_coeff_modulus({ small_mods_30bit(0), small_mods_40bit(0) }); + auto context = SEALContext::Create(parms); + KeyGenerator keygen(context); + + SecretKey sk = keygen.secret_key(); + ASSERT_TRUE(sk.parms_id() == parms.parms_id()); + sk.save(stream); + + SecretKey sk2; + sk2.load(context, stream); + + ASSERT_TRUE(sk.data() == sk2.data()); + ASSERT_TRUE(sk.parms_id() == sk2.parms_id()); + } + } +} diff --git a/tests/seal/smallmodulus.cpp b/tests/seal/smallmodulus.cpp new file mode 100644 index 000000000..a766ed34e --- /dev/null +++ b/tests/seal/smallmodulus.cpp @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/smallmodulus.h" + +using namespace seal; +using namespace std; + +namespace SEALTest +{ + TEST(SmallModulusTest, CreateSmallModulus) + { + SmallModulus mod; + ASSERT_TRUE(mod.is_zero()); + ASSERT_EQ(0ULL, mod.value()); + ASSERT_EQ(0, mod.bit_count()); + ASSERT_EQ(1ULL, mod.uint64_count()); + ASSERT_EQ(0ULL, mod.const_ratio()[0]); + ASSERT_EQ(0ULL, mod.const_ratio()[1]); + ASSERT_EQ(0ULL, mod.const_ratio()[2]); + + mod = 3; + ASSERT_FALSE(mod.is_zero()); + ASSERT_EQ(3ULL, mod.value()); + ASSERT_EQ(2, mod.bit_count()); + ASSERT_EQ(1ULL, mod.uint64_count()); + ASSERT_EQ(6148914691236517205ULL, mod.const_ratio()[0]); + ASSERT_EQ(6148914691236517205ULL, mod.const_ratio()[1]); + ASSERT_EQ(1ULL, mod.const_ratio()[2]); + + SmallModulus mod2(2); + SmallModulus mod3(3); + ASSERT_TRUE(mod != mod2); + ASSERT_TRUE(mod == mod3); + + mod = 0; + ASSERT_TRUE(mod.is_zero()); + ASSERT_EQ(0ULL, mod.value()); + ASSERT_EQ(0, mod.bit_count()); + ASSERT_EQ(1ULL, mod.uint64_count()); + ASSERT_EQ(0ULL, mod.const_ratio()[0]); + ASSERT_EQ(0ULL, mod.const_ratio()[1]); + ASSERT_EQ(0ULL, mod.const_ratio()[2]); + + mod = 0xF00000F00000F; + ASSERT_FALSE(mod.is_zero()); + ASSERT_EQ(0xF00000F00000FULL, mod.value()); + ASSERT_EQ(52, mod.bit_count()); + ASSERT_EQ(1ULL, mod.uint64_count()); + ASSERT_EQ(1224979098644774929ULL, mod.const_ratio()[0]); + ASSERT_EQ(4369ULL, mod.const_ratio()[1]); + ASSERT_EQ(281470698520321ULL, mod.const_ratio()[2]); + } + + TEST(SmallModulusTest, SaveLoadSmallModulus) + { + stringstream stream; + + SmallModulus mod; + mod.save(stream); + + SmallModulus mod2; + mod2.load(stream); + ASSERT_EQ(mod2.value(), mod.value()); + ASSERT_EQ(mod2.bit_count(), mod.bit_count()); + ASSERT_EQ(mod2.uint64_count(), mod.uint64_count()); + ASSERT_EQ(mod2.const_ratio()[0], mod.const_ratio()[0]); + ASSERT_EQ(mod2.const_ratio()[1], mod.const_ratio()[1]); + ASSERT_EQ(mod2.const_ratio()[2], mod.const_ratio()[2]); + + mod = 3; + mod.save(stream); + mod2.load(stream); + ASSERT_EQ(mod2.value(), mod.value()); + ASSERT_EQ(mod2.bit_count(), mod.bit_count()); + ASSERT_EQ(mod2.uint64_count(), mod.uint64_count()); + ASSERT_EQ(mod2.const_ratio()[0], mod.const_ratio()[0]); + ASSERT_EQ(mod2.const_ratio()[1], mod.const_ratio()[1]); + ASSERT_EQ(mod2.const_ratio()[2], mod.const_ratio()[2]); + + mod = 0xF00000F00000F; + mod.save(stream); + mod2.load(stream); + ASSERT_EQ(mod2.value(), mod.value()); + ASSERT_EQ(mod2.bit_count(), mod.bit_count()); + ASSERT_EQ(mod2.uint64_count(), mod.uint64_count()); + ASSERT_EQ(mod2.const_ratio()[0], mod.const_ratio()[0]); + ASSERT_EQ(mod2.const_ratio()[1], mod.const_ratio()[1]); + ASSERT_EQ(mod2.const_ratio()[2], mod.const_ratio()[2]); + } +} diff --git a/tests/seal/testrunner.cpp b/tests/seal/testrunner.cpp new file mode 100644 index 000000000..a7e7583bb --- /dev/null +++ b/tests/seal/testrunner.cpp @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" + +/** +Main entry point for Google Test unit tests. +*/ +int main(int argc, char** argv) +{ + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tests/seal/util/CMakeLists.txt b/tests/seal/util/CMakeLists.txt new file mode 100644 index 000000000..af5ae9f44 --- /dev/null +++ b/tests/seal/util/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +target_sources(sealtest + PRIVATE + ${CMAKE_CURRENT_LIST_DIR}/clipnormal.cpp + ${CMAKE_CURRENT_LIST_DIR}/common.cpp + ${CMAKE_CURRENT_LIST_DIR}/hash.cpp + ${CMAKE_CURRENT_LIST_DIR}/locks.cpp + ${CMAKE_CURRENT_LIST_DIR}/mempool.cpp + ${CMAKE_CURRENT_LIST_DIR}/numth.cpp + ${CMAKE_CURRENT_LIST_DIR}/polyarith.cpp + ${CMAKE_CURRENT_LIST_DIR}/polyarithmod.cpp + ${CMAKE_CURRENT_LIST_DIR}/polyarithsmallmod.cpp + ${CMAKE_CURRENT_LIST_DIR}/polycore.cpp + ${CMAKE_CURRENT_LIST_DIR}/randomtostd.cpp + ${CMAKE_CURRENT_LIST_DIR}/smallntt.cpp + ${CMAKE_CURRENT_LIST_DIR}/stringtouint64.cpp + ${CMAKE_CURRENT_LIST_DIR}/uint64tostring.cpp + ${CMAKE_CURRENT_LIST_DIR}/uintarith.cpp + ${CMAKE_CURRENT_LIST_DIR}/uintarithmod.cpp + ${CMAKE_CURRENT_LIST_DIR}/uintarithsmallmod.cpp + ${CMAKE_CURRENT_LIST_DIR}/uintcore.cpp +) diff --git a/tests/seal/util/clipnormal.cpp b/tests/seal/util/clipnormal.cpp new file mode 100644 index 000000000..f17f45c41 --- /dev/null +++ b/tests/seal/util/clipnormal.cpp @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/randomgen.h" +#include "seal/util/randomtostd.h" +#include "seal/util/clipnormal.h" +#include +#include + +using namespace seal::util; +using namespace seal; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST(ClipNormal, ClipNormalGenerate) + { + shared_ptr generator(UniformRandomGeneratorFactory::default_factory()->create()); + RandomToStandardAdapter rand(generator); + ClippedNormalDistribution dist(50.0, 10.0, 20.0); + + ASSERT_EQ(50.0, dist.mean()); + ASSERT_EQ(10.0, dist.standard_deviation()); + ASSERT_EQ(20.0, dist.max_deviation()); + ASSERT_EQ(30.0, dist.min()); + ASSERT_EQ(70.0, dist.max()); + double average = 0; + double stddev = 0; + for (int i = 0; i < 100; ++i) + { + double value = dist(rand); + average += value; + stddev += (value - 50.0) * (value - 50.0); + ASSERT_TRUE(value >= 30.0 && value <= 70.0); + } + average /= 100; + stddev /= 100; + stddev = sqrt(stddev); + ASSERT_TRUE(average >= 40.0 && average <= 60.0); + ASSERT_TRUE(stddev >= 5.0 && stddev <= 15.0); + } + } +} diff --git a/tests/seal/util/common.cpp b/tests/seal/util/common.cpp new file mode 100644 index 000000000..7e0ef9b3e --- /dev/null +++ b/tests/seal/util/common.cpp @@ -0,0 +1,282 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/common.h" +#include + +using namespace seal; +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST(Common, Constants) + { + ASSERT_EQ(4, bits_per_nibble); + ASSERT_EQ(8, bits_per_byte); + ASSERT_EQ(4, bytes_per_uint32); + ASSERT_EQ(8, bytes_per_uint64); + ASSERT_EQ(32, bits_per_uint32); + ASSERT_EQ(64, bits_per_uint64); + ASSERT_EQ(2, nibbles_per_byte); + ASSERT_EQ(2, uint32_per_uint64); + ASSERT_EQ(16, nibbles_per_uint64); + ASSERT_EQ(static_cast(INT64_MAX) + 1, uint64_high_bit); + } + + TEST(Common, UnsignedComparisons) + { + int pos_i = 5; + int neg_i = -5; + unsigned pos_u = 6; + signed pos_s = 6; + unsigned char pos_uc = 1; + char neg_c = -1; + char pos_c = 1; + unsigned char pos_uc_max = 0xFF; + unsigned long long pos_ull = 1; + unsigned long long pos_ull_max = 0xFFFFFFFFFFFFFFFF; + long long neg_ull = -1; + + ASSERT_TRUE(unsigned_eq(pos_i, pos_i)); + ASSERT_FALSE(unsigned_eq(pos_i, neg_i)); + ASSERT_TRUE(unsigned_gt(pos_u, pos_i)); + ASSERT_TRUE(unsigned_lt(pos_i, neg_i)); + ASSERT_TRUE(unsigned_geq(pos_u, pos_s)); + ASSERT_TRUE(unsigned_gt(neg_c, pos_c)); + ASSERT_TRUE(unsigned_geq(neg_c, pos_c)); + ASSERT_FALSE(unsigned_eq(neg_c, pos_c)); + ASSERT_FALSE(unsigned_gt(pos_u, neg_c)); + ASSERT_TRUE(unsigned_eq(pos_uc, pos_c)); + ASSERT_TRUE(unsigned_geq(pos_uc, pos_c)); + ASSERT_TRUE(unsigned_leq(pos_uc, pos_c)); + ASSERT_TRUE(unsigned_lt(pos_uc_max, neg_c)); + ASSERT_TRUE(unsigned_eq(neg_c, pos_ull_max)); + ASSERT_TRUE(unsigned_eq(neg_ull, pos_ull_max)); + ASSERT_FALSE(unsigned_lt(neg_ull, pos_ull_max)); + ASSERT_TRUE(unsigned_lt(pos_ull, pos_ull_max)); + } + + TEST(Common, SafeArithmetic) + { + int pos_i = 5; + int neg_i = -5; + unsigned pos_u = 6; + unsigned char pos_uc_max = 0xFF; + unsigned long long pos_ull_max = 0xFFFFFFFFFFFFFFFF; + long long neg_ull = -1; + + ASSERT_EQ(25, mul_safe(pos_i, pos_i)); + ASSERT_EQ(25, mul_safe(neg_i, neg_i)); + ASSERT_EQ(10, add_safe(pos_i, pos_i)); + ASSERT_EQ(-10, add_safe(neg_i, neg_i)); + ASSERT_EQ(0, add_safe(pos_i, neg_i)); + ASSERT_EQ(0, add_safe(neg_i, pos_i)); + ASSERT_EQ(10, sub_safe(pos_i, neg_i)); + ASSERT_EQ(-10, sub_safe(neg_i, pos_i)); + ASSERT_EQ(unsigned(0), sub_safe(pos_u, pos_u)); + ASSERT_THROW(sub_safe(unsigned(0), pos_u), out_of_range); + ASSERT_THROW(sub_safe(unsigned(4), pos_u), out_of_range); + ASSERT_THROW(add_safe(pos_uc_max, (unsigned char)1), out_of_range); + ASSERT_TRUE(pos_uc_max == add_safe(pos_uc_max, (unsigned char)0)); + ASSERT_THROW(mul_safe(pos_ull_max, pos_ull_max), out_of_range); + ASSERT_EQ(0ULL, mul_safe(0ULL, pos_ull_max)); + ASSERT_TRUE((long long)1 == mul_safe(neg_ull, neg_ull)); + ASSERT_THROW(mul_safe(pos_uc_max, pos_uc_max), out_of_range); + ASSERT_EQ(15, add_safe(pos_i, -pos_i, pos_i, pos_i, pos_i)); + ASSERT_EQ(6, add_safe(0, -pos_i, pos_i, 1, pos_i)); + ASSERT_EQ(0, mul_safe(pos_i, pos_i, pos_i, 0, pos_i)); + ASSERT_EQ(625, mul_safe(pos_i, pos_i, pos_i, pos_i)); + ASSERT_THROW(mul_safe( + pos_i, pos_i, pos_i, pos_i, pos_i, pos_i, pos_i, + pos_i, pos_i, pos_i, pos_i, pos_i, pos_i, pos_i), out_of_range); + } + + TEST(Common, FitsIn) + { + int neg_i = -5; + signed pos_s = 6; + unsigned char pos_uc = 1; + unsigned char pos_uc_max = 0xFF; + float f = 1.234f; + double d = -1234; + + ASSERT_TRUE(fits_in(pos_s)); + ASSERT_TRUE(fits_in(pos_uc)); + ASSERT_FALSE(fits_in(neg_i)); + ASSERT_FALSE(fits_in(pos_uc_max)); + ASSERT_TRUE(fits_in(d)); + ASSERT_TRUE(fits_in(f)); + ASSERT_TRUE(fits_in(d)); + ASSERT_TRUE(fits_in(f)); + ASSERT_FALSE(fits_in(d)); + } + + TEST(Common, DivideRoundUp) + { + ASSERT_EQ(0, divide_round_up(0, 4)); + ASSERT_EQ(1, divide_round_up(1, 4)); + ASSERT_EQ(1, divide_round_up(2, 4)); + ASSERT_EQ(1, divide_round_up(3, 4)); + ASSERT_EQ(1, divide_round_up(4, 4)); + ASSERT_EQ(2, divide_round_up(5, 4)); + ASSERT_EQ(2, divide_round_up(6, 4)); + ASSERT_EQ(2, divide_round_up(7, 4)); + ASSERT_EQ(2, divide_round_up(8, 4)); + ASSERT_EQ(3, divide_round_up(9, 4)); + ASSERT_EQ(3, divide_round_up(12, 4)); + ASSERT_EQ(4, divide_round_up(13, 4)); + } + + TEST(Common, GetUInt64Byte) + { + uint64_t number[2]; + number[0] = 0x3456789ABCDEF121; + number[1] = 0x23456789ABCDEF12; + ASSERT_TRUE(SEAL_BYTE(0x21) == *get_uint64_byte(number, 0)); + ASSERT_TRUE(SEAL_BYTE(0xF1) == *get_uint64_byte(number, 1)); + ASSERT_TRUE(SEAL_BYTE(0xDE) == *get_uint64_byte(number, 2)); + ASSERT_TRUE(SEAL_BYTE(0xBC) == *get_uint64_byte(number, 3)); + ASSERT_TRUE(SEAL_BYTE(0x9A) == *get_uint64_byte(number, 4)); + ASSERT_TRUE(SEAL_BYTE(0x78) == *get_uint64_byte(number, 5)); + ASSERT_TRUE(SEAL_BYTE(0x56) == *get_uint64_byte(number, 6)); + ASSERT_TRUE(SEAL_BYTE(0x34) == *get_uint64_byte(number, 7)); + ASSERT_TRUE(SEAL_BYTE(0x12) == *get_uint64_byte(number, 8)); + ASSERT_TRUE(SEAL_BYTE(0xEF) == *get_uint64_byte(number, 9)); + ASSERT_TRUE(SEAL_BYTE(0xCD) == *get_uint64_byte(number, 10)); + ASSERT_TRUE(SEAL_BYTE(0xAB) == *get_uint64_byte(number, 11)); + ASSERT_TRUE(SEAL_BYTE(0x89) == *get_uint64_byte(number, 12)); + ASSERT_TRUE(SEAL_BYTE(0x67) == *get_uint64_byte(number, 13)); + ASSERT_TRUE(SEAL_BYTE(0x45) == *get_uint64_byte(number, 14)); + ASSERT_TRUE(SEAL_BYTE(0x23) == *get_uint64_byte(number, 15)); + } + + TEST(Common, ReverseBits32) + { + ASSERT_EQ(static_cast(0), reverse_bits(static_cast(0))); + ASSERT_EQ(static_cast(0x80000000), reverse_bits(static_cast(1))); + ASSERT_EQ(static_cast(0x40000000), reverse_bits(static_cast(2))); + ASSERT_EQ(static_cast(0xC0000000), reverse_bits(static_cast(3))); + ASSERT_EQ(static_cast(0x00010000), reverse_bits(static_cast(0x00008000))); + ASSERT_EQ(static_cast(0xFFFF0000), reverse_bits(static_cast(0x0000FFFF))); + ASSERT_EQ(static_cast(0x0000FFFF), reverse_bits(static_cast(0xFFFF0000))); + ASSERT_EQ(static_cast(0x00008000), reverse_bits(static_cast(0x00010000))); + ASSERT_EQ(static_cast(3), reverse_bits(static_cast(0xC0000000))); + ASSERT_EQ(static_cast(2), reverse_bits(static_cast(0x40000000))); + ASSERT_EQ(static_cast(1), reverse_bits(static_cast(0x80000000))); + ASSERT_EQ(static_cast(0xFFFFFFFF), reverse_bits(static_cast(0xFFFFFFFF))); + + // Reversing a 0-bit item should return 0 + ASSERT_EQ(static_cast(0), reverse_bits(static_cast(0xFFFFFFFF), 0)); + + // Reversing a 32-bit item returns is same as normal reverse + ASSERT_EQ(static_cast(0), reverse_bits(static_cast(0), 32)); + ASSERT_EQ(static_cast(0x80000000), reverse_bits(static_cast(1), 32)); + ASSERT_EQ(static_cast(0x40000000), reverse_bits(static_cast(2), 32)); + ASSERT_EQ(static_cast(0xC0000000), reverse_bits(static_cast(3), 32)); + ASSERT_EQ(static_cast(0x00010000), reverse_bits(static_cast(0x00008000), 32)); + ASSERT_EQ(static_cast(0xFFFF0000), reverse_bits(static_cast(0x0000FFFF), 32)); + ASSERT_EQ(static_cast(0x0000FFFF), reverse_bits(static_cast(0xFFFF0000), 32)); + ASSERT_EQ(static_cast(0x00008000), reverse_bits(static_cast(0x00010000), 32)); + ASSERT_EQ(static_cast(3), reverse_bits(static_cast(0xC0000000), 32)); + ASSERT_EQ(static_cast(2), reverse_bits(static_cast(0x40000000), 32)); + ASSERT_EQ(static_cast(1), reverse_bits(static_cast(0x80000000), 32)); + ASSERT_EQ(static_cast(0xFFFFFFFF), reverse_bits(static_cast(0xFFFFFFFF), 32)); + + // 16-bit reversal + ASSERT_EQ(static_cast(0), reverse_bits(static_cast(0), 16)); + ASSERT_EQ(static_cast(0x00008000), reverse_bits(static_cast(1), 16)); + ASSERT_EQ(static_cast(0x00004000), reverse_bits(static_cast(2), 16)); + ASSERT_EQ(static_cast(0x0000C000), reverse_bits(static_cast(3), 16)); + ASSERT_EQ(static_cast(0x00000001), reverse_bits(static_cast(0x00008000), 16)); + ASSERT_EQ(static_cast(0x0000FFFF), reverse_bits(static_cast(0x0000FFFF), 16)); + ASSERT_EQ(static_cast(0x00000000), reverse_bits(static_cast(0xFFFF0000), 16)); + ASSERT_EQ(static_cast(0x00000000), reverse_bits(static_cast(0x00010000), 16)); + ASSERT_EQ(static_cast(3), reverse_bits(static_cast(0x0000C000), 16)); + ASSERT_EQ(static_cast(2), reverse_bits(static_cast(0x00004000), 16)); + ASSERT_EQ(static_cast(1), reverse_bits(static_cast(0x00008000), 16)); + ASSERT_EQ(static_cast(0x0000FFFF), reverse_bits(static_cast(0xFFFFFFFF), 16)); + } + + TEST(Common, ReverseBits64) + { + ASSERT_EQ(0ULL, reverse_bits(0ULL)); + ASSERT_EQ(1ULL << 63, reverse_bits(1ULL)); + ASSERT_EQ(1ULL << 32, reverse_bits(1ULL << 31)); + ASSERT_EQ(0xFFFFULL << 32, reverse_bits(0xFFFFULL << 16)); + ASSERT_EQ(0x0000FFFFFFFF0000ULL, reverse_bits(0x0000FFFFFFFF0000ULL)); + ASSERT_EQ(0x0000FFFF0000FFFFULL, reverse_bits(0xFFFF0000FFFF0000ULL)); + + ASSERT_EQ(0ULL, reverse_bits(0ULL, 0)); + ASSERT_EQ(0ULL, reverse_bits(0ULL, 1)); + ASSERT_EQ(0ULL, reverse_bits(0ULL, 32)); + ASSERT_EQ(0ULL, reverse_bits(0ULL, 64)); + + ASSERT_EQ(0ULL, reverse_bits(1ULL, 0)); + ASSERT_EQ(1ULL, reverse_bits(1ULL, 1)); + ASSERT_EQ(1ULL << 31, reverse_bits(1ULL, 32)); + ASSERT_EQ(1ULL << 63, reverse_bits(1ULL, 64)); + + ASSERT_EQ(0ULL, reverse_bits(1ULL << 31, 0)); + ASSERT_EQ(0ULL, reverse_bits(1ULL << 31, 1)); + ASSERT_EQ(1ULL, reverse_bits(1ULL << 31, 32)); + ASSERT_EQ(1ULL << 32, reverse_bits(1ULL << 31, 64)); + + ASSERT_EQ(0ULL, reverse_bits(0xFFFFULL << 16, 0)); + ASSERT_EQ(0ULL, reverse_bits(0xFFFFULL << 16, 1)); + ASSERT_EQ(0xFFFFULL, reverse_bits(0xFFFFULL << 16, 32)); + ASSERT_EQ(0xFFFFULL << 32, reverse_bits(0xFFFFULL << 16, 64)); + + ASSERT_EQ(0ULL, reverse_bits(0x0000FFFFFFFF0000ULL, 0)); + ASSERT_EQ(0ULL, reverse_bits(0x0000FFFFFFFF0000ULL, 1)); + ASSERT_EQ(0xFFFFULL, reverse_bits(0x0000FFFFFFFF0000ULL, 32)); + ASSERT_EQ(0x0000FFFFFFFF0000ULL, reverse_bits(0x0000FFFFFFFF0000ULL, 64)); + + ASSERT_EQ(0ULL, reverse_bits(0xFFFF0000FFFF0000ULL, 0)); + ASSERT_EQ(0ULL, reverse_bits(0xFFFF0000FFFF0000ULL, 1)); + ASSERT_EQ(0xFFFFULL, reverse_bits(0xFFFF0000FFFF0000ULL, 32)); + ASSERT_EQ(0x0000FFFF0000FFFFULL, reverse_bits(0xFFFF0000FFFF0000ULL, 64)); + } + + TEST(Common, GetSignificantBitCount) + { + ASSERT_EQ(0, get_significant_bit_count(0)); + ASSERT_EQ(1, get_significant_bit_count(1)); + ASSERT_EQ(2, get_significant_bit_count(2)); + ASSERT_EQ(2, get_significant_bit_count(3)); + ASSERT_EQ(3, get_significant_bit_count(4)); + ASSERT_EQ(3, get_significant_bit_count(5)); + ASSERT_EQ(3, get_significant_bit_count(6)); + ASSERT_EQ(3, get_significant_bit_count(7)); + ASSERT_EQ(4, get_significant_bit_count(8)); + ASSERT_EQ(63, get_significant_bit_count(0x7000000000000000)); + ASSERT_EQ(63, get_significant_bit_count(0x7FFFFFFFFFFFFFFF)); + ASSERT_EQ(64, get_significant_bit_count(0x8000000000000000)); + ASSERT_EQ(64, get_significant_bit_count(0xFFFFFFFFFFFFFFFF)); + } + + TEST(Common, GetMSBIndexGeneric) + { + unsigned long result; + get_msb_index_generic(&result, 1); + ASSERT_EQ(static_cast(0), result); + get_msb_index_generic(&result, 2); + ASSERT_EQ(static_cast(1), result); + get_msb_index_generic(&result, 3); + ASSERT_EQ(static_cast(1), result); + get_msb_index_generic(&result, 4); + ASSERT_EQ(static_cast(2), result); + get_msb_index_generic(&result, 16); + ASSERT_EQ(static_cast(4), result); + get_msb_index_generic(&result, 0xFFFFFFFF); + ASSERT_EQ(static_cast(31), result); + get_msb_index_generic(&result, 0x100000000); + ASSERT_EQ(static_cast(32), result); + get_msb_index_generic(&result, 0xFFFFFFFFFFFFFFFF); + ASSERT_EQ(static_cast(63), result); + } + } +} diff --git a/tests/seal/util/hash.cpp b/tests/seal/util/hash.cpp new file mode 100644 index 000000000..f5cb3b121 --- /dev/null +++ b/tests/seal/util/hash.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/hash.h" +#include + +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST(HashTest, SHA3Hash) + { + uint64_t input[3]{ 0, 0, 0 }; + HashFunction::sha3_block_type hash1, hash2; + HashFunction::sha3_hash(0, hash1); + + HashFunction::sha3_hash(input, 0, hash2); + ASSERT_TRUE(hash1 != hash2); + + HashFunction::sha3_hash(input, 1, hash2); + ASSERT_TRUE(hash1 == hash2); + + HashFunction::sha3_hash(input, 2, hash2); + ASSERT_TRUE(hash1 != hash2); + + HashFunction::sha3_hash(0x123456, hash1); + HashFunction::sha3_hash(0x023456, hash2); + ASSERT_TRUE(hash1 != hash2); + + input[0] = 0x123456; + input[1] = 1; + HashFunction::sha3_hash(0x123456, hash1); + HashFunction::sha3_hash(input, 2, hash2); + ASSERT_TRUE(hash1 != hash2); + } + } +} diff --git a/tests/seal/util/jsonparser.cpp b/tests/seal/util/jsonparser.cpp new file mode 100644 index 000000000..f82f2c393 --- /dev/null +++ b/tests/seal/util/jsonparser.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "CppUnitTest.h" +#include "seal/util/jsonparser.h" +#include + +using namespace Microsoft::VisualStudio::CppUnitTestFramework; +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST_CLASS(JsonParser) + { + public: + TEST_METHOD(StripWhitespace) + { + string json(" { \"name1\" : \"value1\", \"name2\" : \"value2\", \"array1\" : [ \"hello\", \"world\" ], \"object1\" : { \"subname1\" : \"subvalue1\" } } "); + string expected("{\"name1\":\"value1\",\"name2\":\"value2\",\"array1\":[\"hello\",\"world\"],\"object1\":{\"subname1\":\"subvalue1\"}}"); + Assert::AreEqual(expected, stripWhitespace(json)); + } + + TEST_METHOD(JsonParse) + { + string json("{\"name1\":\"value1\",\"name2\":\"value2\",\"array1\":[\"hello\",\"world\"],\"object1\":{\"subname1\":\"subvalue1\"}}"); + auto result = parseJSON(json); + } + }; + } +} diff --git a/tests/seal/util/locks.cpp b/tests/seal/util/locks.cpp new file mode 100644 index 000000000..6fdd48b10 --- /dev/null +++ b/tests/seal/util/locks.cpp @@ -0,0 +1,309 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/locks.h" +#include +#include + +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + namespace util + { + class Reader + { + public: + Reader(ReaderWriterLocker &locker) : locker_(locker), locked_(false), trying_(false) + { + } + + bool is_locked() const + { + return locked_; + } + + bool is_trying_to_lock() const + { + return trying_; + } + + void acquire_read() + { + trying_ = true; + lock_ = locker_.acquire_read(); + locked_ = true; + trying_ = false; + } + + void release() + { + lock_.unlock(); + locked_ = false; + } + + void wait_until_trying() + { + while (!trying_); + } + + void wait_until_locked() + { + while (!locked_); + } + + private: + ReaderWriterLocker &locker_; + + ReaderLock lock_; + + volatile bool locked_; + + volatile bool trying_; + }; + + class Writer + { + public: + Writer(ReaderWriterLocker &locker) : locker_(locker), locked_(false), trying_(false) + { + } + + bool is_locked() const + { + return locked_; + } + + bool is_trying_to_lock() const + { + return trying_; + } + + void acquire_write() + { + trying_ = true; + lock_ = locker_.acquire_write(); + locked_ = true; + trying_ = false; + } + + void release() + { + lock_.unlock(); + locked_ = false; + } + + void wait_until_trying() + { + while (!trying_); + } + + void wait_until_locked() + { + while (!locked_); + } + + void wait_until_unlocked() + { + while(locked_); + } + + private: + ReaderWriterLocker &locker_; + + WriterLock lock_; + + volatile bool locked_; + + volatile bool trying_; + }; + + TEST(ReaderWriterLockerTests, ReaderWriterLockNonBlocking) + { + ReaderWriterLocker locker; + + WriterLock writeLock = locker.acquire_write(); + ASSERT_TRUE(writeLock.owns_lock()); + writeLock.unlock(); + ASSERT_FALSE(writeLock.owns_lock()); + + ReaderLock readLock = locker.acquire_read(); + ASSERT_TRUE(readLock.owns_lock()); + readLock.unlock(); + + ReaderLock readLock2 = locker.acquire_read(); + ASSERT_TRUE(readLock2.owns_lock()); + ASSERT_FALSE(readLock.owns_lock()); + readLock2.unlock(); + ASSERT_FALSE(readLock2.owns_lock()); + + readLock = locker.try_acquire_read(); + ASSERT_TRUE(readLock.owns_lock()); + writeLock = locker.try_acquire_write(); + ASSERT_FALSE(writeLock.owns_lock()); + + readLock2 = locker.try_acquire_read(); + ASSERT_TRUE(readLock2.owns_lock()); + writeLock = locker.try_acquire_write(); + ASSERT_FALSE(writeLock.owns_lock()); + + readLock.unlock(); + writeLock = locker.try_acquire_write(); + ASSERT_FALSE(writeLock.owns_lock()); + + readLock2.unlock(); + writeLock = locker.try_acquire_write(); + ASSERT_TRUE(writeLock.owns_lock()); + + WriterLock writeLock2 = locker.try_acquire_write(); + + ASSERT_FALSE(writeLock2.owns_lock()); + readLock2 = locker.try_acquire_read(); + ASSERT_FALSE(readLock2.owns_lock()); + + writeLock.unlock(); + + writeLock2 = locker.try_acquire_write(); + ASSERT_TRUE(writeLock2.owns_lock()); + readLock2 = locker.try_acquire_read(); + ASSERT_FALSE(readLock2.owns_lock()); + + writeLock2.unlock(); + } + + TEST(ReaderWriterLockerTests, ReaderWriterLockBlocking) + { + ReaderWriterLocker locker; + + Reader *reader1 = new Reader(locker); + Reader *reader2 = new Reader(locker); + Writer *writer1 = new Writer(locker); + Writer *writer2 = new Writer(locker); + + ASSERT_FALSE(reader1->is_locked()); + ASSERT_FALSE(reader2->is_locked()); + ASSERT_FALSE(writer1->is_locked()); + ASSERT_FALSE(writer2->is_locked()); + + reader1->acquire_read(); + ASSERT_TRUE(reader1->is_locked()); + ASSERT_FALSE(reader2->is_locked()); + reader2->acquire_read(); + ASSERT_TRUE(reader1->is_locked()); + ASSERT_TRUE(reader2->is_locked()); + + atomic should_unlock1{ false }; + atomic should_unlock2{ false }; + + thread writer1_thread([&] { + writer1->acquire_write(); + while (!should_unlock1) + { + this_thread::sleep_for(10ms); + } + writer1->release(); + }); + + writer1->wait_until_trying(); + ASSERT_TRUE(writer1->is_trying_to_lock()); + ASSERT_FALSE(writer1->is_locked()); + + reader2->release(); + ASSERT_TRUE(reader1->is_locked()); + ASSERT_FALSE(reader2->is_locked()); + ASSERT_TRUE(writer1->is_trying_to_lock()); + ASSERT_FALSE(writer1->is_locked()); + + thread writer2_thread([&] { + writer2->acquire_write(); + while (!should_unlock2) + { + this_thread::sleep_for(10ms); + } + writer2->release(); + }); + + writer2->wait_until_trying(); + ASSERT_TRUE(writer1->is_trying_to_lock()); + ASSERT_FALSE(writer1->is_locked()); + ASSERT_TRUE(writer2->is_trying_to_lock()); + ASSERT_FALSE(writer2->is_locked()); + + reader1->release(); + ASSERT_FALSE(reader1->is_locked()); + + while (writer1->is_trying_to_lock() && writer2->is_trying_to_lock()); + + Writer *winner; + Writer *waiting; + atomic* should_unlock_winner; + atomic* should_unlock_waiting; + + if (writer1->is_locked()) + { + winner = writer1; + waiting = writer2; + should_unlock_winner = &should_unlock1; + should_unlock_waiting = &should_unlock2; + } + else + { + winner = writer2; + waiting = writer1; + should_unlock_winner = &should_unlock2; + should_unlock_waiting = &should_unlock1; + } + + ASSERT_TRUE(winner->is_locked()); + ASSERT_FALSE(waiting->is_locked()); + + *should_unlock_winner = true; + winner->wait_until_unlocked(); + ASSERT_FALSE(winner->is_locked()); + + waiting->wait_until_locked(); + ASSERT_TRUE(waiting->is_locked()); + + thread reader1_thread(&Reader::acquire_read, reader1); + reader1->wait_until_trying(); + ASSERT_TRUE(reader1->is_trying_to_lock()); + ASSERT_FALSE(reader1->is_locked()); + + thread reader2_thread(&Reader::acquire_read, reader2); + reader2->wait_until_trying(); + ASSERT_TRUE(reader2->is_trying_to_lock()); + ASSERT_FALSE(reader2->is_locked()); + + *should_unlock_waiting = true; + + reader1->wait_until_locked(); + reader2->wait_until_locked(); + ASSERT_TRUE(reader1->is_locked()); + ASSERT_TRUE(reader2->is_locked()); + + reader1->release(); + reader2->release(); + + ASSERT_FALSE(reader1->is_locked()); + ASSERT_FALSE(reader2->is_locked()); + ASSERT_FALSE(writer1->is_locked()); + ASSERT_FALSE(reader2->is_locked()); + ASSERT_FALSE(reader1->is_trying_to_lock()); + ASSERT_FALSE(reader2->is_trying_to_lock()); + ASSERT_FALSE(writer1->is_trying_to_lock()); + ASSERT_FALSE(reader2->is_trying_to_lock()); + + writer1_thread.join(); + writer2_thread.join(); + reader1_thread.join(); + reader2_thread.join(); + + delete reader1; + delete reader2; + delete writer1; + delete writer2; + } + } +} diff --git a/tests/seal/util/mempool.cpp b/tests/seal/util/mempool.cpp new file mode 100644 index 000000000..9c34e60ff --- /dev/null +++ b/tests/seal/util/mempool.cpp @@ -0,0 +1,674 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/mempool.h" +#include "seal/util/pointer.h" +#include "seal/util/common.h" +#include + +using namespace seal; +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST(MemoryPoolTests, TestMemoryPoolMT) + { + { + MemoryPoolMT pool; + ASSERT_TRUE(0LL == pool.pool_count()); + + Pointer pointer{ pool.get_for_byte_count(bytes_per_uint64 * 0) }; + ASSERT_FALSE(pointer.is_set()); + pointer.release(); + ASSERT_TRUE(0LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(bytes_per_uint64 * 2); + uint64_t *allocation1 = pointer.get(); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(1LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(bytes_per_uint64 * 2); + ASSERT_TRUE(allocation1 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(1LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(bytes_per_uint64 * 1); + ASSERT_FALSE(allocation1 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(2LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(bytes_per_uint64 * 2); + ASSERT_TRUE(allocation1 == pointer.get()); + Pointer pointer2 = pool.get_for_byte_count(bytes_per_uint64 * 2); + uint64_t *allocation2 = pointer2.get(); + ASSERT_FALSE(allocation2 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + pointer2.release(); + ASSERT_TRUE(2LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(bytes_per_uint64 * 2); + ASSERT_TRUE(allocation2 == pointer.get()); + pointer2 = pool.get_for_byte_count(bytes_per_uint64 * 2); + ASSERT_TRUE(allocation1 == pointer2.get()); + Pointer pointer3 = pool.get_for_byte_count(bytes_per_uint64 * 1); + pointer.release(); + pointer2.release(); + pointer3.release(); + ASSERT_TRUE(2LL == pool.pool_count()); + + Pointer pointer4 = pool.get_for_byte_count(1); + Pointer pointer5 = pool.get_for_byte_count(2); + Pointer pointer6 = pool.get_for_byte_count(1); + pointer4.release(); + pointer5.release(); + pointer6.release(); + ASSERT_TRUE(4LL == pool.pool_count()); + } + { + MemoryPoolMT pool; + ASSERT_TRUE(0LL == pool.pool_count()); + + Pointer pointer{ pool.get_for_byte_count(4 * 0) }; + ASSERT_FALSE(pointer.is_set()); + pointer.release(); + ASSERT_TRUE(0LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(4 * 2); + int *allocation1 = pointer.get(); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(1LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(4 * 2); + ASSERT_TRUE(allocation1 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(1LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(4 * 1); + ASSERT_FALSE(allocation1 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(2LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(4 * 2); + ASSERT_TRUE(allocation1 == pointer.get()); + Pointer pointer2 = pool.get_for_byte_count(4 * 2); + int *allocation2 = pointer2.get(); + ASSERT_FALSE(allocation2 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + pointer2.release(); + ASSERT_TRUE(2LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(4 * 2); + ASSERT_TRUE(allocation2 == pointer.get()); + pointer2 = pool.get_for_byte_count(4 * 2); + ASSERT_TRUE(allocation1 == pointer2.get()); + Pointer pointer3 = pool.get_for_byte_count(4 * 1); + pointer.release(); + pointer2.release(); + pointer3.release(); + ASSERT_TRUE(2LL == pool.pool_count()); + + Pointer pointer4 = pool.get_for_byte_count(1); + Pointer pointer5 = pool.get_for_byte_count(2); + Pointer pointer6 = pool.get_for_byte_count(1); + pointer4.release(); + pointer5.release(); + pointer6.release(); + ASSERT_TRUE(4LL == pool.pool_count()); + } + { + MemoryPoolMT pool; + ASSERT_TRUE(0LL == pool.pool_count()); + + Pointer pointer = pool.get_for_byte_count(0); + ASSERT_FALSE(pointer.is_set()); + pointer.release(); + ASSERT_TRUE(0LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(2); + SEAL_BYTE *allocation1 = pointer.get(); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(1LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(2); + ASSERT_TRUE(allocation1 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(1LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(1); + ASSERT_FALSE(allocation1 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(2LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(2); + ASSERT_TRUE(allocation1 == pointer.get()); + Pointer pointer2 = pool.get_for_byte_count(2); + SEAL_BYTE *allocation2 = pointer2.get(); + ASSERT_FALSE(allocation2 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + pointer2.release(); + ASSERT_TRUE(2LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(2); + ASSERT_TRUE(allocation2 == pointer.get()); + pointer2 = pool.get_for_byte_count(2); + ASSERT_TRUE(allocation1 == pointer2.get()); + Pointer pointer3 = pool.get_for_byte_count(1); + pointer.release(); + pointer2.release(); + pointer3.release(); + ASSERT_TRUE(2LL == pool.pool_count()); + } + } + + TEST(MemoryPoolTests, PointerTestsMT) + { + MemoryPool &pool = *global_variables::global_memory_pool; + { + Pointer p1; + ASSERT_FALSE(p1.is_set()); + ASSERT_TRUE(p1.get() == nullptr); + + p1 = pool.get_for_byte_count(bytes_per_uint64 * 1); + uint64_t *allocation1 = p1.get(); + ASSERT_TRUE(p1.is_set()); + ASSERT_TRUE(p1.get() != nullptr); + + p1.release(); + ASSERT_FALSE(p1.is_set()); + ASSERT_TRUE(p1.get() == nullptr); + + p1 = pool.get_for_byte_count(bytes_per_uint64 * 1); + ASSERT_TRUE(p1.is_set()); + ASSERT_TRUE(p1.get() == allocation1); + + Pointer p2; + p2.acquire(p1); + ASSERT_FALSE(p1.is_set()); + ASSERT_TRUE(p2.is_set()); + ASSERT_TRUE(p2.get() == allocation1); + + ConstPointer cp2; + cp2.acquire(p2); + ASSERT_FALSE(p2.is_set()); + ASSERT_TRUE(cp2.is_set()); + ASSERT_TRUE(cp2.get() == allocation1); + cp2.release(); + + Pointer p3 = pool.get_for_byte_count(bytes_per_uint64 * 1); + ASSERT_TRUE(p3.is_set()); + ASSERT_TRUE(p3.get() == allocation1); + + Pointer p4 = pool.get_for_byte_count(bytes_per_uint64 * 2); + ASSERT_TRUE(p4.is_set()); + uint64_t *allocation2 = p4.get(); + p3.swap_with(p4); + ASSERT_TRUE(p3.is_set()); + ASSERT_TRUE(p3.get() == allocation2); + ASSERT_TRUE(p4.is_set()); + ASSERT_TRUE(p4.get() == allocation1); + p3.release(); + p4.release(); + } + { + Pointer p1; + ASSERT_FALSE(p1.is_set()); + ASSERT_TRUE(p1.get() == nullptr); + + p1 = pool.get_for_byte_count(bytes_per_uint64 * 1); + SEAL_BYTE *allocation1 = p1.get(); + ASSERT_TRUE(p1.is_set()); + ASSERT_TRUE(p1.get() != nullptr); + + p1.release(); + ASSERT_FALSE(p1.is_set()); + ASSERT_TRUE(p1.get() == nullptr); + + p1 = pool.get_for_byte_count(bytes_per_uint64 * 1); + ASSERT_TRUE(p1.is_set()); + ASSERT_TRUE(p1.get() == allocation1); + + Pointer p2; + p2.acquire(p1); + ASSERT_FALSE(p1.is_set()); + ASSERT_TRUE(p2.is_set()); + ASSERT_TRUE(p2.get() == allocation1); + + ConstPointer cp2; + cp2.acquire(p2); + ASSERT_FALSE(p2.is_set()); + ASSERT_TRUE(cp2.is_set()); + ASSERT_TRUE(cp2.get() == allocation1); + cp2.release(); + + Pointer p3 = pool.get_for_byte_count(bytes_per_uint64 * 1); + ASSERT_TRUE(p3.is_set()); + ASSERT_TRUE(p3.get() == allocation1); + + Pointer p4 = pool.get_for_byte_count(bytes_per_uint64 * 2); + ASSERT_TRUE(p4.is_set()); + SEAL_BYTE *allocation2 = p4.get(); + p3.swap_with(p4); + ASSERT_TRUE(p3.is_set()); + ASSERT_TRUE(p3.get() == allocation2); + ASSERT_TRUE(p4.is_set()); + ASSERT_TRUE(p4.get() == allocation1); + p3.release(); + p4.release(); + } + } + + TEST(MemoryPoolTests, DuplicateIfNeededMT) + { + { + unique_ptr allocation(new uint64_t[2]); + allocation[0] = 0x1234567812345678; + allocation[1] = 0x8765432187654321; + + MemoryPoolMT pool; + Pointer p1 = duplicate_if_needed(allocation.get(), 2, false, pool); + ASSERT_TRUE(p1.is_set()); + ASSERT_TRUE(p1.get() == allocation.get()); + ASSERT_TRUE(0LL == pool.pool_count()); + + p1 = duplicate_if_needed(allocation.get(), 2, true, pool); + ASSERT_TRUE(p1.is_set()); + ASSERT_FALSE(p1.get() == allocation.get()); + ASSERT_TRUE(1LL == pool.pool_count()); + ASSERT_TRUE(p1.get()[0] == 0x1234567812345678); + ASSERT_TRUE(p1.get()[1] == 0x8765432187654321); + p1.release(); + } + { + unique_ptr allocation(new int64_t[2]); + allocation[0] = 0x234567812345678; + allocation[1] = 0x765432187654321; + + MemoryPoolMT pool; + Pointer p1 = duplicate_if_needed(allocation.get(), 2, false, pool); + ASSERT_TRUE(p1.is_set()); + ASSERT_TRUE(p1.get() == allocation.get()); + ASSERT_TRUE(0LL == pool.pool_count()); + + p1 = duplicate_if_needed(allocation.get(), 2, true, pool); + ASSERT_TRUE(p1.is_set()); + ASSERT_FALSE(p1.get() == allocation.get()); + ASSERT_TRUE(1LL == pool.pool_count()); + ASSERT_TRUE(p1.get()[0] == 0x234567812345678); + ASSERT_TRUE(p1.get()[1] == 0x765432187654321); + p1.release(); + } + { + unique_ptr allocation(new int[2]); + allocation[0] = 0x123; + allocation[1] = 0x876; + + MemoryPoolMT pool; + Pointer p1 = duplicate_if_needed(allocation.get(), 2, false, pool); + ASSERT_TRUE(p1.is_set()); + ASSERT_TRUE(p1.get() == allocation.get()); + ASSERT_TRUE(0LL == pool.pool_count()); + + p1 = duplicate_if_needed(allocation.get(), 2, true, pool); + ASSERT_TRUE(p1.is_set()); + ASSERT_FALSE(p1.get() == allocation.get()); + ASSERT_TRUE(1LL == pool.pool_count()); + ASSERT_TRUE(p1.get()[0] == 0x123); + ASSERT_TRUE(p1.get()[1] == 0x876); + p1.release(); + } + } + + TEST(MemoryPoolTests, TestMemoryPoolST) + { + { + MemoryPoolMT pool; + ASSERT_TRUE(0LL == pool.pool_count()); + + Pointer pointer{ pool.get_for_byte_count(bytes_per_uint64 * 0) }; + ASSERT_FALSE(pointer.is_set()); + pointer.release(); + ASSERT_TRUE(0LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(bytes_per_uint64 * 2); + uint64_t *allocation1 = pointer.get(); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(1LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(bytes_per_uint64 * 2); + ASSERT_TRUE(allocation1 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(1LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(bytes_per_uint64 * 1); + ASSERT_FALSE(allocation1 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(2LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(bytes_per_uint64 * 2); + ASSERT_TRUE(allocation1 == pointer.get()); + Pointer pointer2 = pool.get_for_byte_count(bytes_per_uint64 * 2); + uint64_t *allocation2 = pointer2.get(); + ASSERT_FALSE(allocation2 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + pointer2.release(); + ASSERT_TRUE(2LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(bytes_per_uint64 * 2); + ASSERT_TRUE(allocation2 == pointer.get()); + pointer2 = pool.get_for_byte_count(bytes_per_uint64 * 2); + ASSERT_TRUE(allocation1 == pointer2.get()); + Pointer pointer3 = pool.get_for_byte_count(bytes_per_uint64 * 1); + pointer.release(); + pointer2.release(); + pointer3.release(); + ASSERT_TRUE(2LL == pool.pool_count()); + + Pointer pointer4 = pool.get_for_byte_count(1); + Pointer pointer5 = pool.get_for_byte_count(2); + Pointer pointer6 = pool.get_for_byte_count(1); + pointer4.release(); + pointer5.release(); + pointer6.release(); + ASSERT_TRUE(4LL == pool.pool_count()); + } + { + MemoryPoolMT pool; + ASSERT_TRUE(0LL == pool.pool_count()); + + Pointer pointer{ pool.get_for_byte_count(4 * 0) }; + ASSERT_FALSE(pointer.is_set()); + pointer.release(); + ASSERT_TRUE(0LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(4 * 2); + int *allocation1 = pointer.get(); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(1LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(4 * 2); + ASSERT_TRUE(allocation1 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(1LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(4 * 1); + ASSERT_FALSE(allocation1 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(2LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(4 * 2); + ASSERT_TRUE(allocation1 == pointer.get()); + Pointer pointer2 = pool.get_for_byte_count(4 * 2); + int *allocation2 = pointer2.get(); + ASSERT_FALSE(allocation2 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + pointer2.release(); + ASSERT_TRUE(2LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(4 * 2); + ASSERT_TRUE(allocation2 == pointer.get()); + pointer2 = pool.get_for_byte_count(4 * 2); + ASSERT_TRUE(allocation1 == pointer2.get()); + Pointer pointer3 = pool.get_for_byte_count(4 * 1); + pointer.release(); + pointer2.release(); + pointer3.release(); + ASSERT_TRUE(2LL == pool.pool_count()); + + Pointer pointer4 = pool.get_for_byte_count(1); + Pointer pointer5 = pool.get_for_byte_count(2); + Pointer pointer6 = pool.get_for_byte_count(1); + pointer4.release(); + pointer5.release(); + pointer6.release(); + ASSERT_TRUE(4LL == pool.pool_count()); + } + { + MemoryPoolMT pool; + ASSERT_TRUE(0LL == pool.pool_count()); + + Pointer pointer = pool.get_for_byte_count(0); + ASSERT_FALSE(pointer.is_set()); + pointer.release(); + ASSERT_TRUE(0LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(2); + SEAL_BYTE *allocation1 = pointer.get(); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(1LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(2); + ASSERT_TRUE(allocation1 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(1LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(1); + ASSERT_FALSE(allocation1 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + ASSERT_FALSE(pointer.is_set()); + ASSERT_TRUE(2LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(2); + ASSERT_TRUE(allocation1 == pointer.get()); + Pointer pointer2 = pool.get_for_byte_count(2); + SEAL_BYTE *allocation2 = pointer2.get(); + ASSERT_FALSE(allocation2 == pointer.get()); + ASSERT_TRUE(pointer.is_set()); + pointer.release(); + pointer2.release(); + ASSERT_TRUE(2LL == pool.pool_count()); + + pointer = pool.get_for_byte_count(2); + ASSERT_TRUE(allocation2 == pointer.get()); + pointer2 = pool.get_for_byte_count(2); + ASSERT_TRUE(allocation1 == pointer2.get()); + Pointer pointer3 = pool.get_for_byte_count(1); + pointer.release(); + pointer2.release(); + pointer3.release(); + ASSERT_TRUE(2LL == pool.pool_count()); + } + } + + TEST(MemoryPoolTests, PointerTestsST) + { + MemoryPool &pool = *global_variables::global_memory_pool; + { + Pointer p1; + ASSERT_FALSE(p1.is_set()); + ASSERT_TRUE(p1.get() == nullptr); + + p1 = pool.get_for_byte_count(bytes_per_uint64 * 1); + uint64_t *allocation1 = p1.get(); + ASSERT_TRUE(p1.is_set()); + ASSERT_TRUE(p1.get() != nullptr); + + p1.release(); + ASSERT_FALSE(p1.is_set()); + ASSERT_TRUE(p1.get() == nullptr); + + p1 = pool.get_for_byte_count(bytes_per_uint64 * 1); + ASSERT_TRUE(p1.is_set()); + ASSERT_TRUE(p1.get() == allocation1); + + Pointer p2; + p2.acquire(p1); + ASSERT_FALSE(p1.is_set()); + ASSERT_TRUE(p2.is_set()); + ASSERT_TRUE(p2.get() == allocation1); + + ConstPointer cp2; + cp2.acquire(p2); + ASSERT_FALSE(p2.is_set()); + ASSERT_TRUE(cp2.is_set()); + ASSERT_TRUE(cp2.get() == allocation1); + cp2.release(); + + Pointer p3 = pool.get_for_byte_count(bytes_per_uint64 * 1); + ASSERT_TRUE(p3.is_set()); + ASSERT_TRUE(p3.get() == allocation1); + + Pointer p4 = pool.get_for_byte_count(bytes_per_uint64 * 2); + ASSERT_TRUE(p4.is_set()); + uint64_t *allocation2 = p4.get(); + p3.swap_with(p4); + ASSERT_TRUE(p3.is_set()); + ASSERT_TRUE(p3.get() == allocation2); + ASSERT_TRUE(p4.is_set()); + ASSERT_TRUE(p4.get() == allocation1); + p3.release(); + p4.release(); + } + { + Pointer p1; + ASSERT_FALSE(p1.is_set()); + ASSERT_TRUE(p1.get() == nullptr); + + p1 = pool.get_for_byte_count(bytes_per_uint64 * 1); + SEAL_BYTE *allocation1 = p1.get(); + ASSERT_TRUE(p1.is_set()); + ASSERT_TRUE(p1.get() != nullptr); + + p1.release(); + ASSERT_FALSE(p1.is_set()); + ASSERT_TRUE(p1.get() == nullptr); + + p1 = pool.get_for_byte_count(bytes_per_uint64 * 1); + ASSERT_TRUE(p1.is_set()); + ASSERT_TRUE(p1.get() == allocation1); + + Pointer p2; + p2.acquire(p1); + ASSERT_FALSE(p1.is_set()); + ASSERT_TRUE(p2.is_set()); + ASSERT_TRUE(p2.get() == allocation1); + + ConstPointer cp2; + cp2.acquire(p2); + ASSERT_FALSE(p2.is_set()); + ASSERT_TRUE(cp2.is_set()); + ASSERT_TRUE(cp2.get() == allocation1); + cp2.release(); + + Pointer p3 = pool.get_for_byte_count(bytes_per_uint64 * 1); + ASSERT_TRUE(p3.is_set()); + ASSERT_TRUE(p3.get() == allocation1); + + Pointer p4 = pool.get_for_byte_count(bytes_per_uint64 * 2); + ASSERT_TRUE(p4.is_set()); + SEAL_BYTE *allocation2 = p4.get(); + p3.swap_with(p4); + ASSERT_TRUE(p3.is_set()); + ASSERT_TRUE(p3.get() == allocation2); + ASSERT_TRUE(p4.is_set()); + ASSERT_TRUE(p4.get() == allocation1); + p3.release(); + p4.release(); + } + } + + TEST(MemoryPoolTests, DuplicateIfNeededST) + { + { + unique_ptr allocation(new uint64_t[2]); + allocation[0] = 0x1234567812345678; + allocation[1] = 0x8765432187654321; + + MemoryPoolMT pool; + Pointer p1 = duplicate_if_needed(allocation.get(), 2, false, pool); + ASSERT_TRUE(p1.is_set()); + ASSERT_TRUE(p1.get() == allocation.get()); + ASSERT_TRUE(0LL == pool.pool_count()); + + p1 = duplicate_if_needed(allocation.get(), 2, true, pool); + ASSERT_TRUE(p1.is_set()); + ASSERT_FALSE(p1.get() == allocation.get()); + ASSERT_TRUE(1LL == pool.pool_count()); + ASSERT_TRUE(p1.get()[0] == 0x1234567812345678); + ASSERT_TRUE(p1.get()[1] == 0x8765432187654321); + p1.release(); + } + { + unique_ptr allocation(new int64_t[2]); + allocation[0] = 0x234567812345678; + allocation[1] = 0x765432187654321; + + MemoryPoolMT pool; + Pointer p1 = duplicate_if_needed(allocation.get(), 2, false, pool); + ASSERT_TRUE(p1.is_set()); + ASSERT_TRUE(p1.get() == allocation.get()); + ASSERT_TRUE(0LL == pool.pool_count()); + + p1 = duplicate_if_needed(allocation.get(), 2, true, pool); + ASSERT_TRUE(p1.is_set()); + ASSERT_FALSE(p1.get() == allocation.get()); + ASSERT_TRUE(1LL == pool.pool_count()); + ASSERT_TRUE(p1.get()[0] == 0x234567812345678); + ASSERT_TRUE(p1.get()[1] == 0x765432187654321); + p1.release(); + } + { + unique_ptr allocation(new int[2]); + allocation[0] = 0x123; + allocation[1] = 0x876; + + MemoryPoolMT pool; + Pointer p1 = duplicate_if_needed(allocation.get(), 2, false, pool); + ASSERT_TRUE(p1.is_set()); + ASSERT_TRUE(p1.get() == allocation.get()); + ASSERT_TRUE(0LL == pool.pool_count()); + + p1 = duplicate_if_needed(allocation.get(), 2, true, pool); + ASSERT_TRUE(p1.is_set()); + ASSERT_FALSE(p1.get() == allocation.get()); + ASSERT_TRUE(1LL == pool.pool_count()); + ASSERT_TRUE(p1.get()[0] == 0x123); + ASSERT_TRUE(p1.get()[1] == 0x876); + p1.release(); + } + } + } +} diff --git a/tests/seal/util/numth.cpp b/tests/seal/util/numth.cpp new file mode 100644 index 000000000..91f2fec7a --- /dev/null +++ b/tests/seal/util/numth.cpp @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/numth.h" +#include + +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST(NumberTheoryTest, GCD) + { + ASSERT_EQ(1ULL, gcd(1, 1)); + ASSERT_EQ(1ULL, gcd(2, 1)); + ASSERT_EQ(1ULL, gcd(1, 2)); + ASSERT_EQ(2ULL, gcd(2, 2)); + ASSERT_EQ(3ULL, gcd(6, 15)); + ASSERT_EQ(3ULL, gcd(15, 6)); + ASSERT_EQ(1ULL, gcd(7, 15)); + ASSERT_EQ(1ULL, gcd(15, 7)); + ASSERT_EQ(1ULL, gcd(7, 15)); + ASSERT_EQ(3ULL, gcd(11112, 44445)); + } + + TEST(NumberTheoryTest, ExtendedGCD) + { + tuple result; + + // Corner case behavior + result = xgcd(7, 7); + ASSERT_TRUE(result == make_tuple<>(7, 0, 1)); + result = xgcd(2, 2); + ASSERT_TRUE(result == make_tuple<>(2, 0, 1)); + + result = xgcd(1, 1); + ASSERT_TRUE(result == make_tuple<>(1, 0, 1)); + result = xgcd(1, 2); + ASSERT_TRUE(result == make_tuple<>(1, 1, 0)); + result = xgcd(5, 6); + ASSERT_TRUE(result == make_tuple<>(1, -1, 1)); + result = xgcd(13, 19); + ASSERT_TRUE(result == make_tuple<>(1, 3, -2)); + result = xgcd(14, 21); + ASSERT_TRUE(result == make_tuple<>(7, -1, 1)); + + result = xgcd(2, 1); + ASSERT_TRUE(result == make_tuple<>(1, 0, 1)); + result = xgcd(6, 5); + ASSERT_TRUE(result == make_tuple<>(1, 1, -1)); + result = xgcd(19, 13); + ASSERT_TRUE(result == make_tuple<>(1, -2, 3)); + result = xgcd(21, 14); + ASSERT_TRUE(result == make_tuple<>(7, 1, -1)); + } + + TEST(NumberTheoryTest, TryModInverse) + { + uint64_t input, modulus, result; + + input = 1, modulus = 2; + ASSERT_TRUE(try_mod_inverse(input, modulus, result)); + ASSERT_EQ(result, 1ULL); + + input = 2, modulus = 2; + ASSERT_FALSE(try_mod_inverse(input, modulus, result)); + + input = 3, modulus = 2; + ASSERT_TRUE(try_mod_inverse(input, modulus, result)); + ASSERT_EQ(result, 1ULL); + + input = 0xFFFFFF, modulus = 2; + ASSERT_TRUE(try_mod_inverse(input, modulus, result)); + ASSERT_EQ(result, 1ULL); + + input = 0xFFFFFE, modulus = 2; + ASSERT_FALSE(try_mod_inverse(input, modulus, result)); + + input = 12345, modulus = 3; + ASSERT_FALSE(try_mod_inverse(input, modulus, result)); + + input = 5, modulus = 19; + ASSERT_TRUE(try_mod_inverse(input, modulus, result)); + ASSERT_EQ(result, 4ULL); + + input = 4, modulus = 19; + ASSERT_TRUE(try_mod_inverse(input, modulus, result)); + ASSERT_EQ(result, 5ULL); + } + } +} diff --git a/tests/seal/util/polyarith.cpp b/tests/seal/util/polyarith.cpp new file mode 100644 index 000000000..017bdd89b --- /dev/null +++ b/tests/seal/util/polyarith.cpp @@ -0,0 +1,505 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/uintcore.h" +#include "seal/util/polyarith.h" +#include + +using namespace seal::util; +using namespace std; +using namespace seal; + +namespace SEALTest +{ + namespace util + { + TEST(PolyArith, RightShiftPolyCoeffs) + { + right_shift_poly_coeffs(nullptr, 0, 0, 0, nullptr); + right_shift_poly_coeffs(nullptr, 0, 0, 1, nullptr); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_zero_poly(3, 2, pool)); + ptr[0] = 2; + ptr[1] = 4; + ptr[2] = 8; + right_shift_poly_coeffs(ptr.get(), 3, 1, 0, ptr.get()); + ASSERT_EQ(2ULL, ptr[0]); + ASSERT_EQ(4ULL, ptr[1]); + ASSERT_EQ(8ULL, ptr[2]); + + right_shift_poly_coeffs(ptr.get(), 3, 1, 1, ptr.get()); + ASSERT_EQ(1ULL, ptr[0]); + ASSERT_EQ(2ULL, ptr[1]); + ASSERT_EQ(4ULL, ptr[2]); + + right_shift_poly_coeffs(ptr.get(), 3, 1, 1, ptr.get()); + ASSERT_EQ(0ULL, ptr[0]); + ASSERT_EQ(1ULL, ptr[1]); + ASSERT_EQ(2ULL, ptr[2]); + + ptr[0] = 3; + ptr[1] = 5; + ptr[2] = 9; + right_shift_poly_coeffs(ptr.get(), 3, 1, 2, ptr.get()); + ASSERT_EQ(0ULL, ptr[0]); + ASSERT_EQ(1ULL, ptr[1]); + ASSERT_EQ(2ULL, ptr[2]); + + ptr[0] = 3; + ptr[1] = 5; + ptr[2] = 9; + right_shift_poly_coeffs(ptr.get(), 3, 1, 4, ptr.get()); + ASSERT_EQ(0ULL, ptr[0]); + ASSERT_EQ(0ULL, ptr[1]); + ASSERT_EQ(0ULL, ptr[2]); + + ptr[0] = 1; + ptr[1] = 1; + ptr[2] = 1; + right_shift_poly_coeffs(ptr.get(), 1, 2, 64, ptr.get()); + ASSERT_EQ(1ULL, ptr[0]); + ASSERT_EQ(0ULL, ptr[1]); + ASSERT_EQ(1ULL, ptr[2]); + + ptr[0] = 3; + ptr[1] = 5; + ptr[2] = 9; + right_shift_poly_coeffs(ptr.get(), 1, 3, 128, ptr.get()); + ASSERT_EQ(9ULL, ptr[0]); + ASSERT_EQ(0ULL, ptr[1]); + ASSERT_EQ(0ULL, ptr[2]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr[2] = 0xFFFFFFFFFFFFFFFF; + right_shift_poly_coeffs(ptr.get(), 1, 3, 191, ptr.get()); + ASSERT_EQ(1ULL, ptr[0]); + ASSERT_EQ(0ULL, ptr[1]); + ASSERT_EQ(0ULL, ptr[2]); + } + + TEST(PolyArith, NegatePoly) + { + negate_poly(nullptr, 0, 0, nullptr); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_zero_poly(3, 2, pool)); + ptr[0] = 2; + ptr[2] = 3; + ptr[4] = 4; + negate_poly(ptr.get(), 3, 2, ptr.get()); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFD), ptr[2]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[3]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFC), ptr[4]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[5]); + } + + TEST(PolyArith, AddPolyPoly) + { + add_poly_poly(nullptr, nullptr, 0, 0, nullptr); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly1(allocate_zero_poly(3, 2, pool)); + auto poly2(allocate_zero_poly(3, 2, pool)); + + poly1[0] = 0; + poly1[1] = 0xFFFFFFFFFFFFFFFF; + poly1[2] = 1; + poly1[3] = 0; + poly1[4] = 0xFFFFFFFFFFFFFFFF; + poly1[5] = 1; + poly2[0] = 1; + poly2[1] = 1; + poly2[2] = 1; + poly2[3] = 1; + poly2[4] = 0xFFFFFFFFFFFFFFFF; + poly2[5] = 1; + add_poly_poly(poly1.get(), poly2.get(), 3, 2, poly1.get()); + ASSERT_EQ(static_cast(1), poly1[0]); + ASSERT_EQ(static_cast(0), poly1[1]); + ASSERT_EQ(static_cast(2), poly1[2]); + ASSERT_EQ(static_cast(1), poly1[3]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), poly1[4]); + ASSERT_EQ(static_cast(3), poly1[5]); + + poly1[0] = 2; + poly1[1] = 0; + poly1[2] = 3; + poly1[3] = 0; + poly1[4] = 0xFFFFFFFFFFFFFFFF; + poly1[5] = 0xFFFFFFFFFFFFFFFF; + poly2[0] = 5; + poly2[1] = 0; + poly2[2] = 6; + poly2[3] = 0; + poly2[4] = 0xFFFFFFFFFFFFFFFF; + poly2[5] = 0xFFFFFFFFFFFFFFFF; + add_poly_poly(poly1.get(), poly2.get(), 3, 2, poly1.get()); + ASSERT_EQ(static_cast(7), poly1[0]); + ASSERT_EQ(static_cast(0), poly1[1]); + ASSERT_EQ(static_cast(9), poly1[2]); + ASSERT_EQ(static_cast(0), poly1[3]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), poly1[4]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), poly1[5]); + } + + TEST(PolyArith, SubPolyPoly) + { + sub_poly_poly(nullptr, nullptr, 0, 0, nullptr); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly1(allocate_zero_poly(3, 2, pool)); + auto poly2(allocate_zero_poly(3, 2, pool)); + + poly1[0] = 0; + poly1[1] = 0xFFFFFFFFFFFFFFFF; + poly1[2] = 1; + poly1[3] = 0; + poly1[4] = 0xFFFFFFFFFFFFFFFF; + poly1[5] = 1; + poly2[0] = 1; + poly2[1] = 1; + poly2[2] = 1; + poly2[3] = 1; + poly2[4] = 0xFFFFFFFFFFFFFFFF; + poly2[5] = 1; + sub_poly_poly(poly1.get(), poly2.get(), 6, 1, poly1.get()); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), poly1[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), poly1[1]); + ASSERT_EQ(static_cast(0), poly1[2]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), poly1[3]); + ASSERT_EQ(static_cast(0), poly1[4]); + ASSERT_EQ(static_cast(0), poly1[5]); + + poly1[0] = 5; + poly1[1] = 0; + poly1[2] = 6; + poly1[3] = 0; + poly1[4] = 0xFFFFFFFFFFFFFFFF; + poly1[5] = 0xFFFFFFFFFFFFFFFF; + poly2[0] = 2; + poly2[1] = 0; + poly2[2] = 8; + poly2[3] = 0; + poly2[4] = 0xFFFFFFFFFFFFFFFE; + poly2[5] = 0xFFFFFFFFFFFFFFFF; + sub_poly_poly(poly1.get(), poly2.get(), 3, 2, poly1.get()); + ASSERT_EQ(static_cast(3), poly1[0]); + ASSERT_EQ(static_cast(0), poly1[1]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), poly1[2]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), poly1[3]); + ASSERT_EQ(1ULL, poly1[4]); + ASSERT_EQ(static_cast(0), poly1[5]); + } + + TEST(PolyArith, MultiplyPolyPoly) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly1(allocate_zero_poly(3, 2, pool)); + auto poly2(allocate_zero_poly(3, 2, pool)); + auto result(allocate_zero_poly(5, 2, pool)); + poly1[0] = 1; + poly1[2] = 2; + poly1[4] = 3; + poly2[0] = 2; + poly2[2] = 3; + poly2[4] = 4; + multiply_poly_poly(poly1.get(), 3, 2, poly2.get(), 3, 2, 5, 2, result.get(), pool); + ASSERT_EQ(static_cast(2), result[0]); + ASSERT_EQ(static_cast(0), result[1]); + ASSERT_EQ(static_cast(7), result[2]); + ASSERT_EQ(static_cast(0), result[3]); + ASSERT_EQ(static_cast(16), result[4]); + ASSERT_EQ(static_cast(0), result[5]); + ASSERT_EQ(static_cast(17), result[6]); + ASSERT_EQ(static_cast(0), result[7]); + ASSERT_EQ(static_cast(12), result[8]); + ASSERT_EQ(static_cast(0), result[9]); + + poly2[0] = 2; + poly2[1] = 3; + multiply_poly_poly(poly1.get(), 3, 2, poly2.get(), 2, 1, 5, 2, result.get(), pool); + ASSERT_EQ(static_cast(2), result[0]); + ASSERT_EQ(static_cast(0), result[1]); + ASSERT_EQ(static_cast(7), result[2]); + ASSERT_EQ(static_cast(0), result[3]); + ASSERT_EQ(static_cast(12), result[4]); + ASSERT_EQ(static_cast(0), result[5]); + ASSERT_EQ(static_cast(9), result[6]); + ASSERT_EQ(static_cast(0), result[7]); + ASSERT_EQ(static_cast(0), result[8]); + ASSERT_EQ(static_cast(0), result[9]); + + multiply_poly_poly(poly1.get(), 3, 2, poly2.get(), 2, 1, 5, 1, result.get(), pool); + ASSERT_EQ(static_cast(2), result[0]); + ASSERT_EQ(static_cast(7), result[1]); + ASSERT_EQ(static_cast(12), result[2]); + ASSERT_EQ(static_cast(9), result[3]); + ASSERT_EQ(static_cast(0), result[4]); + } + + TEST(PolyArith, PolyInftyNorm) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly(allocate_zero_poly(10, 1, pool)); + uint64_t result[2]; + + poly[0] = 1, poly[1] = 0, poly[2] = 1, poly[3] = 0, poly[4] = 0; + poly[5] = 4, poly[6] = 0xB, poly[7] = 0xA, poly[8] = 5, poly[9] = 2; + poly_infty_norm(poly.get(), 10, 1, result); + ASSERT_EQ(result[0], 0xBULL); + + poly[0] = 2, poly[1] = 0, poly[2] = 1, poly[3] = 0, poly[4] = 0; + poly[5] = 0xF7, poly[6] = 0xFE, poly[7] = 0xCF, poly[8] = 0xCA, poly[9] = 0xAB; + poly_infty_norm(poly.get(), 10, 1, result); + ASSERT_EQ(result[0], 0xFEULL); + + poly[0] = 2, poly[1] = 0, poly[2] = 1, poly[3] = 0, poly[4] = 0; + poly[5] = 0xABCDEF, poly[6] = 0xABCDE, poly[7] = 0xABCD, poly[8] = 0xABC, poly[9] = 0xAB; + poly_infty_norm(poly.get(), 10, 1, result); + ASSERT_EQ(result[0], 0xABCDEFULL); + + poly[0] = 6, poly[1] = 5, poly[2] = 4, poly[3] = 3, poly[4] = 2; + poly[5] = 1, poly[6] = 0; + poly_infty_norm(poly.get(), 6, 1, result); + ASSERT_EQ(result[0], 6ULL); + + poly[0] = 1, poly[1] = 0, poly[2] = 1, poly[3] = 0, poly[4] = 0; + poly[5] = 4, poly[6] = 0xB, poly[7] = 0xA, poly[8] = 5, poly[9] = 2; + poly_infty_norm(poly.get(), 5, 2, result); + ASSERT_EQ(result[0], 0xBULL); + ASSERT_EQ(result[1], 0xAULL); + + poly[0] = 2, poly[1] = 0, poly[2] = 1, poly[3] = 0, poly[4] = 0; + poly[5] = 0xF7, poly[6] = 0xFE, poly[7] = 0xCF, poly[8] = 0xCA, poly[9] = 0xAB; + poly_infty_norm(poly.get(), 5, 2, result); + ASSERT_EQ(result[0], 0x0ULL); + ASSERT_EQ(result[1], 0xF7ULL); + + poly[0] = 2, poly[1] = 0, poly[2] = 1, poly[3] = 0, poly[4] = 0; + poly[5] = 0xABCDEF, poly[6] = 0xABCDE, poly[7] = 0xABCD, poly[8] = 0xABC, poly[9] = 0xAB; + poly_infty_norm(poly.get(), 5, 2, result); + ASSERT_EQ(result[0], 0ULL); + ASSERT_EQ(result[1], 0xABCDEFULL); + + poly[0] = 6, poly[1] = 5, poly[2] = 4, poly[3] = 3, poly[4] = 2; + poly[5] = 1, poly[6] = 0; + poly_infty_norm(poly.get(), 3, 2, result); + ASSERT_EQ(result[0], 6ULL); + ASSERT_EQ(result[1], 5ULL); + } + + TEST(PolyArith, PolyEvalPoly) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly1(allocate_zero_poly(4, 1, pool)); + auto poly2(allocate_zero_poly(4, 1, pool)); + auto poly3(allocate_zero_poly(8, 1, pool)); + + poly_eval_poly(poly1.get(), 4, 1, poly2.get(), 4, 1, 8, 1, poly3.get(), pool); + ASSERT_EQ(poly3[0], 0ULL); + ASSERT_EQ(poly3[1], 0ULL); + ASSERT_EQ(poly3[2], 0ULL); + ASSERT_EQ(poly3[3], 0ULL); + ASSERT_EQ(poly3[4], 0ULL); + ASSERT_EQ(poly3[5], 0ULL); + ASSERT_EQ(poly3[6], 0ULL); + ASSERT_EQ(poly3[7], 0ULL); + + poly1[0] = 1; + poly_eval_poly(poly1.get(), 4, 1, poly2.get(), 4, 1, 8, 1, poly3.get(), pool); + ASSERT_EQ(poly3[0], 1ULL); + ASSERT_EQ(poly3[1], 0ULL); + ASSERT_EQ(poly3[2], 0ULL); + ASSERT_EQ(poly3[3], 0ULL); + ASSERT_EQ(poly3[4], 0ULL); + ASSERT_EQ(poly3[5], 0ULL); + ASSERT_EQ(poly3[6], 0ULL); + ASSERT_EQ(poly3[7], 0ULL); + + poly1[0] = 2; + poly2[0] = 1; + poly_eval_poly(poly1.get(), 4, 1, poly2.get(), 4, 1, 8, 1, poly3.get(), pool); + ASSERT_EQ(poly3[0], 2ULL); + ASSERT_EQ(poly3[1], 0ULL); + ASSERT_EQ(poly3[2], 0ULL); + ASSERT_EQ(poly3[3], 0ULL); + ASSERT_EQ(poly3[4], 0ULL); + ASSERT_EQ(poly3[5], 0ULL); + ASSERT_EQ(poly3[6], 0ULL); + ASSERT_EQ(poly3[7], 0ULL); + + poly1[0] = 1; + poly1[1] = 1; + poly2[0] = 1; + poly_eval_poly(poly1.get(), 4, 1, poly2.get(), 4, 1, 8, 1, poly3.get(), pool); + ASSERT_EQ(poly3[0], 2ULL); + ASSERT_EQ(poly3[1], 0ULL); + ASSERT_EQ(poly3[2], 0ULL); + ASSERT_EQ(poly3[3], 0ULL); + ASSERT_EQ(poly3[4], 0ULL); + ASSERT_EQ(poly3[5], 0ULL); + ASSERT_EQ(poly3[6], 0ULL); + ASSERT_EQ(poly3[7], 0ULL); + + poly1[0] = 1; + poly1[1] = 1; + poly2[0] = 2; + poly2[1] = 0; + poly2[2] = 1; + poly_eval_poly(poly1.get(), 4, 1, poly2.get(), 4, 1, 8, 1, poly3.get(), pool); + ASSERT_EQ(poly3[0], 3ULL); + ASSERT_EQ(poly3[1], 0ULL); + ASSERT_EQ(poly3[2], 1ULL); + ASSERT_EQ(poly3[3], 0ULL); + ASSERT_EQ(poly3[4], 0ULL); + ASSERT_EQ(poly3[5], 0ULL); + ASSERT_EQ(poly3[6], 0ULL); + ASSERT_EQ(poly3[7], 0ULL); + + poly1[0] = 2; + poly1[1] = 0; + poly1[2] = 1; + poly2[0] = 1; + poly2[1] = 1; + poly2[2] = 0; + poly_eval_poly(poly1.get(), 4, 1, poly2.get(), 4, 1, 8, 1, poly3.get(), pool); + ASSERT_EQ(poly3[0], 3ULL); + ASSERT_EQ(poly3[1], 2ULL); + ASSERT_EQ(poly3[2], 1ULL); + ASSERT_EQ(poly3[3], 0ULL); + ASSERT_EQ(poly3[4], 0ULL); + ASSERT_EQ(poly3[5], 0ULL); + ASSERT_EQ(poly3[6], 0ULL); + ASSERT_EQ(poly3[7], 0ULL); + + poly1[0] = 0; + poly1[1] = 0; + poly1[2] = 0; + poly1[3] = 1; + poly2[0] = 2; + poly2[1] = 0; + poly2[2] = 0; + poly2[3] = 0; + poly_eval_poly(poly1.get(), 4, 1, poly2.get(), 4, 1, 8, 1, poly3.get(), pool); + ASSERT_EQ(poly3[0], 8ULL); + ASSERT_EQ(poly3[1], 0ULL); + ASSERT_EQ(poly3[2], 0ULL); + ASSERT_EQ(poly3[3], 0ULL); + ASSERT_EQ(poly3[4], 0ULL); + ASSERT_EQ(poly3[5], 0ULL); + ASSERT_EQ(poly3[6], 0ULL); + ASSERT_EQ(poly3[7], 0ULL); + + poly1[0] = 0; + poly1[1] = 0; + poly1[2] = 0; + poly1[3] = 1; + poly2[0] = 0; + poly2[1] = 0; + poly2[2] = 2; + poly2[3] = 0; + poly_eval_poly(poly1.get(), 4, 1, poly2.get(), 4, 1, 8, 1, poly3.get(), pool); + ASSERT_EQ(poly3[0], 0ULL); + ASSERT_EQ(poly3[1], 0ULL); + ASSERT_EQ(poly3[2], 0ULL); + ASSERT_EQ(poly3[3], 0ULL); + ASSERT_EQ(poly3[4], 0ULL); + ASSERT_EQ(poly3[5], 0ULL); + ASSERT_EQ(poly3[6], 8ULL); + ASSERT_EQ(poly3[7], 0ULL); + } + + TEST(PolyArith, ExponentiatePoly) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly1(allocate_zero_poly(4, 1, pool)); + auto poly2(allocate_zero_poly(12, 1, pool)); + + uint64_t exponent = 1; + exponentiate_poly(poly1.get(), 4, 1, &exponent, 1, 12, 1, poly2.get(), pool); + ASSERT_EQ(poly2[0], 0ULL); + ASSERT_EQ(poly2[1], 0ULL); + ASSERT_EQ(poly2[2], 0ULL); + ASSERT_EQ(poly2[3], 0ULL); + ASSERT_EQ(poly2[4], 0ULL); + ASSERT_EQ(poly2[5], 0ULL); + ASSERT_EQ(poly2[6], 0ULL); + ASSERT_EQ(poly2[7], 0ULL); + ASSERT_EQ(poly2[8], 0ULL); + ASSERT_EQ(poly2[9], 0ULL); + ASSERT_EQ(poly2[10], 0ULL); + ASSERT_EQ(poly2[11], 0ULL); + + exponent = 0; + exponentiate_poly(poly1.get(), 4, 1, &exponent, 1, 12, 1, poly2.get(), pool); + ASSERT_EQ(poly2[0], 1ULL); + ASSERT_EQ(poly2[1], 0ULL); + ASSERT_EQ(poly2[2], 0ULL); + ASSERT_EQ(poly2[3], 0ULL); + ASSERT_EQ(poly2[4], 0ULL); + ASSERT_EQ(poly2[5], 0ULL); + ASSERT_EQ(poly2[6], 0ULL); + ASSERT_EQ(poly2[7], 0ULL); + ASSERT_EQ(poly2[8], 0ULL); + ASSERT_EQ(poly2[9], 0ULL); + ASSERT_EQ(poly2[10], 0ULL); + ASSERT_EQ(poly2[11], 0ULL); + + exponent = 3; + poly1[1] = 2; + exponentiate_poly(poly1.get(), 4, 1, &exponent, 1, 12, 1, poly2.get(), pool); + ASSERT_EQ(poly2[0], 0ULL); + ASSERT_EQ(poly2[1], 0ULL); + ASSERT_EQ(poly2[2], 0ULL); + ASSERT_EQ(poly2[3], 8ULL); + ASSERT_EQ(poly2[4], 0ULL); + ASSERT_EQ(poly2[5], 0ULL); + ASSERT_EQ(poly2[6], 0ULL); + ASSERT_EQ(poly2[7], 0ULL); + ASSERT_EQ(poly2[8], 0ULL); + ASSERT_EQ(poly2[9], 0ULL); + ASSERT_EQ(poly2[10], 0ULL); + ASSERT_EQ(poly2[11], 0ULL); + + exponent = 3; + poly1[0] = 1; + poly1[1] = 1; + exponentiate_poly(poly1.get(), 4, 1, &exponent, 1, 12, 1, poly2.get(), pool); + ASSERT_EQ(poly2[0], 1ULL); + ASSERT_EQ(poly2[1], 3ULL); + ASSERT_EQ(poly2[2], 3ULL); + ASSERT_EQ(poly2[3], 1ULL); + ASSERT_EQ(poly2[4], 0ULL); + ASSERT_EQ(poly2[5], 0ULL); + ASSERT_EQ(poly2[6], 0ULL); + ASSERT_EQ(poly2[7], 0ULL); + ASSERT_EQ(poly2[8], 0ULL); + ASSERT_EQ(poly2[9], 0ULL); + ASSERT_EQ(poly2[10], 0ULL); + ASSERT_EQ(poly2[11], 0ULL); + + exponent = 5; + poly1[0] = 0; + poly1[1] = 0; + poly1[2] = 2; + exponentiate_poly(poly1.get(), 4, 1, &exponent, 1, 12, 1, poly2.get(), pool); + ASSERT_EQ(poly2[0], 0ULL); + ASSERT_EQ(poly2[1], 0ULL); + ASSERT_EQ(poly2[2], 0ULL); + ASSERT_EQ(poly2[3], 0ULL); + ASSERT_EQ(poly2[4], 0ULL); + ASSERT_EQ(poly2[5], 0ULL); + ASSERT_EQ(poly2[6], 0ULL); + ASSERT_EQ(poly2[7], 0ULL); + ASSERT_EQ(poly2[8], 0ULL); + ASSERT_EQ(poly2[9], 0ULL); + ASSERT_EQ(poly2[10], 32ULL); + ASSERT_EQ(poly2[11], 0ULL); + } + } +} diff --git a/tests/seal/util/polyarithmod.cpp b/tests/seal/util/polyarithmod.cpp new file mode 100644 index 000000000..3ad719b34 --- /dev/null +++ b/tests/seal/util/polyarithmod.cpp @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/uintcore.h" +#include "seal/util/polycore.h" +#include "seal/util/polyarithmod.h" +#include + +using namespace seal; +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST(PolyArithMod, NegatePolyCoeffMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly(allocate_zero_poly(3, 2, pool)); + auto modulus(allocate_uint(2, pool)); + poly[0] = 2; + poly[2] = 3; + poly[4] = 4; + modulus[0] = 15; + modulus[1] = 0; + negate_poly_coeffmod(poly.get(), 3, modulus.get(), 2, poly.get()); + ASSERT_EQ(static_cast(13), poly[0]); + ASSERT_EQ(static_cast(0), poly[1]); + ASSERT_EQ(static_cast(12), poly[2]); + ASSERT_EQ(static_cast(0), poly[3]); + ASSERT_EQ(static_cast(11), poly[4]); + ASSERT_EQ(static_cast(0), poly[5]); + + poly[0] = 2; + poly[2] = 3; + poly[4] = 4; + modulus[0] = 0xFFFFFFFFFFFFFFFF; + modulus[1] = 0xFFFFFFFFFFFFFFFF; + negate_poly_coeffmod(poly.get(), 3, modulus.get(), 2, poly.get()); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFD), poly[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), poly[1]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFC), poly[2]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), poly[3]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFB), poly[4]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), poly[5]); + } + + TEST(PolyArithMod, AddPolyPolyCoeffMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly1(allocate_zero_poly(3, 2, pool)); + auto poly2(allocate_zero_poly(3, 2, pool)); + auto modulus(allocate_uint(2, pool)); + poly1[0] = 1; + poly1[2] = 3; + poly1[4] = 4; + poly2[0] = 1; + poly2[2] = 2; + poly2[4] = 4; + modulus[0] = 5; + modulus[1] = 0; + add_poly_poly_coeffmod(poly1.get(), poly2.get(), 3, modulus.get(), 2, poly1.get()); + ASSERT_EQ(static_cast(2), poly1[0]); + ASSERT_EQ(static_cast(0), poly1[1]); + ASSERT_EQ(static_cast(0), poly1[2]); + ASSERT_EQ(static_cast(0), poly1[3]); + ASSERT_EQ(static_cast(3), poly1[4]); + ASSERT_EQ(static_cast(0), poly1[5]); + } + + TEST(PolyArithMod, SubPolyPolyCoeffMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly1(allocate_zero_poly(3, 2, pool)); + auto poly2(allocate_zero_poly(3, 2, pool)); + auto modulus(allocate_uint(2, pool)); + poly1[0] = 4; + poly1[2] = 3; + poly1[4] = 2; + poly2[0] = 2; + poly2[2] = 3; + poly2[4] = 4; + modulus[0] = 5; + modulus[1] = 0; + sub_poly_poly_coeffmod(poly1.get(), poly2.get(), 3, modulus.get(), 2, poly1.get()); + ASSERT_EQ(static_cast(2), poly1[0]); + ASSERT_EQ(static_cast(0), poly1[1]); + ASSERT_EQ(static_cast(0), poly1[2]); + ASSERT_EQ(static_cast(0), poly1[3]); + ASSERT_EQ(static_cast(3), poly1[4]); + ASSERT_EQ(static_cast(0), poly1[5]); + } + } +} diff --git a/tests/seal/util/polyarithsmallmod.cpp b/tests/seal/util/polyarithsmallmod.cpp new file mode 100644 index 000000000..e93f738cf --- /dev/null +++ b/tests/seal/util/polyarithsmallmod.cpp @@ -0,0 +1,391 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/uintcore.h" +#include "seal/util/polycore.h" +#include "seal/util/polyarithsmallmod.h" +#include +#include + +using namespace seal; +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST(PolyArithSmallMod, SmallModuloPolyCoeffs) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly(allocate_zero_poly(3, 1, pool)); + auto modulus(allocate_uint(2, pool)); + poly[0] = 2; + poly[1] = 15; + poly[2] = 77; + SmallModulus mod(15); + modulo_poly_coeffs(poly.get(), 3, mod, poly.get()); + ASSERT_EQ(2ULL, poly[0]); + ASSERT_EQ(0ULL, poly[1]); + ASSERT_EQ(2ULL, poly[2]); + } + + TEST(PolyArithSmallMod, NegatePolyCoeffSmallMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly(allocate_zero_poly(3, 1, pool)); + poly[0] = 2; + poly[1] = 3; + poly[2] = 4; + SmallModulus mod(15); + negate_poly_coeffmod(poly.get(), 3, mod, poly.get()); + ASSERT_EQ(static_cast(13), poly[0]); + ASSERT_EQ(static_cast(12), poly[1]); + ASSERT_EQ(static_cast(11), poly[2]); + + poly[0] = 2; + poly[1] = 3; + poly[2] = 4; + mod = 0xFFFFFFFFFFFFFFULL; + negate_poly_coeffmod(poly.get(), 3, mod, poly.get()); + ASSERT_EQ(0xFFFFFFFFFFFFFDULL, poly[0]); + ASSERT_EQ(0xFFFFFFFFFFFFFCULL, poly[1]); + ASSERT_EQ(0xFFFFFFFFFFFFFBULL, poly[2]); + } + + TEST(PolyArithSmallMod, AddPolyPolyCoeffSmallMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly1(allocate_zero_poly(3, 1, pool)); + auto poly2(allocate_zero_poly(3, 1, pool)); + poly1[0] = 1; + poly1[1] = 3; + poly1[2] = 4; + poly2[0] = 1; + poly2[1] = 2; + poly2[2] = 4; + SmallModulus mod(5); + add_poly_poly_coeffmod(poly1.get(), poly2.get(), 3, mod, poly1.get()); + ASSERT_EQ(2ULL, poly1[0]); + ASSERT_EQ(0ULL, poly1[1]); + ASSERT_EQ(3ULL, poly1[2]); + } + + TEST(PolyArithSmallMod, SubPolyPolyCoeffSmallMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly1(allocate_zero_poly(3, 1, pool)); + auto poly2(allocate_zero_poly(3, 1, pool)); + poly1[0] = 4; + poly1[1] = 3; + poly1[2] = 2; + poly2[0] = 2; + poly2[1] = 3; + poly2[2] = 4; + SmallModulus mod(5); + sub_poly_poly_coeffmod(poly1.get(), poly2.get(), 3, mod, poly1.get()); + ASSERT_EQ(2ULL, poly1[0]); + ASSERT_EQ(0ULL, poly1[1]); + ASSERT_EQ(3ULL, poly1[2]); + } + + TEST(PolyArithSmallMod, MultiplyPolyScalarCoeffSmallMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly(allocate_zero_poly(3, 1, pool)); + poly[0] = 1; + poly[1] = 3; + poly[2] = 4; + uint64_t scalar = 3; + SmallModulus mod(5); + multiply_poly_scalar_coeffmod(poly.get(), 3, scalar, mod, poly.get()); + ASSERT_EQ(3ULL, poly[0]); + ASSERT_EQ(4ULL, poly[1]); + ASSERT_EQ(2ULL, poly[2]); + } + + TEST(PolyArithSmallMod, MultiplyPolyPolyCoeffSmallMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly1(allocate_zero_poly(3, 1, pool)); + auto poly2(allocate_zero_poly(3, 1, pool)); + auto result(allocate_zero_poly(5, 1, pool)); + poly1[0] = 1; + poly1[1] = 2; + poly1[2] = 3; + poly2[0] = 2; + poly2[1] = 3; + poly2[2] = 4; + SmallModulus mod(5); + multiply_poly_poly_coeffmod(poly1.get(), 3, poly2.get(), 3, mod, 5, result.get()); + ASSERT_EQ(2ULL, result[0]); + ASSERT_EQ(2ULL, result[1]); + ASSERT_EQ(1ULL, result[2]); + ASSERT_EQ(2ULL, result[3]); + ASSERT_EQ(2ULL, result[4]); + + poly2[0] = 2; + poly2[1] = 3; + multiply_poly_poly_coeffmod(poly1.get(), 3, poly2.get(), 2, mod, 5, result.get()); + ASSERT_EQ(2ULL, result[0]); + ASSERT_EQ(2ULL, result[1]); + ASSERT_EQ(2ULL, result[2]); + ASSERT_EQ(4ULL, result[3]); + ASSERT_EQ(0ULL, result[4]); + } + + TEST(PolyArithSmallMod, DividePolyPolyCoeffSmallMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly1(allocate_zero_poly(5, 1, pool)); + auto poly2(allocate_zero_poly(5, 1, pool)); + auto result(allocate_zero_poly(5, 1, pool)); + auto quotient(allocate_zero_poly(5, 1, pool)); + SmallModulus mod(5); + + poly1[0] = 2; + poly1[1] = 2; + poly2[0] = 2; + poly2[1] = 3; + poly2[2] = 4; + + divide_poly_poly_coeffmod_inplace(poly1.get(), poly2.get(), 5, mod, result.get()); + ASSERT_EQ(2ULL, poly1[0]); + ASSERT_EQ(2ULL, poly1[1]); + ASSERT_EQ(0ULL, poly1[2]); + ASSERT_EQ(0ULL, poly1[3]); + ASSERT_EQ(0ULL, poly1[4]); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + ASSERT_EQ(0ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + ASSERT_EQ(0ULL, result[4]); + + poly1[0] = 2; + poly1[1] = 2; + poly1[2] = 1; + poly1[3] = 2; + poly1[4] = 2; + poly2[0] = 4; + poly2[1] = 3; + poly2[2] = 2; + + divide_poly_poly_coeffmod(poly1.get(), poly2.get(), 5, mod, quotient.get(), result.get()); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + ASSERT_EQ(0ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + ASSERT_EQ(0ULL, result[4]); + ASSERT_EQ(3ULL, quotient[0]); + ASSERT_EQ(2ULL, quotient[1]); + ASSERT_EQ(1ULL, quotient[2]); + ASSERT_EQ(0ULL, quotient[3]); + ASSERT_EQ(0ULL, quotient[4]); + } + + TEST(PolyArithSmallMod, DyadicProductCoeffSmallMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly1(allocate_zero_poly(3, 1, pool)); + auto poly2(allocate_zero_poly(3, 1, pool)); + auto result(allocate_zero_poly(3, 1, pool)); + SmallModulus mod(13); + + poly1[0] = 1; + poly1[1] = 1; + poly1[2] = 1; + poly2[0] = 2; + poly2[1] = 3; + poly2[2] = 4; + + dyadic_product_coeffmod(poly1.get(), poly2.get(), 3, mod, result.get()); + ASSERT_EQ(2ULL, result[0]); + ASSERT_EQ(3ULL, result[1]); + ASSERT_EQ(4ULL, result[2]); + + poly1[0] = 0; + poly1[1] = 0; + poly1[2] = 0; + poly2[0] = 2; + poly2[1] = 3; + poly2[2] = 4; + + dyadic_product_coeffmod(poly1.get(), poly2.get(), 3, mod, result.get()); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + ASSERT_EQ(0ULL, result[2]); + + poly1[0] = 3; + poly1[1] = 5; + poly1[2] = 8; + poly2[0] = 2; + poly2[1] = 3; + poly2[2] = 4; + + dyadic_product_coeffmod(poly1.get(), poly2.get(), 3, mod, result.get()); + ASSERT_EQ(6ULL, result[0]); + ASSERT_EQ(2ULL, result[1]); + ASSERT_EQ(6ULL, result[2]); + } + + TEST(PolyArithSmallMod, TryInvertPolyCoeffSmallMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly(allocate_zero_poly(4, 1, pool)); + auto polymod(allocate_zero_poly(4, 1, pool)); + auto result(allocate_zero_poly(4, 1, pool)); + SmallModulus mod(5); + + polymod[0] = 4; + polymod[1] = 3; + polymod[2] = 0; + polymod[3] = 2; + + ASSERT_FALSE(try_invert_poly_coeffmod(poly.get(), polymod.get(), 4, mod, result.get(), pool)); + + poly[0] = 1; + ASSERT_TRUE(try_invert_poly_coeffmod(poly.get(), polymod.get(), 4, mod, result.get(), pool)); + ASSERT_EQ(1ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + ASSERT_EQ(0ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + + poly[0] = 1; + poly[1] = 2; + poly[2] = 3; + ASSERT_TRUE(try_invert_poly_coeffmod(poly.get(), polymod.get(), 4, mod, result.get(), pool)); + ASSERT_EQ(4ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + ASSERT_EQ(2ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + } + + TEST(PolyArithSmallMod, PolyInftyNormCoeffSmallMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly(allocate_zero_poly(4, 1, pool)); + SmallModulus mod(10); + + poly[0] = 0; + poly[1] = 1; + poly[2] = 2; + poly[3] = 3; + ASSERT_EQ(0x3ULL, poly_infty_norm_coeffmod(poly.get(), 4, mod)); + + poly[0] = 0; + poly[1] = 1; + poly[2] = 2; + poly[3] = 8; + ASSERT_EQ(0x2ULL, poly_infty_norm_coeffmod(poly.get(), 4, mod)); + } + + TEST(PolyArithSmallMod, NegacyclicShiftPolyCoeffSmallMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly(allocate_zero_poly(4, 1, pool)); + auto result(allocate_zero_poly(4, 1, pool)); + + SmallModulus mod(10); + size_t coeff_count = 4; + + negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 0, mod, result.get()); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + ASSERT_EQ(0ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 1, mod, result.get()); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + ASSERT_EQ(0ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 4, mod, result.get()); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + ASSERT_EQ(0ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 5, mod, result.get()); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + ASSERT_EQ(0ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 8, mod, result.get()); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + ASSERT_EQ(0ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + + poly[0] = 1; + poly[1] = 2; + poly[2] = 3; + poly[3] = 4; + negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 0, mod, result.get()); + ASSERT_EQ(1ULL, result[0]); + ASSERT_EQ(2ULL, result[1]); + ASSERT_EQ(3ULL, result[2]); + ASSERT_EQ(4ULL, result[3]); + negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 1, mod, result.get()); + ASSERT_EQ(6ULL, result[0]); + ASSERT_EQ(1ULL, result[1]); + ASSERT_EQ(2ULL, result[2]); + ASSERT_EQ(3ULL, result[3]); + negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 4, mod, result.get()); + ASSERT_EQ(9ULL, result[0]); + ASSERT_EQ(8ULL, result[1]); + ASSERT_EQ(7ULL, result[2]); + ASSERT_EQ(6ULL, result[3]); + negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 5, mod, result.get()); + ASSERT_EQ(4ULL, result[0]); + ASSERT_EQ(9ULL, result[1]); + ASSERT_EQ(8ULL, result[2]); + ASSERT_EQ(7ULL, result[3]); + negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 8, mod, result.get()); + ASSERT_EQ(1ULL, result[0]); + ASSERT_EQ(2ULL, result[1]); + ASSERT_EQ(3ULL, result[2]); + ASSERT_EQ(4ULL, result[3]); + + poly[0] = 1; + poly[1] = 2; + poly[2] = 0; + poly[3] = 4; + negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 0, mod, result.get()); + ASSERT_EQ(1ULL, result[0]); + ASSERT_EQ(2ULL, result[1]); + ASSERT_EQ(0ULL, result[2]); + ASSERT_EQ(4ULL, result[3]); + negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 1, mod, result.get()); + ASSERT_EQ(6ULL, result[0]); + ASSERT_EQ(1ULL, result[1]); + ASSERT_EQ(2ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 4, mod, result.get()); + ASSERT_EQ(9ULL, result[0]); + ASSERT_EQ(8ULL, result[1]); + ASSERT_EQ(0ULL, result[2]); + ASSERT_EQ(6ULL, result[3]); + negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 5, mod, result.get()); + ASSERT_EQ(4ULL, result[0]); + ASSERT_EQ(9ULL, result[1]); + ASSERT_EQ(8ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 8, mod, result.get()); + ASSERT_EQ(1ULL, result[0]); + ASSERT_EQ(2ULL, result[1]); + ASSERT_EQ(0ULL, result[2]); + ASSERT_EQ(4ULL, result[3]); + + poly[0] = 1; + poly[1] = 2; + poly[2] = 3; + poly[3] = 4; + coeff_count = 2; + negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 1, mod, result.get()); + negacyclic_shift_poly_coeffmod(poly.get() + 2, coeff_count, 1, mod, result.get() + 2); + ASSERT_EQ(8ULL, result[0]); + ASSERT_EQ(1ULL, result[1]); + ASSERT_EQ(6ULL, result[2]); + ASSERT_EQ(3ULL, result[3]); + } + } +} diff --git a/tests/seal/util/polycore.cpp b/tests/seal/util/polycore.cpp new file mode 100644 index 000000000..55e595a8d --- /dev/null +++ b/tests/seal/util/polycore.cpp @@ -0,0 +1,288 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/polycore.h" +#include "seal/util/uintarith.h" +#include + +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST(PolyCore, AllocatePoly) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_poly(0, 0, pool)); + ASSERT_TRUE(nullptr == ptr.get()); + + ptr = allocate_poly(1, 0, pool); + ASSERT_TRUE(nullptr == ptr.get()); + + ptr = allocate_poly(0, 1, pool); + ASSERT_TRUE(nullptr == ptr.get()); + + ptr = allocate_poly(1, 1, pool); + ASSERT_TRUE(nullptr != ptr.get()); + + ptr = allocate_poly(2, 1, pool); + ASSERT_TRUE(nullptr != ptr.get()); + } + + TEST(PolyCore, SetZeroPoly) + { + set_zero_poly(0, 0, nullptr); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_poly(1, 1, pool)); + ptr[0] = 0x1234567812345678; + set_zero_poly(1, 1, ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + + ptr = allocate_poly(2, 3, pool); + for (size_t i = 0; i < 6; ++i) + { + ptr[i] = 0x1234567812345678; + } + set_zero_poly(2, 3, ptr.get()); + for (size_t i = 0; i < 6; ++i) + { + ASSERT_EQ(static_cast(0), ptr[i]); + } + } + + TEST(PolyCore, AllocateZeroPoly) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_zero_poly(0, 0, pool)); + ASSERT_TRUE(nullptr == ptr.get()); + + ptr = allocate_zero_poly(1, 1, pool); + ASSERT_TRUE(nullptr != ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + + ptr = allocate_zero_poly(2, 3, pool); + ASSERT_TRUE(nullptr != ptr.get()); + for (size_t i = 0; i < 6; ++i) + { + ASSERT_EQ(static_cast(0), ptr[i]); + } + } + + TEST(PolyCore, GetPolyCoeff) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_zero_poly(2, 3, pool)); + *get_poly_coeff(ptr.get(), 0, 3) = 1; + *get_poly_coeff(ptr.get(), 1, 3) = 2; + ASSERT_EQ(1ULL, ptr[0]); + ASSERT_EQ(static_cast(2), ptr[3]); + ASSERT_EQ(1ULL, *get_poly_coeff(ptr.get(), 0, 3)); + ASSERT_EQ(static_cast(2), *get_poly_coeff(ptr.get(), 1, 3)); + } + + TEST(PolyCore, SetPolyPoly) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr1(allocate_poly(2, 3, pool)); + auto ptr2(allocate_zero_poly(2, 3, pool)); + for (size_t i = 0; i < 6; ++i) + { + ptr1[i] = static_cast(i + 1); + } + set_poly_poly(ptr1.get(), 2, 3, ptr2.get()); + for (size_t i = 0; i < 6; ++i) + { + ASSERT_EQ(static_cast(i + 1), ptr2[i]); + } + + set_poly_poly(ptr1.get(), 2, 3, ptr1.get()); + for (size_t i = 0; i < 6; ++i) + { + ASSERT_EQ(static_cast(i + 1), ptr2[i]); + } + + ptr2 = allocate_poly(3, 4, pool); + for (size_t i = 0; i < 12; ++i) + { + ptr2[i] = 1ULL; + } + set_poly_poly(ptr1.get(), 2, 3, 3, 4, ptr2.get()); + ASSERT_EQ(1ULL, ptr2[0]); + ASSERT_EQ(static_cast(2), ptr2[1]); + ASSERT_EQ(static_cast(3), ptr2[2]); + ASSERT_EQ(static_cast(0), ptr2[3]); + ASSERT_EQ(static_cast(4), ptr2[4]); + ASSERT_EQ(static_cast(5), ptr2[5]); + ASSERT_EQ(static_cast(6), ptr2[6]); + ASSERT_EQ(static_cast(0), ptr2[7]); + ASSERT_EQ(static_cast(0), ptr2[8]); + ASSERT_EQ(static_cast(0), ptr2[9]); + ASSERT_EQ(static_cast(0), ptr2[10]); + ASSERT_EQ(static_cast(0), ptr2[11]); + + ptr2 = allocate_poly(1, 2, pool); + ptr2[0] = 1; + ptr2[1] = 1; + set_poly_poly(ptr1.get(), 2, 3, 1, 2, ptr2.get()); + ASSERT_EQ(1ULL, ptr2[0]); + ASSERT_EQ(static_cast(2), ptr2[1]); + } + + TEST(PolyCore, IsZeroPoly) + { + ASSERT_TRUE(is_zero_poly(nullptr, 0, 0)); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_poly(2, 3, pool)); + for (size_t i = 0; i < 6; ++i) + { + ptr[i] = 0; + } + ASSERT_TRUE(is_zero_poly(ptr.get(), 2, 3)); + for (size_t i = 0; i < 6; ++i) + { + ptr[i] = 1; + ASSERT_FALSE(is_zero_poly(ptr.get(), 2, 3)); + ptr[i] = 0; + } + } + + TEST(PolyCore, IsEqualPolyPoly) + { + ASSERT_TRUE(is_equal_poly_poly(nullptr, nullptr, 0, 0)); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr1(allocate_poly(2, 3, pool)); + auto ptr2(allocate_poly(2, 3, pool)); + for (size_t i = 0; i < 6; ++i) + { + ptr2[i] = ptr1[i] = static_cast(i + 1); + } + ASSERT_TRUE(is_equal_poly_poly(ptr1.get(), ptr2.get(), 2, 3)); + for (size_t i = 0; i < 6; ++i) + { + ptr2[i]--; + ASSERT_FALSE(is_equal_poly_poly(ptr1.get(), ptr2.get(), 2, 3)); + ptr2[i]++; + } + } + + TEST(PolyCore, IsOneZeroOnePoly) + { + ASSERT_FALSE(is_one_zero_one_poly(nullptr, 0, 0)); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly(allocate_zero_poly(4, 2, pool)); + ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 0, 2)); + ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 1, 2)); + ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 2, 2)); + ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 3, 2)); + + poly[0] = 2; + ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 1, 2)); + ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 2, 2)); + + poly[0] = 1; + ASSERT_TRUE(is_one_zero_one_poly(poly.get(), 1, 2)); + ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 2, 2)); + + poly[2] = 2; + ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 2, 2)); + ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 3, 2)); + + poly[2] = 1; + ASSERT_TRUE(is_one_zero_one_poly(poly.get(), 2, 2)); + ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 3, 2)); + + poly[4] = 1; + ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 3, 2)); + ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 4, 2)); + + poly[2] = 0; + ASSERT_TRUE(is_one_zero_one_poly(poly.get(), 3, 2)); + ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 4, 2)); + + poly[6] = 2; + ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 4, 2)); + + poly[6] = 1; + ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 4, 2)); + + poly[4] = 0; + ASSERT_TRUE(is_one_zero_one_poly(poly.get(), 4, 2)); + } + + TEST(PolyCore, GetSignificantCoeffCountPoly) + { + ASSERT_EQ(0ULL, get_significant_coeff_count_poly(nullptr, 0, 0)); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_zero_poly(3, 2, pool)); + ASSERT_EQ(0ULL, get_significant_coeff_count_poly(ptr.get(), 3, 2)); + ptr[0] = 1; + ASSERT_EQ(1ULL, get_significant_coeff_count_poly(ptr.get(), 3, 2)); + ptr[1] = 1; + ASSERT_EQ(1ULL, get_significant_coeff_count_poly(ptr.get(), 3, 2)); + ptr[4] = 1; + ASSERT_EQ(3ULL, get_significant_coeff_count_poly(ptr.get(), 3, 2)); + ptr[4] = 0; + ptr[5] = 1; + ASSERT_EQ(3ULL, get_significant_coeff_count_poly(ptr.get(), 3, 2)); + } + + TEST(PolyCore, DuplicatePolyIfNeeded) + { + ASSERT_EQ(0ULL, get_significant_coeff_count_poly(nullptr, 0, 0)); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly(allocate_poly(3, 2, pool)); + for (size_t i = 0; i < 6; i++) + { + poly[i] = i + 1; + } + + auto ptr = duplicate_poly_if_needed(poly.get(), 3, 2, 3, 2, false, pool); + ASSERT_TRUE(ptr.get() == poly.get()); + ptr = duplicate_poly_if_needed(poly.get(), 3, 2, 2, 2, false, pool); + ASSERT_TRUE(ptr.get() == poly.get()); + ptr = duplicate_poly_if_needed(poly.get(), 3, 2, 2, 3, false, pool); + ASSERT_TRUE(ptr.get() != poly.get()); + ASSERT_EQ(1ULL, ptr[0]); + ASSERT_EQ(static_cast(2), ptr[1]); + ASSERT_EQ(static_cast(0), ptr[2]); + ASSERT_EQ(static_cast(3), ptr[3]); + ASSERT_EQ(static_cast(4), ptr[4]); + ASSERT_EQ(static_cast(0), ptr[5]); + + ptr = duplicate_poly_if_needed(poly.get(), 3, 2, 3, 2, true, pool); + ASSERT_TRUE(ptr.get() != poly.get()); + for (size_t i = 0; i < 6; i++) + { + ASSERT_EQ(static_cast(i + 1), ptr[i]); + } + } + + TEST(PolyCore, ArePolyCoeffsLessThan) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto poly(allocate_zero_poly(3, 2, pool)); + poly[0] = 3; + poly[2] = 5; + poly[4] = 4; + + auto max(allocate_uint(1, pool)); + max[0] = 1; + ASSERT_FALSE(are_poly_coefficients_less_than(poly.get(), 3, 2, max.get(), 1)); + max[0] = 5; + ASSERT_FALSE(are_poly_coefficients_less_than(poly.get(), 3, 2, max.get(), 1)); + max[0] = 6; + ASSERT_TRUE(are_poly_coefficients_less_than(poly.get(), 3, 2, max.get(), 1)); + max[0] = 10; + ASSERT_TRUE(are_poly_coefficients_less_than(poly.get(), 3, 2, max.get(), 1)); + } + } +} diff --git a/tests/seal/util/randomtostd.cpp b/tests/seal/util/randomtostd.cpp new file mode 100644 index 000000000..8d98dd4af --- /dev/null +++ b/tests/seal/util/randomtostd.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/randomgen.h" +#include "seal/util/randomtostd.h" +#include +#include + +using namespace seal::util; +using namespace seal; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST(RandomToStandard, RandomToStandardGenerate) + { + shared_ptr generator(UniformRandomGeneratorFactory::default_factory()->create()); + RandomToStandardAdapter rand(generator); + ASSERT_TRUE(rand.generator() == generator); + ASSERT_EQ(static_cast(0), rand.min()); + ASSERT_EQ(static_cast(UINT32_MAX), rand.max()); + bool lower_half = false; + bool upper_half = false; + bool even = false; + bool odd = false; + for (int i = 0; i < 10; i++) + { + uint32_t value = rand(); + if (value < UINT32_MAX / 2) + { + lower_half = true; + } + else + { + upper_half = true; + } + if ((value % 2) == 0) + { + even = true; + } + else + { + odd = true; + } + } + ASSERT_TRUE(lower_half); + ASSERT_TRUE(upper_half); + ASSERT_TRUE(even); + ASSERT_TRUE(odd); + } + } +} diff --git a/tests/seal/util/smallntt.cpp b/tests/seal/util/smallntt.cpp new file mode 100644 index 000000000..1cbf5a399 --- /dev/null +++ b/tests/seal/util/smallntt.cpp @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/mempool.h" +#include "seal/util/uintcore.h" +#include "seal/util/polycore.h" +#include "seal/util/smallntt.h" +#include "seal/defaultparams.h" +#include "seal/util/numth.h" +#include +#include +#include + +using namespace seal; +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST(SmallNTTTablesTest, SmallNTTBasics) + { + MemoryPoolHandle pool = MemoryPoolHandle::Global(); + SmallNTTTables tables; + int coeff_count_power = 1; + int coeff_count = 1 << coeff_count_power; + SmallModulus modulus(small_mods_60bit(0)); + tables.generate(coeff_count_power, modulus); + ASSERT_EQ(2ULL, tables.coeff_count()); + ASSERT_TRUE(tables.is_generated()); + ASSERT_EQ(1, tables.coeff_count_power()); + + coeff_count_power = 2; + coeff_count = 1 << coeff_count_power; + modulus = small_mods_50bit(0); + tables.generate(coeff_count_power, modulus); + ASSERT_EQ(4ULL, tables.coeff_count()); + ASSERT_TRUE(tables.is_generated()); + ASSERT_EQ(2, tables.coeff_count_power()); + + coeff_count_power = 10; + coeff_count = 1 << coeff_count_power; + modulus = small_mods_40bit(0); + tables.generate(coeff_count_power, modulus); + ASSERT_EQ(1024ULL, tables.coeff_count()); + ASSERT_TRUE(tables.is_generated()); + ASSERT_EQ(10, tables.coeff_count_power()); + } + + TEST(SmallNTTTablesTest, SmallNTTPrimitiveRootsTest) + { + MemoryPoolHandle pool = MemoryPoolHandle::Global(); + SmallNTTTables tables; + + int coeff_count_power = 1; + SmallModulus modulus(0xffffffffffc0001ULL); + tables.generate(coeff_count_power, modulus); + ASSERT_EQ(1ULL, tables.get_from_root_powers(0)); + ASSERT_EQ(288794978602139552ULL, tables.get_from_root_powers(1)); + uint64_t inv; + try_mod_inverse(tables.get_from_root_powers(1), modulus.value(), inv); + ASSERT_EQ(inv, tables.get_from_inv_root_powers(1)); + + coeff_count_power = 2; + tables.generate(coeff_count_power, modulus); + ASSERT_EQ(1ULL, tables.get_from_root_powers(0)); + ASSERT_EQ(288794978602139552ULL, tables.get_from_root_powers(1)); + ASSERT_EQ(178930308976060547ULL, tables.get_from_root_powers(2)); + ASSERT_EQ(748001537669050592ULL, tables.get_from_root_powers(3)); + } + + TEST(SmallNTTTablesTest, NegacyclicSmallNTTTest) + { + MemoryPoolHandle pool = MemoryPoolHandle::Global(); + SmallNTTTables tables; + + int coeff_count_power = 1; + SmallModulus modulus(0xffffffffffc0001ULL); + tables.generate(coeff_count_power, modulus); + auto poly(allocate_poly(2, 1, pool)); + poly[0] = 0; + poly[1] = 0; + ntt_negacyclic_harvey(poly.get(), tables); + ASSERT_EQ(0ULL, poly[0]); + ASSERT_EQ(0ULL, poly[1]); + + poly[0] = 1; + poly[1] = 0; + ntt_negacyclic_harvey(poly.get(), tables); + ASSERT_EQ(1ULL, poly[0]); + ASSERT_EQ(1ULL, poly[1]); + + poly[0] = 1; + poly[1] = 1; + ntt_negacyclic_harvey(poly.get(), tables); + ASSERT_EQ(288794978602139553ULL, poly[0]); + ASSERT_EQ(864126526004445282ULL, poly[1]); + } + + TEST(SmallNTTTablesTest, InverseNegacyclicSmallNTTTest) + { + MemoryPoolHandle pool = MemoryPoolHandle::Global(); + SmallNTTTables tables; + + int coeff_count_power = 3; + SmallModulus modulus(0xffffffffffc0001ULL); + tables.generate(coeff_count_power, modulus); + auto poly(allocate_zero_poly(800, 1, pool)); + auto temp(allocate_zero_poly(800, 1, pool)); + + inverse_ntt_negacyclic_harvey(poly.get(), tables); + for (size_t i = 0; i < 800; i++) + { + ASSERT_EQ(0ULL, poly[i]); + } + + random_device rd; + for (size_t i = 0; i < 800; i++) + { + poly[i] = static_cast(rd()) % modulus.value(); + temp[i] = poly[i]; + } + + ntt_negacyclic_harvey(poly.get(), tables); + inverse_ntt_negacyclic_harvey(poly.get(), tables); + for (size_t i = 0; i < 800; i++) + { + ASSERT_EQ(temp[i], poly[i]); + } + } + } +} diff --git a/tests/seal/util/stringtouint64.cpp b/tests/seal/util/stringtouint64.cpp new file mode 100644 index 000000000..0999f65d3 --- /dev/null +++ b/tests/seal/util/stringtouint64.cpp @@ -0,0 +1,271 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/common.h" +#include "seal/util/uintcore.h" +#include +#include + +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST(StringToUInt64, IsHexCharTest) + { + ASSERT_TRUE(is_hex_char('0')); + ASSERT_TRUE(is_hex_char('1')); + ASSERT_TRUE(is_hex_char('2')); + ASSERT_TRUE(is_hex_char('3')); + ASSERT_TRUE(is_hex_char('4')); + ASSERT_TRUE(is_hex_char('5')); + ASSERT_TRUE(is_hex_char('6')); + ASSERT_TRUE(is_hex_char('7')); + ASSERT_TRUE(is_hex_char('8')); + ASSERT_TRUE(is_hex_char('9')); + ASSERT_TRUE(is_hex_char('A')); + ASSERT_TRUE(is_hex_char('B')); + ASSERT_TRUE(is_hex_char('C')); + ASSERT_TRUE(is_hex_char('D')); + ASSERT_TRUE(is_hex_char('E')); + ASSERT_TRUE(is_hex_char('F')); + ASSERT_TRUE(is_hex_char('a')); + ASSERT_TRUE(is_hex_char('b')); + ASSERT_TRUE(is_hex_char('c')); + ASSERT_TRUE(is_hex_char('d')); + ASSERT_TRUE(is_hex_char('e')); + ASSERT_TRUE(is_hex_char('f')); + + ASSERT_FALSE(is_hex_char('/')); + ASSERT_FALSE(is_hex_char(' ')); + ASSERT_FALSE(is_hex_char('+')); + ASSERT_FALSE(is_hex_char('\\')); + ASSERT_FALSE(is_hex_char('G')); + ASSERT_FALSE(is_hex_char('g')); + ASSERT_FALSE(is_hex_char('Z')); + ASSERT_FALSE(is_hex_char('Z')); + } + + TEST(StringToUInt64, HexToNibbleTest) + { + ASSERT_EQ(0, hex_to_nibble('0')); + ASSERT_EQ(1, hex_to_nibble('1')); + ASSERT_EQ(2, hex_to_nibble('2')); + ASSERT_EQ(3, hex_to_nibble('3')); + ASSERT_EQ(4, hex_to_nibble('4')); + ASSERT_EQ(5, hex_to_nibble('5')); + ASSERT_EQ(6, hex_to_nibble('6')); + ASSERT_EQ(7, hex_to_nibble('7')); + ASSERT_EQ(8, hex_to_nibble('8')); + ASSERT_EQ(9, hex_to_nibble('9')); + ASSERT_EQ(10, hex_to_nibble('A')); + ASSERT_EQ(11, hex_to_nibble('B')); + ASSERT_EQ(12, hex_to_nibble('C')); + ASSERT_EQ(13, hex_to_nibble('D')); + ASSERT_EQ(14, hex_to_nibble('E')); + ASSERT_EQ(15, hex_to_nibble('F')); + ASSERT_EQ(10, hex_to_nibble('a')); + ASSERT_EQ(11, hex_to_nibble('b')); + ASSERT_EQ(12, hex_to_nibble('c')); + ASSERT_EQ(13, hex_to_nibble('d')); + ASSERT_EQ(14, hex_to_nibble('e')); + ASSERT_EQ(15, hex_to_nibble('f')); + } + + TEST(StringToUInt64, GetHexStringBitCount) + { + ASSERT_EQ(0, get_hex_string_bit_count(nullptr, 0)); + ASSERT_EQ(0, get_hex_string_bit_count("0", 1)); + ASSERT_EQ(0, get_hex_string_bit_count("000000000", 9)); + ASSERT_EQ(1, get_hex_string_bit_count("1", 1)); + ASSERT_EQ(1, get_hex_string_bit_count("00001", 5)); + ASSERT_EQ(2, get_hex_string_bit_count("2", 1)); + ASSERT_EQ(2, get_hex_string_bit_count("00002", 5)); + ASSERT_EQ(2, get_hex_string_bit_count("3", 1)); + ASSERT_EQ(2, get_hex_string_bit_count("0003", 4)); + ASSERT_EQ(3, get_hex_string_bit_count("4", 1)); + ASSERT_EQ(3, get_hex_string_bit_count("5", 1)); + ASSERT_EQ(3, get_hex_string_bit_count("6", 1)); + ASSERT_EQ(3, get_hex_string_bit_count("7", 1)); + ASSERT_EQ(4, get_hex_string_bit_count("8", 1)); + ASSERT_EQ(4, get_hex_string_bit_count("9", 1)); + ASSERT_EQ(4, get_hex_string_bit_count("A", 1)); + ASSERT_EQ(4, get_hex_string_bit_count("B", 1)); + ASSERT_EQ(4, get_hex_string_bit_count("C", 1)); + ASSERT_EQ(4, get_hex_string_bit_count("D", 1)); + ASSERT_EQ(4, get_hex_string_bit_count("E", 1)); + ASSERT_EQ(4, get_hex_string_bit_count("F", 1)); + ASSERT_EQ(5, get_hex_string_bit_count("10", 2)); + ASSERT_EQ(5, get_hex_string_bit_count("00010", 5)); + ASSERT_EQ(5, get_hex_string_bit_count("11", 2)); + ASSERT_EQ(5, get_hex_string_bit_count("1F", 2)); + ASSERT_EQ(6, get_hex_string_bit_count("20", 2)); + ASSERT_EQ(6, get_hex_string_bit_count("2F", 2)); + ASSERT_EQ(7, get_hex_string_bit_count("7F", 2)); + ASSERT_EQ(7, get_hex_string_bit_count("0007F", 5)); + ASSERT_EQ(8, get_hex_string_bit_count("80", 2)); + ASSERT_EQ(8, get_hex_string_bit_count("FF", 2)); + ASSERT_EQ(8, get_hex_string_bit_count("00FF", 4)); + ASSERT_EQ(9, get_hex_string_bit_count("100", 3)); + ASSERT_EQ(9, get_hex_string_bit_count("000100", 6)); + ASSERT_EQ(22, get_hex_string_bit_count("200000", 6)); + ASSERT_EQ(35, get_hex_string_bit_count("7FFF30001", 9)); + + ASSERT_EQ(15, get_hex_string_bit_count("7FFF30001", 4)); + ASSERT_EQ(3, get_hex_string_bit_count("7FFF30001", 1)); + ASSERT_EQ(0, get_hex_string_bit_count("7FFF30001", 0)); + } + + TEST(StringToUInt64, HexStringToUInt64) + { + uint64_t correct[3]; + uint64_t parsed[3]; + + correct[0] = 0; + correct[1] = 0; + correct[2] = 0; + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("0", 1, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("0", 1, 1, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 1 * sizeof(uint64_t))); + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint(nullptr, 0, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + + correct[0] = 1; + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("1", 1, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("01", 2, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("001", 3, 1, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 1 * sizeof(uint64_t))); + + correct[0] = 0xF; + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("F", 1, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + + correct[0] = 0x10; + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("10", 2, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("010", 3, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + + correct[0] = 0x100; + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("100", 3, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + + correct[0] = 0x123; + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("123", 3, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("00000123", 8, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + + correct[0] = 0; + correct[1] = 1; + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("10000000000000000", 17, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + + correct[0] = 0x1123456789ABCDEF; + correct[1] = 0x1; + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("11123456789ABCDEF", 17, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("000011123456789ABCDEF", 21, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + + correct[0] = 0x3456789ABCDEF123; + correct[1] = 0x23456789ABCDEF12; + correct[2] = 0x123456789ABCDEF1; + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("123456789ABCDEF123456789ABCDEF123456789ABCDEF123", 48, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + + correct[0] = 0xFFFFFFFFFFFFFFFF; + correct[1] = 0xFFFFFFFFFFFFFFFF; + correct[2] = 0xFFFFFFFFFFFFFFFF; + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", 48, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + + correct[0] = 0x100; + correct[1] = 0; + correct[2] = 0; + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("100", 3, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + + correct[0] = 0x10; + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("100", 2, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + + correct[0] = 0x1; + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("100", 1, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + + correct[0] = 0; + parsed[0] = 0x123; + parsed[1] = 0x123; + parsed[2] = 0x123; + hex_string_to_uint("100", 0, 3, parsed); + ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); + } + } +} diff --git a/tests/seal/util/uint64tostring.cpp b/tests/seal/util/uint64tostring.cpp new file mode 100644 index 000000000..1ce133780 --- /dev/null +++ b/tests/seal/util/uint64tostring.cpp @@ -0,0 +1,181 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/common.h" +#include "seal/util/uintcore.h" +#include "seal/util/polycore.h" +#include "seal/util/mempool.h" +#include + +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST(UInt64ToString, NibbleToUpperHexTest) + { + ASSERT_EQ('0', nibble_to_upper_hex(0)); + ASSERT_EQ('1', nibble_to_upper_hex(1)); + ASSERT_EQ('2', nibble_to_upper_hex(2)); + ASSERT_EQ('3', nibble_to_upper_hex(3)); + ASSERT_EQ('4', nibble_to_upper_hex(4)); + ASSERT_EQ('5', nibble_to_upper_hex(5)); + ASSERT_EQ('6', nibble_to_upper_hex(6)); + ASSERT_EQ('7', nibble_to_upper_hex(7)); + ASSERT_EQ('8', nibble_to_upper_hex(8)); + ASSERT_EQ('9', nibble_to_upper_hex(9)); + ASSERT_EQ('A', nibble_to_upper_hex(10)); + ASSERT_EQ('B', nibble_to_upper_hex(11)); + ASSERT_EQ('C', nibble_to_upper_hex(12)); + ASSERT_EQ('D', nibble_to_upper_hex(13)); + ASSERT_EQ('E', nibble_to_upper_hex(14)); + ASSERT_EQ('F', nibble_to_upper_hex(15)); + } + + TEST(UInt64ToString, UInt64ToHexString) + { + uint64_t number[] = { 0, 0, 0 }; + string correct = "0"; + ASSERT_EQ(correct, uint_to_hex_string(number, 3)); + ASSERT_EQ(correct, uint_to_hex_string(number, 1)); + ASSERT_EQ(correct, uint_to_hex_string(number, 0)); + ASSERT_EQ(correct, uint_to_hex_string(nullptr, 0)); + + number[0] = 1; + correct = "1"; + ASSERT_EQ(correct, uint_to_hex_string(number, 3)); + ASSERT_EQ(correct, uint_to_hex_string(number, 1)); + + number[0] = 0xF; + correct = "F"; + ASSERT_EQ(correct, uint_to_hex_string(number, 3)); + + number[0] = 0x10; + correct = "10"; + ASSERT_EQ(correct, uint_to_hex_string(number, 3)); + + number[0] = 0x100; + correct = "100"; + ASSERT_EQ(correct, uint_to_hex_string(number, 3)); + + number[0] = 0x123; + correct = "123"; + ASSERT_EQ(correct, uint_to_hex_string(number, 3)); + + number[0] = 0; + number[1] = 1; + correct = "10000000000000000"; + ASSERT_EQ(correct, uint_to_hex_string(number, 3)); + + number[0] = 0x1123456789ABCDEF; + number[1] = 0x1; + correct = "11123456789ABCDEF"; + ASSERT_EQ(correct, uint_to_hex_string(number, 3)); + + number[0] = 0x3456789ABCDEF123; + number[1] = 0x23456789ABCDEF12; + number[2] = 0x123456789ABCDEF1; + correct = "123456789ABCDEF123456789ABCDEF123456789ABCDEF123"; + ASSERT_EQ(correct, uint_to_hex_string(number, 3)); + + number[0] = 0xFFFFFFFFFFFFFFFF; + number[1] = 0xFFFFFFFFFFFFFFFF; + number[2] = 0xFFFFFFFFFFFFFFFF; + correct = "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF"; + ASSERT_EQ(correct, uint_to_hex_string(number, 3)); + } + + TEST(UInt64ToString, UInt64ToDecString) + { + uint64_t number[] = { 0, 0, 0 }; + string correct = "0"; + MemoryPool &pool = *global_variables::global_memory_pool; + ASSERT_EQ(correct, uint_to_dec_string(number, 3, pool)); + ASSERT_EQ(correct, uint_to_dec_string(number, 1, pool)); + ASSERT_EQ(correct, uint_to_dec_string(number, 0, pool)); + ASSERT_EQ(correct, uint_to_dec_string(nullptr, 0, pool)); + + number[0] = 1; + correct = "1"; + ASSERT_EQ(correct, uint_to_dec_string(number, 3, pool)); + ASSERT_EQ(correct, uint_to_dec_string(number, 1, pool)); + + number[0] = 9; + correct = "9"; + ASSERT_EQ(correct, uint_to_dec_string(number, 3, pool)); + + number[0] = 10; + correct = "10"; + ASSERT_EQ(correct, uint_to_dec_string(number, 3, pool)); + + number[0] = 123; + correct = "123"; + ASSERT_EQ(correct, uint_to_dec_string(number, 3, pool)); + + number[0] = 987654321; + correct = "987654321"; + ASSERT_EQ(correct, uint_to_dec_string(number, 3, pool)); + + number[0] = 0; + number[1] = 1; + correct = "18446744073709551616"; + ASSERT_EQ(correct, uint_to_dec_string(number, 3, pool)); + } + + TEST(UInt64ToString, PolyToHexString) + { + uint64_t number[] = { 0, 0, 0, 0 }; + string correct = "0"; + ASSERT_EQ(correct, poly_to_hex_string(number, 0, 1)); + ASSERT_EQ(correct, poly_to_hex_string(number, 4, 0)); + ASSERT_EQ(correct, poly_to_hex_string(number, 1, 1)); + ASSERT_EQ(correct, poly_to_hex_string(number, 4, 1)); + ASSERT_EQ(correct, poly_to_hex_string(number, 2, 2)); + ASSERT_EQ(correct, poly_to_hex_string(number, 1, 4)); + ASSERT_EQ(correct, poly_to_hex_string(nullptr, 0, 0)); + + number[0] = 1; + correct = "1"; + ASSERT_EQ(correct, poly_to_hex_string(number, 4, 1)); + ASSERT_EQ(correct, poly_to_hex_string(number, 2, 2)); + ASSERT_EQ(correct, poly_to_hex_string(number, 1, 4)); + + number[0] = 0; + number[1] = 1; + correct = "1x^1"; + ASSERT_EQ(correct, poly_to_hex_string(number, 4, 1)); + correct = "10000000000000000"; + ASSERT_EQ(correct, poly_to_hex_string(number, 2, 2)); + ASSERT_EQ(correct, poly_to_hex_string(number, 1, 4)); + + number[0] = 1; + number[1] = 0; + number[2] = 0; + number[3] = 1; + correct = "1x^3 + 1"; + ASSERT_EQ(correct, poly_to_hex_string(number, 4, 1)); + correct = "10000000000000000x^1 + 1"; + ASSERT_EQ(correct, poly_to_hex_string(number, 2, 2)); + correct = "1000000000000000000000000000000000000000000000001"; + ASSERT_EQ(correct, poly_to_hex_string(number, 1, 4)); + + number[0] = 0xF00000000000000F; + number[1] = 0xF0F0F0F0F0F0F0F0; + number[2] = 0; + number[3] = 0; + correct = "F0F0F0F0F0F0F0F0x^1 + F00000000000000F"; + ASSERT_EQ(correct, poly_to_hex_string(number, 4, 1)); + correct = "F0F0F0F0F0F0F0F0F00000000000000F"; + + number[2] = 0xF0FF0F0FF0F0FF0F; + number[3] = 0xBABABABABABABABA; + correct = "BABABABABABABABAF0FF0F0FF0F0FF0Fx^1 + F0F0F0F0F0F0F0F0F00000000000000F"; + ASSERT_EQ(correct, poly_to_hex_string(number, 2, 2)); + correct = "BABABABABABABABAx^3 + F0FF0F0FF0F0FF0Fx^2 + F0F0F0F0F0F0F0F0x^1 + F00000000000000F"; + ASSERT_EQ(correct, poly_to_hex_string(number, 4, 1)); + } + } +} diff --git a/tests/seal/util/uintarith.cpp b/tests/seal/util/uintarith.cpp new file mode 100644 index 000000000..a91bebaf2 --- /dev/null +++ b/tests/seal/util/uintarith.cpp @@ -0,0 +1,1414 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/uintarith.h" +#include + +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST(UIntArith, AddUInt64Generic) + { + unsigned long long result; + ASSERT_FALSE(add_uint64_generic(0ULL, 0ULL, 0, &result)); + ASSERT_EQ(0ULL, result); + ASSERT_FALSE(add_uint64_generic(1ULL, 1ULL, 0, &result)); + ASSERT_EQ(2ULL, result); + ASSERT_FALSE(add_uint64_generic(1ULL, 0ULL, 1, &result)); + ASSERT_EQ(2ULL, result); + ASSERT_FALSE(add_uint64_generic(0ULL, 1ULL, 1, &result)); + ASSERT_EQ(2ULL, result); + ASSERT_FALSE(add_uint64_generic(1ULL, 1ULL, 1, &result)); + ASSERT_EQ(3ULL, result); + ASSERT_TRUE(add_uint64_generic(0xFFFFFFFFFFFFFFFFULL, 1ULL, 0, &result)); + ASSERT_EQ(0ULL, result); + ASSERT_TRUE(add_uint64_generic(1ULL, 0xFFFFFFFFFFFFFFFFULL, 0, &result)); + ASSERT_EQ(0ULL, result); + ASSERT_TRUE(add_uint64_generic(1ULL, 0xFFFFFFFFFFFFFFFFULL, 1, &result)); + ASSERT_EQ(1ULL, result); + ASSERT_TRUE(add_uint64_generic(2ULL, 0xFFFFFFFFFFFFFFFEULL, 0, &result)); + ASSERT_EQ(0ULL, result); + ASSERT_TRUE(add_uint64_generic(2ULL, 0xFFFFFFFFFFFFFFFEULL, 1, &result)); + ASSERT_EQ(1ULL, result); + ASSERT_FALSE(add_uint64_generic(0xF00F00F00F00F00FULL, 0x0FF0FF0FF0FF0FF0ULL, 0, &result)); + ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, result); + ASSERT_TRUE(add_uint64_generic(0xF00F00F00F00F00FULL, 0x0FF0FF0FF0FF0FF0ULL, 1, &result)); + ASSERT_EQ(0x0ULL, result); + } + +#if SEAL_COMPILER == SEAL_COMPILER_MSVC +#pragma optimize ("", off) +#elif SEAL_COMPILER == SEAL_COMPILER_GCC +#pragma GCC push_options +#pragma GCC optimize ("O0") +#elif SEAL_COMPILER == SEAL_COMPILER_CLANG +#pragma clang optimize off +#endif + + TEST(UIntArith, AddUInt64) + { + unsigned long long result; + ASSERT_FALSE(add_uint64(0ULL, 0ULL, 0, &result)); + ASSERT_EQ(0ULL, result); + ASSERT_FALSE(add_uint64(1ULL, 1ULL, 0, &result)); + ASSERT_EQ(2ULL, result); + ASSERT_FALSE(add_uint64(1ULL, 0ULL, 1, &result)); + ASSERT_EQ(2ULL, result); + ASSERT_FALSE(add_uint64(0ULL, 1ULL, 1, &result)); + ASSERT_EQ(2ULL, result); + ASSERT_FALSE(add_uint64(1ULL, 1ULL, 1, &result)); + ASSERT_EQ(3ULL, result); + ASSERT_TRUE(add_uint64(0xFFFFFFFFFFFFFFFFULL, 1ULL, 0, &result)); + ASSERT_EQ(0ULL, result); + ASSERT_TRUE(add_uint64(1ULL, 0xFFFFFFFFFFFFFFFFULL, 0, &result)); + ASSERT_EQ(0ULL, result); + ASSERT_TRUE(add_uint64(1ULL, 0xFFFFFFFFFFFFFFFFULL, 1, &result)); + ASSERT_EQ(1ULL, result); + ASSERT_TRUE(add_uint64(2ULL, 0xFFFFFFFFFFFFFFFEULL, 0, &result)); + ASSERT_EQ(0ULL, result); + ASSERT_TRUE(add_uint64(2ULL, 0xFFFFFFFFFFFFFFFEULL, 1, &result)); + ASSERT_EQ(1ULL, result); + ASSERT_FALSE(add_uint64(0xF00F00F00F00F00FULL, 0x0FF0FF0FF0FF0FF0ULL, 0, &result)); + ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, result); + ASSERT_TRUE(add_uint64(0xF00F00F00F00F00FULL, 0x0FF0FF0FF0FF0FF0ULL, 1, &result)); + ASSERT_EQ(0x0ULL, result); + } + +#if SEAL_COMPILER == SEAL_COMPILER_MSVC +#pragma optimize ("", on) +#elif SEAL_COMPILER == SEAL_COMPILER_GCC +#pragma GCC pop_options +#elif SEAL_COMPILER == SEAL_COMPILER_CLANG +#pragma clang optimize on +#endif + + TEST(UIntArith, SubUInt64Generic) + { + unsigned long long result; + ASSERT_FALSE(sub_uint64_generic(0ULL, 0ULL, 0, &result)); + ASSERT_EQ(0ULL, result); + ASSERT_FALSE(sub_uint64_generic(1ULL, 1ULL, 0, &result)); + ASSERT_EQ(0ULL, result); + ASSERT_FALSE(sub_uint64_generic(1ULL, 0ULL, 1, &result)); + ASSERT_EQ(0ULL, result); + ASSERT_TRUE(sub_uint64_generic(0ULL, 1ULL, 1, &result)); + ASSERT_EQ(0xFFFFFFFFFFFFFFFEULL, result); + ASSERT_TRUE(sub_uint64_generic(1ULL, 1ULL, 1, &result)); + ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, result); + ASSERT_FALSE(sub_uint64_generic(0xFFFFFFFFFFFFFFFFULL, 1ULL, 0, &result)); + ASSERT_EQ(0xFFFFFFFFFFFFFFFEULL, result); + ASSERT_TRUE(sub_uint64_generic(1ULL, 0xFFFFFFFFFFFFFFFFULL, 0, &result)); + ASSERT_EQ(2ULL, result); + ASSERT_TRUE(sub_uint64_generic(1ULL, 0xFFFFFFFFFFFFFFFFULL, 1, &result)); + ASSERT_EQ(1ULL, result); + ASSERT_TRUE(sub_uint64_generic(2ULL, 0xFFFFFFFFFFFFFFFEULL, 0, &result)); + ASSERT_EQ(4ULL, result); + ASSERT_TRUE(sub_uint64_generic(2ULL, 0xFFFFFFFFFFFFFFFEULL, 1, &result)); + ASSERT_EQ(3ULL, result); + ASSERT_FALSE(sub_uint64_generic(0xF00F00F00F00F00FULL, 0x0FF0FF0FF0FF0FF0ULL, 0, &result)); + ASSERT_EQ(0xE01E01E01E01E01FULL, result); + ASSERT_FALSE(sub_uint64_generic(0xF00F00F00F00F00FULL, 0x0FF0FF0FF0FF0FF0ULL, 1, &result)); + ASSERT_EQ(0xE01E01E01E01E01EULL, result); + } + + TEST(UIntArith, SubUInt64) + { + unsigned long long result; + ASSERT_FALSE(sub_uint64(0ULL, 0ULL, 0, &result)); + ASSERT_EQ(0ULL, result); + ASSERT_FALSE(sub_uint64(1ULL, 1ULL, 0, &result)); + ASSERT_EQ(0ULL, result); + ASSERT_FALSE(sub_uint64(1ULL, 0ULL, 1, &result)); + ASSERT_EQ(0ULL, result); + ASSERT_TRUE(sub_uint64(0ULL, 1ULL, 1, &result)); + ASSERT_EQ(0xFFFFFFFFFFFFFFFEULL, result); + ASSERT_TRUE(sub_uint64(1ULL, 1ULL, 1, &result)); + ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, result); + ASSERT_FALSE(sub_uint64(0xFFFFFFFFFFFFFFFFULL, 1ULL, 0, &result)); + ASSERT_EQ(0xFFFFFFFFFFFFFFFEULL, result); + ASSERT_TRUE(sub_uint64(1ULL, 0xFFFFFFFFFFFFFFFFULL, 0, &result)); + ASSERT_EQ(2ULL, result); + ASSERT_TRUE(sub_uint64(1ULL, 0xFFFFFFFFFFFFFFFFULL, 1, &result)); + ASSERT_EQ(1ULL, result); + ASSERT_TRUE(sub_uint64(2ULL, 0xFFFFFFFFFFFFFFFEULL, 0, &result)); + ASSERT_EQ(4ULL, result); + ASSERT_TRUE(sub_uint64(2ULL, 0xFFFFFFFFFFFFFFFEULL, 1, &result)); + ASSERT_EQ(3ULL, result); + ASSERT_FALSE(sub_uint64(0xF00F00F00F00F00FULL, 0x0FF0FF0FF0FF0FF0ULL, 0, &result)); + ASSERT_EQ(0xE01E01E01E01E01FULL, result); + ASSERT_FALSE(sub_uint64(0xF00F00F00F00F00FULL, 0x0FF0FF0FF0FF0FF0ULL, 1, &result)); + ASSERT_EQ(0xE01E01E01E01E01EULL, result); + } + + TEST(UIntArith, AddUIntUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + auto ptr2(allocate_uint(2, pool)); + auto ptr3(allocate_uint(2, pool)); + ptr[0] = 0; + ptr[1] = 0; + ptr2[0] = 0; + ptr2[1] = 0; + ptr3[0] = 0xFFFFFFFFFFFFFFFF; + ptr3[1] = 0xFFFFFFFFFFFFFFFF; + ASSERT_FALSE(add_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); + ASSERT_EQ(static_cast(0), ptr3[0]); + ASSERT_EQ(static_cast(0), ptr3[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 0; + ptr2[1] = 0; + ptr3[0] = 0; + ptr3[1] = 0; + ASSERT_FALSE(add_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFE; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 1; + ptr2[1] = 0; + ptr3[0] = 0; + ptr3[1] = 0; + ASSERT_FALSE(add_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 1; + ptr2[1] = 0; + ptr3[0] = 0xFFFFFFFFFFFFFFFF; + ptr3[1] = 0xFFFFFFFFFFFFFFFF; + ASSERT_TRUE(add_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); + ASSERT_EQ(static_cast(0), ptr3[0]); + ASSERT_EQ(static_cast(0), ptr3[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + ptr3[0] = 0; + ptr3[1] = 0; + + ASSERT_TRUE(add_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr3[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); + ASSERT_TRUE(add_uint_uint(ptr.get(), ptr2.get(), 2, ptr.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0; + ptr2[0] = 1; + ptr2[1] = 0; + ptr3[0] = 0; + ptr3[1] = 0; + ASSERT_FALSE(add_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); + ASSERT_EQ(static_cast(0), ptr3[0]); + ASSERT_EQ(1ULL, ptr3[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 5; + ptr2[0] = 1; + ptr2[1] = 0; + ptr3[0] = 0; + ptr3[1] = 0; + ASSERT_FALSE(add_uint_uint(ptr.get(), 2, ptr2.get(), 1, false, 2, ptr3.get()) != 0); + ASSERT_EQ(static_cast(0), ptr3[0]); + ASSERT_EQ(static_cast(6), ptr3[1]); + ASSERT_FALSE(add_uint_uint(ptr.get(), 2, ptr2.get(), 1, true, 2, ptr3.get()) != 0); + ASSERT_EQ(1ULL, ptr3[0]); + ASSERT_EQ(static_cast(6), ptr3[1]); + } + + TEST(UIntArith, SubUIntUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + auto ptr2(allocate_uint(2, pool)); + auto ptr3(allocate_uint(2, pool)); + ptr[0] = 0; + ptr[1] = 0; + ptr2[0] = 0; + ptr2[1] = 0; + ptr3[0] = 0xFFFFFFFFFFFFFFFF; + ptr3[1] = 0xFFFFFFFFFFFFFFFF; + ASSERT_FALSE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); + ASSERT_EQ(static_cast(0), ptr3[0]); + ASSERT_EQ(static_cast(0), ptr3[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 0; + ptr2[1] = 0; + ptr3[0] = 0; + ptr3[1] = 0; + ASSERT_FALSE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 1; + ptr2[1] = 0; + ptr3[0] = 0; + ptr3[1] = 0; + ASSERT_FALSE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr3[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); + + ptr[0] = 0; + ptr[1] = 0; + ptr2[0] = 1; + ptr2[1] = 0; + ptr3[0] = 0; + ptr3[1] = 0; + ASSERT_TRUE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); + ASSERT_TRUE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + ptr3[0] = 0; + ptr3[1] = 0; + ASSERT_FALSE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); + ASSERT_EQ(static_cast(0), ptr3[0]); + ASSERT_EQ(static_cast(0), ptr3[1]); + ASSERT_FALSE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr.get()) != 0); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFE; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + ptr3[0] = 0; + ptr3[1] = 0; + ASSERT_TRUE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); + + ptr[0] = 0; + ptr[1] = 1; + ptr2[0] = 1; + ptr2[1] = 0; + ptr3[0] = 0; + ptr3[1] = 0; + ASSERT_FALSE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); + ASSERT_EQ(static_cast(0), ptr3[1]); + + ptr[0] = 0; + ptr[1] = 1; + ptr2[0] = 1; + ptr2[1] = 0; + ptr3[0] = 0; + ptr3[1] = 0; + ASSERT_FALSE(sub_uint_uint(ptr.get(), 2, ptr2.get(), 1, false, 2, ptr3.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); + ASSERT_EQ(static_cast(0), ptr3[1]); + ASSERT_FALSE(sub_uint_uint(ptr.get(), 2, ptr2.get(), 1, true, 2, ptr3.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr3[0]); + ASSERT_EQ(static_cast(0), ptr3[1]); + } + + TEST(UIntArith, AddUIntUInt64) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + auto ptr2(allocate_uint(2, pool)); + + ptr[0] = 0ULL; + ptr[1] = 0ULL; + ASSERT_FALSE(add_uint_uint64(ptr.get(), 0ULL, 2, ptr2.get())); + ASSERT_EQ(0ULL, ptr2[0]); + ASSERT_EQ(0ULL, ptr2[1]); + + ptr[0] = 0xFFFFFFFF00000000ULL; + ptr[1] = 0ULL; + ASSERT_FALSE(add_uint_uint64(ptr.get(), 0xFFFFFFFFULL, 2, ptr2.get())); + ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, ptr2[0]); + ASSERT_EQ(0ULL, ptr2[1]); + + ptr[0] = 0xFFFFFFFF00000000ULL; + ptr[1] = 0xFFFFFFFF00000000ULL; + ASSERT_FALSE(add_uint_uint64(ptr.get(), 0x100000000ULL, 2, ptr2.get())); + ASSERT_EQ(0ULL, ptr2[0]); + ASSERT_EQ(0xFFFFFFFF00000001ULL, ptr2[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFFULL; + ptr[1] = 0xFFFFFFFFFFFFFFFFULL; + ASSERT_TRUE(add_uint_uint64(ptr.get(), 1ULL, 2, ptr2.get())); + ASSERT_EQ(0ULL, ptr2[0]); + ASSERT_EQ(0ULL, ptr2[1]); + } + + TEST(UIntArith, SubUIntUInt64) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + auto ptr2(allocate_uint(2, pool)); + + ptr[0] = 0ULL; + ptr[1] = 0ULL; + ASSERT_FALSE(sub_uint_uint64(ptr.get(), 0ULL, 2, ptr2.get())); + ASSERT_EQ(0ULL, ptr2[0]); + ASSERT_EQ(0ULL, ptr2[1]); + + ptr[0] = 0ULL; + ptr[1] = 0ULL; + ASSERT_TRUE(sub_uint_uint64(ptr.get(), 1ULL, 2, ptr2.get())); + ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, ptr2[0]); + ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, ptr2[1]); + + ptr[0] = 1ULL; + ptr[1] = 0ULL; + ASSERT_TRUE(sub_uint_uint64(ptr.get(), 2ULL, 2, ptr2.get())); + ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, ptr2[0]); + ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, ptr2[1]); + + ptr[0] = 0xFFFFFFFF00000000ULL; + ptr[1] = 0ULL; + ASSERT_FALSE(sub_uint_uint64(ptr.get(), 0xFFFFFFFFULL, 2, ptr2.get())); + ASSERT_EQ(0xFFFFFFFE00000001ULL, ptr2[0]); + ASSERT_EQ(0ULL, ptr2[1]); + + ptr[0] = 0xFFFFFFFF00000000ULL; + ptr[1] = 0xFFFFFFFF00000000ULL; + ASSERT_FALSE(sub_uint_uint64(ptr.get(), 0x100000000ULL, 2, ptr2.get())); + ASSERT_EQ(0xFFFFFFFE00000000ULL, ptr2[0]); + ASSERT_EQ(0xFFFFFFFF00000000ULL, ptr2[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFFULL; + ptr[1] = 0xFFFFFFFFFFFFFFFFULL; + ASSERT_FALSE(sub_uint_uint64(ptr.get(), 1ULL, 2, ptr2.get())); + ASSERT_EQ(0xFFFFFFFFFFFFFFFEULL, ptr2[0]); + ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, ptr2[1]); + } + + TEST(UIntArith, IncrementUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr1(allocate_uint(2, pool)); + auto ptr2(allocate_uint(2, pool)); + ptr1[0] = 0; + ptr1[1] = 0; + ASSERT_FALSE(increment_uint(ptr1.get(), 2, ptr2.get()) != 0); + ASSERT_EQ(1ULL, ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + ASSERT_FALSE(increment_uint(ptr2.get(), 2, ptr1.get()) != 0); + ASSERT_EQ(static_cast(2), ptr1[0]); + ASSERT_EQ(static_cast(0), ptr1[1]); + + ptr1[0] = 0xFFFFFFFFFFFFFFFF; + ptr1[1] = 0; + ASSERT_FALSE(increment_uint(ptr1.get(), 2, ptr2.get()) != 0); + ASSERT_EQ(static_cast(0), ptr2[0]); + ASSERT_EQ(1ULL, ptr2[1]); + ASSERT_FALSE(increment_uint(ptr2.get(), 2, ptr1.get()) != 0); + ASSERT_EQ(1ULL, ptr1[0]); + ASSERT_EQ(1ULL, ptr1[1]); + + ptr1[0] = 0xFFFFFFFFFFFFFFFF; + ptr1[1] = 1; + ASSERT_FALSE(increment_uint(ptr1.get(), 2, ptr2.get()) != 0); + ASSERT_EQ(static_cast(0), ptr2[0]); + ASSERT_EQ(static_cast(2), ptr2[1]); + ASSERT_FALSE(increment_uint(ptr2.get(), 2, ptr1.get()) != 0); + ASSERT_EQ(1ULL, ptr1[0]); + ASSERT_EQ(static_cast(2), ptr1[1]); + + ptr1[0] = 0xFFFFFFFFFFFFFFFE; + ptr1[1] = 0xFFFFFFFFFFFFFFFF; + ASSERT_FALSE(increment_uint(ptr1.get(), 2, ptr2.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr2[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr2[1]); + ASSERT_TRUE(increment_uint(ptr2.get(), 2, ptr1.get()) != 0); + ASSERT_EQ(static_cast(0), ptr1[0]); + ASSERT_EQ(static_cast(0), ptr1[1]); + ASSERT_FALSE(increment_uint(ptr1.get(), 2, ptr2.get()) != 0); + ASSERT_EQ(1ULL, ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + } + + TEST(UIntArith, DecrementUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr1(allocate_uint(2, pool)); + auto ptr2(allocate_uint(2, pool)); + ptr1[0] = 2; + ptr1[1] = 2; + ASSERT_FALSE(decrement_uint(ptr1.get(), 2, ptr2.get()) != 0); + ASSERT_EQ(1ULL, ptr2[0]); + ASSERT_EQ(static_cast(2), ptr2[1]); + ASSERT_FALSE(decrement_uint(ptr2.get(), 2, ptr1.get()) != 0); + ASSERT_EQ(static_cast(0), ptr1[0]); + ASSERT_EQ(static_cast(2), ptr1[1]); + ASSERT_FALSE(decrement_uint(ptr1.get(), 2, ptr2.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr2[0]); + ASSERT_EQ(1ULL, ptr2[1]); + ASSERT_FALSE(decrement_uint(ptr2.get(), 2, ptr1.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr1[0]); + ASSERT_EQ(1ULL, ptr1[1]); + + ptr1[0] = 2; + ptr1[1] = 1; + ASSERT_FALSE(decrement_uint(ptr1.get(), 2, ptr2.get()) != 0); + ASSERT_EQ(1ULL, ptr2[0]); + ASSERT_EQ(1ULL, ptr2[1]); + ASSERT_FALSE(decrement_uint(ptr2.get(), 2, ptr1.get()) != 0); + ASSERT_EQ(static_cast(0), ptr1[0]); + ASSERT_EQ(1ULL, ptr1[1]); + ASSERT_FALSE(decrement_uint(ptr1.get(), 2, ptr2.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + ASSERT_FALSE(decrement_uint(ptr2.get(), 2, ptr1.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr1[0]); + ASSERT_EQ(static_cast(0), ptr1[1]); + + ptr1[0] = 2; + ptr1[1] = 0; + ASSERT_FALSE(decrement_uint(ptr1.get(), 2, ptr2.get()) != 0); + ASSERT_EQ(1ULL, ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + ASSERT_FALSE(decrement_uint(ptr2.get(), 2, ptr1.get()) != 0); + ASSERT_EQ(static_cast(0), ptr1[0]); + ASSERT_EQ(static_cast(0), ptr1[1]); + ASSERT_TRUE(decrement_uint(ptr1.get(), 2, ptr2.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr2[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr2[1]); + ASSERT_FALSE(decrement_uint(ptr2.get(), 2, ptr1.get()) != 0); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr1[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr1[1]); + } + + TEST(UIntArith, NegateUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + ptr[0] = 0; + ptr[1] = 0; + negate_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + + ptr[0] = 1; + ptr[1] = 0; + negate_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); + negate_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(1ULL, ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + + ptr[0] = 2; + ptr[1] = 0; + negate_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); + negate_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(static_cast(2), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + + ptr[0] = 0; + ptr[1] = 1; + negate_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); + negate_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(1ULL, ptr[1]); + + ptr[0] = 0; + ptr[1] = 2; + negate_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr[1]); + negate_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(2), ptr[1]); + + ptr[0] = 1; + ptr[1] = 1; + negate_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr[1]); + negate_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(1ULL, ptr[0]); + ASSERT_EQ(1ULL, ptr[1]); + } + + TEST(UIntArith, LeftShiftUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + auto ptr2(allocate_uint(2, pool)); + ptr[0] = 0; + ptr[1] = 0; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + left_shift_uint(ptr.get(), 0, 2, ptr2.get()); + ASSERT_EQ(static_cast(0), ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + left_shift_uint(ptr.get(), 10, 2, ptr2.get()); + ASSERT_EQ(static_cast(0), ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + left_shift_uint(ptr.get(), 10, 2, ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + + ptr[0] = 0x5555555555555555; + ptr[1] = 0xAAAAAAAAAAAAAAAA; + left_shift_uint(ptr.get(), 0, 2, ptr2.get()); + ASSERT_EQ(static_cast(0x5555555555555555), ptr2[0]); + ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[1]); + left_shift_uint(ptr.get(), 0, 2, ptr.get()); + ASSERT_EQ(static_cast(0x5555555555555555), ptr[0]); + ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr[1]); + left_shift_uint(ptr.get(), 1, 2, ptr2.get()); + ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[0]); + ASSERT_EQ(static_cast(0x5555555555555554), ptr2[1]); + left_shift_uint(ptr.get(), 2, 2, ptr2.get()); + ASSERT_EQ(static_cast(0x5555555555555554), ptr2[0]); + ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAA9), ptr2[1]); + left_shift_uint(ptr.get(), 64, 2, ptr2.get()); + ASSERT_EQ(static_cast(0), ptr2[0]); + ASSERT_EQ(static_cast(0x5555555555555555), ptr2[1]); + left_shift_uint(ptr.get(), 65, 2, ptr2.get()); + ASSERT_EQ(static_cast(0), ptr2[0]); + ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[1]); + left_shift_uint(ptr.get(), 127, 2, ptr2.get()); + ASSERT_EQ(static_cast(0), ptr2[0]); + ASSERT_EQ(static_cast(0x8000000000000000), ptr2[1]); + left_shift_uint(ptr.get(), 128, 2, ptr2.get()); + ASSERT_EQ(static_cast(0), ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + + left_shift_uint(ptr.get(), 2, 2, ptr.get()); + ASSERT_EQ(static_cast(0x5555555555555554), ptr[0]); + ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAA9), ptr[1]); + left_shift_uint(ptr.get(), 64, 2, ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0x5555555555555554), ptr[1]); + } + + TEST(UIntArith, RightShiftUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + auto ptr2(allocate_uint(2, pool)); + ptr[0] = 0; + ptr[1] = 0; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + right_shift_uint(ptr.get(), 0, 2, ptr2.get()); + ASSERT_EQ(static_cast(0), ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + right_shift_uint(ptr.get(), 10, 2, ptr2.get()); + ASSERT_EQ(static_cast(0), ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + right_shift_uint(ptr.get(), 10, 2, ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + + ptr[0] = 0x5555555555555555; + ptr[1] = 0xAAAAAAAAAAAAAAAA; + right_shift_uint(ptr.get(), 0, 2, ptr2.get()); + ASSERT_EQ(static_cast(0x5555555555555555), ptr2[0]); + ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[1]); + right_shift_uint(ptr.get(), 0, 2, ptr.get()); + ASSERT_EQ(static_cast(0x5555555555555555), ptr[0]); + ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr[1]); + right_shift_uint(ptr.get(), 1, 2, ptr2.get()); + ASSERT_EQ(static_cast(0x2AAAAAAAAAAAAAAA), ptr2[0]); + ASSERT_EQ(static_cast(0x5555555555555555), ptr2[1]); + right_shift_uint(ptr.get(), 2, 2, ptr2.get()); + ASSERT_EQ(static_cast(0x9555555555555555), ptr2[0]); + ASSERT_EQ(static_cast(0x2AAAAAAAAAAAAAAA), ptr2[1]); + right_shift_uint(ptr.get(), 64, 2, ptr2.get()); + ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + right_shift_uint(ptr.get(), 65, 2, ptr2.get()); + ASSERT_EQ(static_cast(0x5555555555555555), ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + right_shift_uint(ptr.get(), 127, 2, ptr2.get()); + ASSERT_EQ(1ULL, ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + right_shift_uint(ptr.get(), 128, 2, ptr2.get()); + ASSERT_EQ(static_cast(0), ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + + right_shift_uint(ptr.get(), 2, 2, ptr.get()); + ASSERT_EQ(static_cast(0x9555555555555555), ptr[0]); + ASSERT_EQ(static_cast(0x2AAAAAAAAAAAAAAA), ptr[1]); + right_shift_uint(ptr.get(), 64, 2, ptr.get()); + ASSERT_EQ(static_cast(0x2AAAAAAAAAAAAAAA), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + } + + TEST(UIntArith, HalfRoundUpUInt) + { + half_round_up_uint(nullptr, 0, nullptr); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + auto ptr2(allocate_uint(2, pool)); + ptr[0] = 0; + ptr[1] = 0; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + half_round_up_uint(ptr.get(), 2, ptr2.get()); + ASSERT_EQ(static_cast(0), ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + half_round_up_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + + ptr[0] = 1; + ptr[1] = 0; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + half_round_up_uint(ptr.get(), 2, ptr2.get()); + ASSERT_EQ(1ULL, ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + half_round_up_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(1ULL, ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + + ptr[0] = 2; + ptr[1] = 0; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + half_round_up_uint(ptr.get(), 2, ptr2.get()); + ASSERT_EQ(1ULL, ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + half_round_up_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(1ULL, ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + + ptr[0] = 3; + ptr[1] = 0; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + half_round_up_uint(ptr.get(), 2, ptr2.get()); + ASSERT_EQ(static_cast(2), ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + + ptr[0] = 4; + ptr[1] = 0; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + half_round_up_uint(ptr.get(), 2, ptr2.get()); + ASSERT_EQ(static_cast(2), ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + half_round_up_uint(ptr.get(), 2, ptr2.get()); + ASSERT_EQ(static_cast(0), ptr2[0]); + ASSERT_EQ(static_cast(0x8000000000000000), ptr2[1]); + half_round_up_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0x8000000000000000), ptr[1]); + } + + TEST(UIntArith, NotUInt) + { + not_uint(nullptr, 0, nullptr); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0; + not_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); + + ptr[0] = 0xFFFFFFFF00000000; + ptr[1] = 0xFFFF0000FFFF0000; + not_uint(ptr.get(), 2, ptr.get()); + ASSERT_EQ(static_cast(0x00000000FFFFFFFF), ptr[0]); + ASSERT_EQ(static_cast(0x0000FFFF0000FFFF), ptr[1]); + } + + TEST(UIntArith, AndUIntUInt) + { + and_uint_uint(nullptr, nullptr, 0, nullptr); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + auto ptr2(allocate_uint(2, pool)); + auto ptr3(allocate_uint(2, pool)); + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0; + ptr2[0] = 0; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + ptr3[0] = 0xFFFFFFFFFFFFFFFF; + ptr3[1] = 0xFFFFFFFFFFFFFFFF; + and_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); + ASSERT_EQ(static_cast(0), ptr3[0]); + ASSERT_EQ(static_cast(0), ptr3[1]); + + ptr[0] = 0xFFFFFFFF00000000; + ptr[1] = 0xFFFF0000FFFF0000; + ptr2[0] = 0x0000FFFF0000FFFF; + ptr2[1] = 0xFF00FF00FF00FF00; + ptr3[0] = 0; + ptr3[1] = 0; + and_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); + ASSERT_EQ(static_cast(0x0000FFFF00000000), ptr3[0]); + ASSERT_EQ(static_cast(0xFF000000FF000000), ptr3[1]); + and_uint_uint(ptr.get(), ptr2.get(), 2, ptr.get()); + ASSERT_EQ(static_cast(0x0000FFFF00000000), ptr[0]); + ASSERT_EQ(static_cast(0xFF000000FF000000), ptr[1]); + } + + TEST(UIntArith, OrUIntUInt) + { + or_uint_uint(nullptr, nullptr, 0, nullptr); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + auto ptr2(allocate_uint(2, pool)); + auto ptr3(allocate_uint(2, pool)); + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0; + ptr2[0] = 0; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + ptr3[0] = 0; + ptr3[1] = 0; + or_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); + + ptr[0] = 0xFFFFFFFF00000000; + ptr[1] = 0xFFFF0000FFFF0000; + ptr2[0] = 0x0000FFFF0000FFFF; + ptr2[1] = 0xFF00FF00FF00FF00; + ptr3[0] = 0; + ptr3[1] = 0; + or_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); + ASSERT_EQ(static_cast(0xFFFFFFFF0000FFFF), ptr3[0]); + ASSERT_EQ(static_cast(0xFFFFFF00FFFFFF00), ptr3[1]); + or_uint_uint(ptr.get(), ptr2.get(), 2, ptr.get()); + ASSERT_EQ(static_cast(0xFFFFFFFF0000FFFF), ptr[0]); + ASSERT_EQ(static_cast(0xFFFFFF00FFFFFF00), ptr[1]); + } + + TEST(UIntArith, XorUIntUInt) + { + xor_uint_uint(nullptr, nullptr, 0, nullptr); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + auto ptr2(allocate_uint(2, pool)); + auto ptr3(allocate_uint(2, pool)); + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0; + ptr2[0] = 0; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + ptr3[0] = 0; + ptr3[1] = 0; + xor_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); + + ptr[0] = 0xFFFFFFFF00000000; + ptr[1] = 0xFFFF0000FFFF0000; + ptr2[0] = 0x0000FFFF0000FFFF; + ptr2[1] = 0xFF00FF00FF00FF00; + ptr3[0] = 0; + ptr3[1] = 0; + xor_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); + ASSERT_EQ(static_cast(0xFFFF00000000FFFF), ptr3[0]); + ASSERT_EQ(static_cast(0x00FFFF0000FFFF00), ptr3[1]); + xor_uint_uint(ptr.get(), ptr2.get(), 2, ptr.get()); + ASSERT_EQ(static_cast(0xFFFF00000000FFFF), ptr[0]); + ASSERT_EQ(static_cast(0x00FFFF0000FFFF00), ptr[1]); + } + + TEST(UIntArith, MultiplyUInt64Generic) + { + unsigned long long result[2]; + + multiply_uint64_generic(0ULL, 0ULL, result); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + multiply_uint64_generic(0ULL, 1ULL, result); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + multiply_uint64_generic(1ULL, 0ULL, result); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + multiply_uint64_generic(1ULL, 1ULL, result); + ASSERT_EQ(1ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + multiply_uint64_generic(0x100000000ULL, 0xFAFABABAULL, result); + ASSERT_EQ(0xFAFABABA00000000ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + multiply_uint64_generic(0x1000000000ULL, 0xFAFABABAULL, result); + ASSERT_EQ(0xAFABABA000000000ULL, result[0]); + ASSERT_EQ(0xFULL, result[1]); + multiply_uint64_generic(1111222233334444ULL, 5555666677778888ULL, result); + ASSERT_EQ(4140785562324247136ULL, result[0]); + ASSERT_EQ(334670460471ULL, result[1]); + } + + TEST(UIntArith, MultiplyUInt64) + { + unsigned long long result[2]; + + multiply_uint64(0ULL, 0ULL, result); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + multiply_uint64(0ULL, 1ULL, result); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + multiply_uint64(1ULL, 0ULL, result); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + multiply_uint64(1ULL, 1ULL, result); + ASSERT_EQ(1ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + multiply_uint64(0x100000000ULL, 0xFAFABABAULL, result); + ASSERT_EQ(0xFAFABABA00000000ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + multiply_uint64(0x1000000000ULL, 0xFAFABABAULL, result); + ASSERT_EQ(0xAFABABA000000000ULL, result[0]); + ASSERT_EQ(0xFULL, result[1]); + multiply_uint64(1111222233334444ULL, 5555666677778888ULL, result); + ASSERT_EQ(4140785562324247136ULL, result[0]); + ASSERT_EQ(334670460471ULL, result[1]); + } + + TEST(UIntArith, MultiplyUInt64HW64Generic) + { + unsigned long long result; + + multiply_uint64_hw64_generic(0ULL, 0ULL, &result); + ASSERT_EQ(0ULL, result); + multiply_uint64_hw64_generic(0ULL, 1ULL, &result); + ASSERT_EQ(0ULL, result); + multiply_uint64_hw64_generic(1ULL, 0ULL, &result); + ASSERT_EQ(0ULL, result); + multiply_uint64_hw64_generic(1ULL, 1ULL, &result); + ASSERT_EQ(0ULL, result); + multiply_uint64_hw64_generic(0x100000000ULL, 0xFAFABABAULL, &result); + ASSERT_EQ(0ULL, result); + multiply_uint64_hw64_generic(0x1000000000ULL, 0xFAFABABAULL, &result); + ASSERT_EQ(0xFULL, result); + multiply_uint64_hw64_generic(1111222233334444ULL, 5555666677778888ULL, &result); + ASSERT_EQ(334670460471ULL, result); + } + + TEST(UIntArith, MultiplyUInt64HW64) + { + unsigned long long result; + + multiply_uint64_hw64(0ULL, 0ULL, &result); + ASSERT_EQ(0ULL, result); + multiply_uint64_hw64(0ULL, 1ULL, &result); + ASSERT_EQ(0ULL, result); + multiply_uint64_hw64(1ULL, 0ULL, &result); + ASSERT_EQ(0ULL, result); + multiply_uint64_hw64(1ULL, 1ULL, &result); + ASSERT_EQ(0ULL, result); + multiply_uint64_hw64(0x100000000ULL, 0xFAFABABAUL, &result); + ASSERT_EQ(0ULL, result); + multiply_uint64_hw64(0x1000000000ULL, 0xFAFABABAULL, &result); + ASSERT_EQ(0xFULL, result); + multiply_uint64_hw64(1111222233334444ULL, 5555666677778888ULL, &result); + ASSERT_EQ(334670460471ULL, result); + } + + TEST(UIntArith, MultiplyUIntUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + auto ptr2(allocate_uint(2, pool)); + auto ptr3(allocate_uint(4, pool)); + ptr[0] = 0; + ptr[1] = 0; + ptr2[0] = 0; + ptr2[1] = 0; + ptr3[0] = 0xFFFFFFFFFFFFFFFF; + ptr3[1] = 0xFFFFFFFFFFFFFFFF; + ptr3[2] = 0xFFFFFFFFFFFFFFFF; + ptr3[3] = 0xFFFFFFFFFFFFFFFF; + multiply_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); + ASSERT_EQ(static_cast(0), ptr3[0]); + ASSERT_EQ(static_cast(0), ptr3[1]); + ASSERT_EQ(static_cast(0), ptr3[2]); + ASSERT_EQ(static_cast(0), ptr3[3]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 0; + ptr2[1] = 0; + ptr3[0] = 0xFFFFFFFFFFFFFFFF; + ptr3[1] = 0xFFFFFFFFFFFFFFFF; + ptr3[2] = 0xFFFFFFFFFFFFFFFF; + ptr3[3] = 0xFFFFFFFFFFFFFFFF; + multiply_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); + ASSERT_EQ(static_cast(0), ptr3[0]); + ASSERT_EQ(static_cast(0), ptr3[1]); + ASSERT_EQ(static_cast(0), ptr3[2]); + ASSERT_EQ(static_cast(0), ptr3[3]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 1; + ptr2[1] = 0; + ptr3[0] = 0; + ptr3[1] = 0; + ptr3[2] = 0; + ptr3[3] = 0; + multiply_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); + ASSERT_EQ(static_cast(0), ptr3[2]); + ASSERT_EQ(static_cast(0), ptr3[3]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 0; + ptr2[1] = 1; + ptr3[0] = 0; + ptr3[1] = 0; + ptr3[2] = 0; + ptr3[3] = 0; + multiply_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); + ASSERT_EQ(static_cast(0), ptr3[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[2]); + ASSERT_EQ(static_cast(0), ptr3[3]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + ptr3[0] = 0; + ptr3[1] = 0; + ptr3[2] = 0; + ptr3[3] = 0; + multiply_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); + ASSERT_EQ(1ULL, ptr3[0]); + ASSERT_EQ(static_cast(0), ptr3[1]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr3[2]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[3]); + + ptr[0] = 9756571004902751654ul; + ptr[1] = 731952007397389984; + ptr2[0] = 701538366196406307; + ptr2[1] = 1699883529753102283; + ptr3[0] = 0; + ptr3[1] = 0; + ptr3[2] = 0; + ptr3[3] = 0; + multiply_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); + ASSERT_EQ(static_cast(9585656442714717618ul), ptr3[0]); + ASSERT_EQ(static_cast(1817697005049051848), ptr3[1]); + ASSERT_EQ(static_cast(14447416709120365380ul), ptr3[2]); + ASSERT_EQ(static_cast(67450014862939159), ptr3[3]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + ptr3[0] = 0; + ptr3[1] = 0; + ptr3[2] = 0; + ptr3[3] = 0; + multiply_uint_uint(ptr.get(), 2, ptr2.get(), 1, 2, ptr3.get()); + ASSERT_EQ(1ULL, ptr3[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); + ASSERT_EQ(static_cast(0), ptr3[2]); + ASSERT_EQ(static_cast(0), ptr3[3]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + ptr3[0] = 0; + ptr3[1] = 0; + ptr3[2] = 0; + ptr3[3] = 0; + multiply_uint_uint(ptr.get(), 2, ptr2.get(), 1, 3, ptr3.get()); + ASSERT_EQ(1ULL, ptr3[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr3[2]); + ASSERT_EQ(static_cast(0), ptr3[3]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + ptr3[0] = 0; + ptr3[1] = 0; + ptr3[2] = 0; + ptr3[3] = 0; + multiply_truncate_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); + ASSERT_EQ(1ULL, ptr3[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); + ASSERT_EQ(static_cast(0), ptr3[2]); + ASSERT_EQ(static_cast(0), ptr3[3]); + } + + TEST(UIntArith, MultiplyUIntUInt64) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(3, pool)); + auto result(allocate_uint(4, pool)); + + ptr[0] = 0; + ptr[1] = 0; + ptr[2] = 0; + multiply_uint_uint64(ptr.get(), 3, 0ULL, 4, result.get()); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + ASSERT_EQ(0ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + + ptr[0] = 0xFFFFFFFFF; + ptr[1] = 0xAAAAAAAAA; + ptr[2] = 0x111111111; + multiply_uint_uint64(ptr.get(), 3, 0ULL, 4, result.get()); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + ASSERT_EQ(0ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + + ptr[0] = 0xFFFFFFFFF; + ptr[1] = 0xAAAAAAAAA; + ptr[2] = 0x111111111; + multiply_uint_uint64(ptr.get(), 3, 1ULL, 4, result.get()); + ASSERT_EQ(0xFFFFFFFFFULL, result[0]); + ASSERT_EQ(0xAAAAAAAAAULL, result[1]); + ASSERT_EQ(0x111111111ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + + ptr[0] = 0xFFFFFFFFF; + ptr[1] = 0xAAAAAAAAA; + ptr[2] = 0x111111111; + multiply_uint_uint64(ptr.get(), 3, 0x10000ULL, 4, result.get()); + ASSERT_EQ(0xFFFFFFFFF0000ULL, result[0]); + ASSERT_EQ(0xAAAAAAAAA0000ULL, result[1]); + ASSERT_EQ(0x1111111110000ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + + ptr[0] = 0xFFFFFFFFF; + ptr[1] = 0xAAAAAAAAA; + ptr[2] = 0x111111111; + multiply_uint_uint64(ptr.get(), 3, 0x100000000ULL, 4, result.get()); + ASSERT_EQ(0xFFFFFFFF00000000ULL, result[0]); + ASSERT_EQ(0xAAAAAAAA0000000FULL, result[1]); + ASSERT_EQ(0x111111110000000AULL, result[2]); + ASSERT_EQ(1ULL, result[3]); + + ptr[0] = 5656565656565656ULL; + ptr[1] = 3434343434343434ULL; + ptr[2] = 1212121212121212ULL; + multiply_uint_uint64(ptr.get(), 3, 7878787878787878ULL, 4, result.get()); + ASSERT_EQ(8891370032116156560ULL, result[0]); + ASSERT_EQ(127835914414679452ULL, result[1]); + ASSERT_EQ(9811042505314082702ULL, result[2]); + ASSERT_EQ(517709026347ULL, result[3]); + } + + TEST(UIntArith, DivideUIntUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + divide_uint_uint_inplace(nullptr, nullptr, 0, nullptr, pool); + divide_uint_uint(nullptr, nullptr, 0, nullptr, nullptr, pool); + + auto ptr(allocate_uint(4, pool)); + auto ptr2(allocate_uint(4, pool)); + auto ptr3(allocate_uint(4, pool)); + auto ptr4(allocate_uint(4, pool)); + ptr[0] = 0; + ptr[1] = 0; + ptr2[0] = 0; + ptr2[1] = 1; + ptr3[0] = 0xFFFFFFFFFFFFFFFF; + ptr3[1] = 0xFFFFFFFFFFFFFFFF; + divide_uint_uint_inplace(ptr.get(), ptr2.get(), 2, ptr3.get(), pool); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + ASSERT_EQ(static_cast(0), ptr3[0]); + ASSERT_EQ(static_cast(0), ptr3[1]); + + ptr[0] = 0; + ptr[1] = 0; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + ptr3[0] = 0xFFFFFFFFFFFFFFFF; + ptr3[1] = 0xFFFFFFFFFFFFFFFF; + divide_uint_uint_inplace(ptr.get(), ptr2.get(), 2, ptr3.get(), pool); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + ASSERT_EQ(static_cast(0), ptr3[0]); + ASSERT_EQ(static_cast(0), ptr3[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFE; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + ptr3[0] = 0xFFFFFFFFFFFFFFFF; + ptr3[1] = 0xFFFFFFFFFFFFFFFF; + divide_uint_uint_inplace(ptr.get(), ptr2.get(), 2, ptr3.get(), pool); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); + ASSERT_EQ(static_cast(0), ptr3[0]); + ASSERT_EQ(static_cast(0), ptr3[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + ptr3[0] = 0xFFFFFFFFFFFFFFFF; + ptr3[1] = 0xFFFFFFFFFFFFFFFF; + divide_uint_uint_inplace(ptr.get(), ptr2.get(), 2, ptr3.get(), pool); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + ASSERT_EQ(1ULL, ptr3[0]); + ASSERT_EQ(static_cast(0), ptr3[1]); + + ptr[0] = 14; + ptr[1] = 0; + ptr2[0] = 3; + ptr2[1] = 0; + ptr3[0] = 0xFFFFFFFFFFFFFFFF; + ptr3[1] = 0xFFFFFFFFFFFFFFFF; + divide_uint_uint_inplace(ptr.get(), ptr2.get(), 2, ptr3.get(), pool); + ASSERT_EQ(static_cast(2), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + ASSERT_EQ(static_cast(4), ptr3[0]); + ASSERT_EQ(static_cast(0), ptr3[1]); + + ptr[0] = 9585656442714717620ul; + ptr[1] = 1817697005049051848; + ptr[2] = 14447416709120365380ul; + ptr[3] = 67450014862939159; + ptr2[0] = 701538366196406307; + ptr2[1] = 1699883529753102283; + ptr2[2] = 0; + ptr2[3] = 0; + ptr3[0] = 0xFFFFFFFFFFFFFFFF; + ptr3[1] = 0xFFFFFFFFFFFFFFFF; + ptr3[2] = 0xFFFFFFFFFFFFFFFF; + ptr3[3] = 0xFFFFFFFFFFFFFFFF; + ptr4[0] = 0xFFFFFFFFFFFFFFFF; + ptr4[1] = 0xFFFFFFFFFFFFFFFF; + ptr4[2] = 0xFFFFFFFFFFFFFFFF; + ptr4[3] = 0xFFFFFFFFFFFFFFFF; + divide_uint_uint(ptr.get(), ptr2.get(), 4, ptr3.get(), ptr4.get(), pool); + ASSERT_EQ(static_cast(2), ptr4[0]); + ASSERT_EQ(static_cast(0), ptr4[1]); + ASSERT_EQ(static_cast(0), ptr4[2]); + ASSERT_EQ(static_cast(0), ptr4[3]); + ASSERT_EQ(static_cast(9756571004902751654ul), ptr3[0]); + ASSERT_EQ(static_cast(731952007397389984), ptr3[1]); + ASSERT_EQ(static_cast(0), ptr3[2]); + ASSERT_EQ(static_cast(0), ptr3[3]); + + divide_uint_uint_inplace(ptr.get(), ptr2.get(), 4, ptr3.get(), pool); + ASSERT_EQ(static_cast(2), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + ASSERT_EQ(static_cast(0), ptr[2]); + ASSERT_EQ(static_cast(0), ptr[3]); + ASSERT_EQ(static_cast(9756571004902751654ul), ptr3[0]); + ASSERT_EQ(static_cast(731952007397389984), ptr3[1]); + ASSERT_EQ(static_cast(0), ptr3[2]); + ASSERT_EQ(static_cast(0), ptr3[3]); + } + + TEST(UIntArith, DivideUInt128UInt64) + { + uint64_t input[2]; + uint64_t quotient[2]; + + input[0] = 0; + input[1] = 0; + divide_uint128_uint64_inplace(input, 1ULL, quotient); + ASSERT_EQ(0ULL, input[0]); + ASSERT_EQ(0ULL, input[1]); + ASSERT_EQ(0ULL, quotient[0]); + ASSERT_EQ(0ULL, quotient[1]); + + input[0] = 1; + input[1] = 0; + divide_uint128_uint64_inplace(input, 1ULL, quotient); + ASSERT_EQ(0ULL, input[0]); + ASSERT_EQ(0ULL, input[1]); + ASSERT_EQ(1ULL, quotient[0]); + ASSERT_EQ(0ULL, quotient[1]); + + input[0] = 0x10101010; + input[1] = 0x2B2B2B2B; + divide_uint128_uint64_inplace(input, 0x1000ULL, quotient); + ASSERT_EQ(0x10ULL, input[0]); + ASSERT_EQ(0ULL, input[1]); + ASSERT_EQ(0xB2B0000000010101ULL, quotient[0]); + ASSERT_EQ(0x2B2B2ULL, quotient[1]); + + input[0] = 1212121212121212ULL; + input[1] = 3434343434343434ULL; + divide_uint128_uint64_inplace(input, 5656565656565656ULL, quotient); + ASSERT_EQ(5252525252525252ULL, input[0]); + ASSERT_EQ(0ULL, input[1]); + ASSERT_EQ(11199808901895084909ULL, quotient[0]); + ASSERT_EQ(0ULL, quotient[1]); + } + + TEST(UIntArith, DivideUInt192UInt64) + { + uint64_t input[3]; + uint64_t quotient[3]; + + input[0] = 0; + input[1] = 0; + input[2] = 0; + divide_uint192_uint64_inplace(input, 1ULL, quotient); + ASSERT_EQ(0ULL, input[0]); + ASSERT_EQ(0ULL, input[1]); + ASSERT_EQ(0ULL, input[2]); + ASSERT_EQ(0ULL, quotient[0]); + ASSERT_EQ(0ULL, quotient[1]); + ASSERT_EQ(0ULL, quotient[2]); + + input[0] = 1; + input[1] = 0; + input[2] = 0; + divide_uint192_uint64_inplace(input, 1ULL, quotient); + ASSERT_EQ(0ULL, input[0]); + ASSERT_EQ(0ULL, input[1]); + ASSERT_EQ(0ULL, input[2]); + ASSERT_EQ(1ULL, quotient[0]); + ASSERT_EQ(0ULL, quotient[1]); + ASSERT_EQ(0ULL, quotient[2]); + + input[0] = 0x10101010; + input[1] = 0x2B2B2B2B; + input[2] = 0xF1F1F1F1; + divide_uint192_uint64_inplace(input, 0x1000ULL, quotient); + ASSERT_EQ(0x10ULL, input[0]); + ASSERT_EQ(0ULL, input[1]); + ASSERT_EQ(0ULL, input[2]); + ASSERT_EQ(0xB2B0000000010101ULL, quotient[0]); + ASSERT_EQ(0x1F1000000002B2B2ULL, quotient[1]); + ASSERT_EQ(0xF1F1FULL, quotient[2]); + + input[0] = 1212121212121212ULL; + input[1] = 3434343434343434ULL; + input[2] = 5656565656565656ULL; + divide_uint192_uint64_inplace(input, 7878787878787878ULL, quotient); + ASSERT_EQ(7272727272727272ULL, input[0]); + ASSERT_EQ(0ULL, input[1]); + ASSERT_EQ(0ULL, input[2]); + ASSERT_EQ(17027763760347278414ULL, quotient[0]); + ASSERT_EQ(13243816258047883211ULL, quotient[1]); + ASSERT_EQ(0ULL, quotient[2]); + } + + TEST(UIntArith, ExponentiateUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto input(allocate_zero_uint(2, pool)); + auto result(allocate_zero_uint(8, pool)); + + result[0] = 1, result[1] = 2, result[2] = 3, result[3] = 4; + result[4] = 5, result[5] = 6, result[6] = 7, result[7] = 8; + + uint64_t exponent[2]{ 0, 0 }; + + input[0] = 0xFFF; + input[1] = 0; + exponentiate_uint(input.get(), 2, exponent, 1, 1, result.get(), pool); + ASSERT_EQ(1ULL, result[0]); + ASSERT_EQ(2ULL, result[1]); + + exponentiate_uint(input.get(), 2, exponent, 1, 2, result.get(), pool); + ASSERT_EQ(1ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + + exponentiate_uint(input.get(), 1, exponent, 1, 4, result.get(), pool); + ASSERT_EQ(1ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + ASSERT_EQ(0ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + + input[0] = 123; + exponent[0] = 5; + exponentiate_uint(input.get(), 1, exponent, 2, 2, result.get(), pool); + ASSERT_EQ(28153056843ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + + input[0] = 1; + exponent[0] = 1; + exponent[1] = 1; + exponentiate_uint(input.get(), 1, exponent, 2, 2, result.get(), pool); + ASSERT_EQ(1ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + + input[0] = 0; + input[1] = 1; + exponent[0] = 7; + exponent[1] = 0; + exponentiate_uint(input.get(), 2, exponent, 2, 8, result.get(), pool); + ASSERT_EQ(0ULL, result[0]); + ASSERT_EQ(0ULL, result[1]); + ASSERT_EQ(0ULL, result[2]); + ASSERT_EQ(0ULL, result[3]); + ASSERT_EQ(0ULL, result[4]); + ASSERT_EQ(0ULL, result[5]); + ASSERT_EQ(0ULL, result[6]); + ASSERT_EQ(1ULL, result[7]); + + input[0] = 121212; + input[1] = 343434; + exponent[0] = 3; + exponent[1] = 0; + exponentiate_uint(input.get(), 2, exponent, 2, 8, result.get(), pool); + ASSERT_EQ(1780889000200128ULL, result[0]); + ASSERT_EQ(15137556501701088ULL, result[1]); + ASSERT_EQ(42889743421486416ULL, result[2]); + ASSERT_EQ(40506979898070504ULL, result[3]); + ASSERT_EQ(0ULL, result[4]); + ASSERT_EQ(0ULL, result[5]); + ASSERT_EQ(0ULL, result[6]); + ASSERT_EQ(0ULL, result[7]); + } + + TEST(UIntArith, ExponentiateUInt64) + { + ASSERT_EQ(0ULL, exponentiate_uint64(0ULL, 1ULL)); + ASSERT_EQ(1ULL, exponentiate_uint64(1ULL, 0ULL)); + ASSERT_EQ(0ULL, exponentiate_uint64(0ULL, 0xFFFFFFFFFFFFFFFFULL)); + ASSERT_EQ(1ULL, exponentiate_uint64(0xFFFFFFFFFFFFFFFFULL, 0ULL)); + ASSERT_EQ(25ULL, exponentiate_uint64(5ULL, 2ULL)); + ASSERT_EQ(32ULL, exponentiate_uint64(2ULL, 5ULL)); + ASSERT_EQ(0x1000000000000000ULL, exponentiate_uint64(0x10ULL, 15ULL)); + ASSERT_EQ(0ULL, exponentiate_uint64(0x10ULL, 16ULL)); + ASSERT_EQ(12389286314587456613ULL, exponentiate_uint64(123456789ULL, 13ULL)); + } + } +} diff --git a/tests/seal/util/uintarithmod.cpp b/tests/seal/util/uintarithmod.cpp new file mode 100644 index 000000000..5dc6d2cb7 --- /dev/null +++ b/tests/seal/util/uintarithmod.cpp @@ -0,0 +1,353 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/uintcore.h" +#include "seal/util/uintarithmod.h" +#include +#include + +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST(UIntArithMod, IncrementUIntMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto value(allocate_uint(2, pool)); + auto modulus(allocate_uint(2, pool)); + value[0] = 0; + value[1] = 0; + modulus[0] = 3; + modulus[1] = 0; + increment_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(1ULL, value[0]); + ASSERT_EQ(static_cast(0), value[1]); + increment_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(static_cast(2), value[0]); + ASSERT_EQ(static_cast(0), value[1]); + increment_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(static_cast(0), value[0]); + ASSERT_EQ(static_cast(0), value[1]); + + value[0] = 0xFFFFFFFFFFFFFFFD; + value[1] = 0xFFFFFFFFFFFFFFFF; + modulus[0] = 0xFFFFFFFFFFFFFFFF; + modulus[1] = 0xFFFFFFFFFFFFFFFF; + increment_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), value[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), value[1]); + increment_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(static_cast(0), value[0]); + ASSERT_EQ(static_cast(0), value[1]); + increment_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(1ULL, value[0]); + ASSERT_EQ(static_cast(0), value[1]); + } + + TEST(UIntArithMod, DecrementUIntMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto value(allocate_uint(2, pool)); + auto modulus(allocate_uint(2, pool)); + value[0] = 2; + value[1] = 0; + modulus[0] = 3; + modulus[1] = 0; + decrement_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(1ULL, value[0]); + ASSERT_EQ(static_cast(0), value[1]); + decrement_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(static_cast(0), value[0]); + ASSERT_EQ(static_cast(0), value[1]); + decrement_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(static_cast(2), value[0]); + ASSERT_EQ(static_cast(0), value[1]); + + value[0] = 1; + value[1] = 0; + modulus[0] = 0xFFFFFFFFFFFFFFFF; + modulus[1] = 0xFFFFFFFFFFFFFFFF; + decrement_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(static_cast(0), value[0]); + ASSERT_EQ(static_cast(0), value[1]); + decrement_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), value[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), value[1]); + decrement_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFD), value[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), value[1]); + } + + TEST(UIntArithMod, NegateUIntMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto value(allocate_uint(2, pool)); + auto modulus(allocate_uint(2, pool)); + value[0] = 0; + value[1] = 0; + modulus[0] = 3; + modulus[1] = 0; + negate_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(static_cast(0), value[0]); + ASSERT_EQ(static_cast(0), value[1]); + + value[0] = 1; + value[1] = 0; + modulus[0] = 3; + modulus[1] = 0; + negate_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(static_cast(2), value[0]); + ASSERT_EQ(static_cast(0), value[1]); + negate_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(1ULL, value[0]); + ASSERT_EQ(static_cast(0), value[1]); + + value[0] = 2; + value[1] = 0; + modulus[0] = 0xFFFFFFFFFFFFFFFF; + modulus[1] = 0xFFFFFFFFFFFFFFFF; + negate_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFD), value[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), value[1]); + negate_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(static_cast(2), value[0]); + ASSERT_EQ(static_cast(0), value[1]); + } + + TEST(UIntArithMod, Div2UIntMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto value(allocate_uint(2, pool)); + auto modulus(allocate_uint(2, pool)); + value[0] = 0; + value[1] = 0; + modulus[0] = 3; + modulus[1] = 0; + div2_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(0ULL, value[0]); + ASSERT_EQ(0ULL, value[1]); + + value[0] = 1; + value[1] = 0; + modulus[0] = 3; + modulus[1] = 0; + div2_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(2ULL, value[0]); + ASSERT_EQ(0ULL, value[1]); + + value[0] = 8; + value[1] = 0; + modulus[0] = 17; + modulus[1] = 0; + div2_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(4ULL, value[0]); + ASSERT_EQ(0ULL, value[1]); + + value[0] = 5; + value[1] = 0; + modulus[0] = 17; + modulus[1] = 0; + div2_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(11ULL, value[0]); + ASSERT_EQ(0ULL, value[1]); + + value[0] = 1; + value[1] = 0; + modulus[0] = 0xFFFFFFFFFFFFFFFFULL; + modulus[1] = 0xFFFFFFFFFFFFFFFFULL; + div2_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(0ULL, value[0]); + ASSERT_EQ(0x8000000000000000ULL, value[1]); + + value[0] = 3; + value[1] = 0; + modulus[0] = 0xFFFFFFFFFFFFFFFFULL; + modulus[1] = 0xFFFFFFFFFFFFFFFFULL; + div2_uint_mod(value.get(), modulus.get(), 2, value.get()); + ASSERT_EQ(1ULL, value[0]); + ASSERT_EQ(0x8000000000000000ULL, value[1]); + } + + TEST(UIntArithMod, AddUIntUIntMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto value1(allocate_uint(2, pool)); + auto value2(allocate_uint(2, pool)); + auto modulus(allocate_uint(2, pool)); + value1[0] = 0; + value1[1] = 0; + value2[0] = 0; + value2[1] = 0; + modulus[0] = 3; + modulus[1] = 0; + add_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); + ASSERT_EQ(static_cast(0), value1[0]); + ASSERT_EQ(static_cast(0), value1[1]); + + value1[0] = 1; + value1[1] = 0; + value2[0] = 1; + value2[1] = 0; + modulus[0] = 3; + modulus[1] = 0; + add_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); + ASSERT_EQ(static_cast(2), value1[0]); + ASSERT_EQ(static_cast(0), value1[1]); + + value1[0] = 1; + value1[1] = 0; + value2[0] = 2; + value2[1] = 0; + modulus[0] = 3; + modulus[1] = 0; + add_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); + ASSERT_EQ(static_cast(0), value1[0]); + ASSERT_EQ(static_cast(0), value1[1]); + + value1[0] = 2; + value1[1] = 0; + value2[0] = 2; + value2[1] = 0; + modulus[0] = 3; + modulus[1] = 0; + add_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); + ASSERT_EQ(1ULL, value1[0]); + ASSERT_EQ(static_cast(0), value1[1]); + + value1[0] = 0xFFFFFFFFFFFFFFFE; + value1[1] = 0xFFFFFFFFFFFFFFFF; + value2[0] = 0xFFFFFFFFFFFFFFFE; + value2[1] = 0xFFFFFFFFFFFFFFFF; + modulus[0] = 0xFFFFFFFFFFFFFFFF; + modulus[1] = 0xFFFFFFFFFFFFFFFF; + add_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFD), value1[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), value1[1]); + } + + TEST(UIntArithMod, SubUIntUIntMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto value1(allocate_uint(2, pool)); + auto value2(allocate_uint(2, pool)); + auto modulus(allocate_uint(2, pool)); + value1[0] = 0; + value1[1] = 0; + value2[0] = 0; + value2[1] = 0; + modulus[0] = 3; + modulus[1] = 0; + sub_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); + ASSERT_EQ(static_cast(0), value1[0]); + ASSERT_EQ(static_cast(0), value1[1]); + + value1[0] = 2; + value1[1] = 0; + value2[0] = 1; + value2[1] = 0; + modulus[0] = 3; + modulus[1] = 0; + sub_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); + ASSERT_EQ(1ULL, value1[0]); + ASSERT_EQ(static_cast(0), value1[1]); + + value1[0] = 1; + value1[1] = 0; + value2[0] = 2; + value2[1] = 0; + modulus[0] = 3; + modulus[1] = 0; + sub_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); + ASSERT_EQ(static_cast(2), value1[0]); + ASSERT_EQ(static_cast(0), value1[1]); + + value1[0] = 2; + value1[1] = 0; + value2[0] = 2; + value2[1] = 0; + modulus[0] = 3; + modulus[1] = 0; + sub_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); + ASSERT_EQ(static_cast(0), value1[0]); + ASSERT_EQ(static_cast(0), value1[1]); + + value1[0] = 1; + value1[1] = 0; + value2[0] = 0xFFFFFFFFFFFFFFFE; + value2[1] = 0xFFFFFFFFFFFFFFFF; + modulus[0] = 0xFFFFFFFFFFFFFFFF; + modulus[1] = 0xFFFFFFFFFFFFFFFF; + sub_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); + ASSERT_EQ(static_cast(2), value1[0]); + ASSERT_EQ(static_cast(0), value1[1]); + } + + TEST(UIntArithMod, TryInvertUIntMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto value(allocate_uint(2, pool)); + auto modulus(allocate_uint(2, pool)); + value[0] = 0; + value[1] = 0; + modulus[0] = 5; + modulus[1] = 0; + ASSERT_FALSE(try_invert_uint_mod(value.get(), modulus.get(), 2, value.get(), pool)); + + value[0] = 1; + value[1] = 0; + modulus[0] = 5; + modulus[1] = 0; + ASSERT_TRUE(try_invert_uint_mod(value.get(), modulus.get(), 2, value.get(), pool)); + ASSERT_EQ(1ULL, value[0]); + ASSERT_EQ(static_cast(0), value[1]); + + value[0] = 2; + value[1] = 0; + modulus[0] = 5; + modulus[1] = 0; + ASSERT_TRUE(try_invert_uint_mod(value.get(), modulus.get(), 2, value.get(), pool)); + ASSERT_EQ(static_cast(3), value[0]); + ASSERT_EQ(static_cast(0), value[1]); + + value[0] = 3; + value[1] = 0; + modulus[0] = 5; + modulus[1] = 0; + ASSERT_TRUE(try_invert_uint_mod(value.get(), modulus.get(), 2, value.get(), pool)); + ASSERT_EQ(static_cast(2), value[0]); + ASSERT_EQ(static_cast(0), value[1]); + + value[0] = 4; + value[1] = 0; + modulus[0] = 5; + modulus[1] = 0; + ASSERT_TRUE(try_invert_uint_mod(value.get(), modulus.get(), 2, value.get(), pool)); + ASSERT_EQ(static_cast(4), value[0]); + ASSERT_EQ(static_cast(0), value[1]); + + value[0] = 2; + value[1] = 0; + modulus[0] = 6; + modulus[1] = 0; + ASSERT_FALSE(try_invert_uint_mod(value.get(), modulus.get(), 2, value.get(), pool)); + + value[0] = 3; + value[1] = 0; + modulus[0] = 6; + modulus[1] = 0; + ASSERT_FALSE(try_invert_uint_mod(value.get(), modulus.get(), 2, value.get(), pool)); + + value[0] = 331975426; + value[1] = 0; + modulus[0] = 1351315121; + modulus[1] = 0; + ASSERT_TRUE(try_invert_uint_mod(value.get(), modulus.get(), 2, value.get(), pool)); + ASSERT_EQ(static_cast(1052541512), value[0]); + ASSERT_EQ(static_cast(0), value[1]); + } + } +} diff --git a/tests/seal/util/uintarithsmallmod.cpp b/tests/seal/util/uintarithsmallmod.cpp new file mode 100644 index 000000000..fb1525557 --- /dev/null +++ b/tests/seal/util/uintarithsmallmod.cpp @@ -0,0 +1,390 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/uintcore.h" +#include "seal/util/uintarithsmallmod.h" +#include "seal/smallmodulus.h" +#include "seal/memorymanager.h" + +using namespace seal::util; +using namespace seal; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST(UIntArithSmallMod, IncrementUIntSmallMod) + { + SmallModulus mod(2); + ASSERT_EQ(1ULL, increment_uint_mod(0, mod)); + ASSERT_EQ(0ULL, increment_uint_mod(1ULL, mod)); + + mod = 0x10000; + ASSERT_EQ(1ULL, increment_uint_mod(0, mod)); + ASSERT_EQ(2ULL, increment_uint_mod(1ULL, mod)); + ASSERT_EQ(0ULL, increment_uint_mod(0xFFFFULL, mod)); + + mod = 4611686018427289601ULL; + ASSERT_EQ(1ULL, increment_uint_mod(0, mod)); + ASSERT_EQ(0ULL, increment_uint_mod(4611686018427289600ULL, mod)); + ASSERT_EQ(1ULL, increment_uint_mod(0, mod)); + } + + TEST(UIntArithSmallMod, DecrementUIntSmallMod) + { + SmallModulus mod(2); + ASSERT_EQ(0ULL, decrement_uint_mod(1, mod)); + ASSERT_EQ(1ULL, decrement_uint_mod(0ULL, mod)); + + mod = 0x10000; + ASSERT_EQ(0ULL, decrement_uint_mod(1, mod)); + ASSERT_EQ(1ULL, decrement_uint_mod(2ULL, mod)); + ASSERT_EQ(0xFFFFULL, decrement_uint_mod(0ULL, mod)); + + mod = 4611686018427289601ULL; + ASSERT_EQ(0ULL, decrement_uint_mod(1, mod)); + ASSERT_EQ(4611686018427289600ULL, decrement_uint_mod(0ULL, mod)); + ASSERT_EQ(0ULL, decrement_uint_mod(1, mod)); + } + + TEST(UIntArithSmallMod, NegateUIntSmallMod) + { + SmallModulus mod(2); + ASSERT_EQ(0ULL, negate_uint_mod(0, mod)); + ASSERT_EQ(1ULL, negate_uint_mod(1, mod)); + + mod = 0xFFFFULL; + ASSERT_EQ(0ULL, negate_uint_mod(0, mod)); + ASSERT_EQ(0xFFFEULL, negate_uint_mod(1, mod)); + ASSERT_EQ(0x1ULL, negate_uint_mod(0xFFFEULL, mod)); + + mod = 0x10000ULL; + ASSERT_EQ(0ULL, negate_uint_mod(0, mod)); + ASSERT_EQ(0xFFFFULL, negate_uint_mod(1, mod)); + ASSERT_EQ(0x1ULL, negate_uint_mod(0xFFFFULL, mod)); + + mod = 4611686018427289601ULL; + ASSERT_EQ(0ULL, negate_uint_mod(0, mod)); + ASSERT_EQ(4611686018427289600ULL, negate_uint_mod(1, mod)); + } + + TEST(UIntArithSmallMod, Div2UIntSmallMod) + { + SmallModulus mod(3); + ASSERT_EQ(0ULL, div2_uint_mod(0ULL, mod)); + ASSERT_EQ(2ULL, div2_uint_mod(1ULL, mod)); + + mod = 17; + ASSERT_EQ(11ULL, div2_uint_mod(5ULL, mod)); + ASSERT_EQ(4ULL, div2_uint_mod(8ULL, mod)); + + mod = 0xFFFFFFFFFFFFFFFULL; + ASSERT_EQ(0x800000000000000ULL, div2_uint_mod(1ULL, mod)); + ASSERT_EQ(0x800000000000001ULL, div2_uint_mod(3ULL, mod)); + } + + TEST(UIntArithSmallMod, AddUIntSmallMod) + { + SmallModulus mod(2); + ASSERT_EQ(0ULL, add_uint_uint_mod(0, 0, mod)); + ASSERT_EQ(1ULL, add_uint_uint_mod(0, 1, mod)); + ASSERT_EQ(1ULL, add_uint_uint_mod(1, 0, mod)); + ASSERT_EQ(0ULL, add_uint_uint_mod(1, 1, mod)); + + mod = 10; + ASSERT_EQ(0ULL, add_uint_uint_mod(0, 0, mod)); + ASSERT_EQ(1ULL, add_uint_uint_mod(0, 1, mod)); + ASSERT_EQ(1ULL, add_uint_uint_mod(1, 0, mod)); + ASSERT_EQ(2ULL, add_uint_uint_mod(1, 1, mod)); + ASSERT_EQ(4ULL, add_uint_uint_mod(7, 7, mod)); + ASSERT_EQ(3ULL, add_uint_uint_mod(6, 7, mod)); + + mod = 4611686018427289601; + ASSERT_EQ(0ULL, add_uint_uint_mod(0, 0, mod)); + ASSERT_EQ(1ULL, add_uint_uint_mod(0, 1, mod)); + ASSERT_EQ(1ULL, add_uint_uint_mod(1, 0, mod)); + ASSERT_EQ(2ULL, add_uint_uint_mod(1, 1, mod)); + ASSERT_EQ(0ULL, add_uint_uint_mod(2305843009213644800ULL, 2305843009213644801ULL, mod)); + ASSERT_EQ(1ULL, add_uint_uint_mod(2305843009213644801ULL, 2305843009213644801ULL, mod)); + ASSERT_EQ(4611686018427289599ULL, add_uint_uint_mod(4611686018427289600ULL, 4611686018427289600ULL, mod)); + } + + TEST(UIntArithSmallMod, SubUIntSmallMod) + { + SmallModulus mod(2); + ASSERT_EQ(0ULL, sub_uint_uint_mod(0, 0, mod)); + ASSERT_EQ(1ULL, sub_uint_uint_mod(0, 1, mod)); + ASSERT_EQ(1ULL, sub_uint_uint_mod(1, 0, mod)); + ASSERT_EQ(0ULL, sub_uint_uint_mod(1, 1, mod)); + + mod = 10; + ASSERT_EQ(0ULL, sub_uint_uint_mod(0, 0, mod)); + ASSERT_EQ(9ULL, sub_uint_uint_mod(0, 1, mod)); + ASSERT_EQ(1ULL, sub_uint_uint_mod(1, 0, mod)); + ASSERT_EQ(0ULL, sub_uint_uint_mod(1, 1, mod)); + ASSERT_EQ(0ULL, sub_uint_uint_mod(7, 7, mod)); + ASSERT_EQ(9ULL, sub_uint_uint_mod(6, 7, mod)); + ASSERT_EQ(1ULL, sub_uint_uint_mod(7, 6, mod)); + + mod = 4611686018427289601ULL; + ASSERT_EQ(0ULL, sub_uint_uint_mod(0, 0, mod)); + ASSERT_EQ(4611686018427289600ULL, sub_uint_uint_mod(0, 1, mod)); + ASSERT_EQ(1ULL, sub_uint_uint_mod(1, 0, mod)); + ASSERT_EQ(0ULL, sub_uint_uint_mod(1, 1, mod)); + ASSERT_EQ(4611686018427289600ULL, sub_uint_uint_mod(2305843009213644800ULL, 2305843009213644801ULL, mod)); + ASSERT_EQ(1ULL, sub_uint_uint_mod(2305843009213644801ULL, 2305843009213644800ULL, mod)); + ASSERT_EQ(0ULL, sub_uint_uint_mod(2305843009213644801ULL, 2305843009213644801ULL, mod)); + ASSERT_EQ(0ULL, sub_uint_uint_mod(4611686018427289600ULL, 4611686018427289600ULL, mod)); + } + + TEST(UIntArithSmallMod, BarrettReduce128) + { + uint64_t input[2]; + + SmallModulus mod(2); + input[0] = 0; + input[1] = 0; + ASSERT_EQ(0ULL, barrett_reduce_128(input, mod)); + input[0] = 1; + input[1] = 0; + ASSERT_EQ(1ULL, barrett_reduce_128(input, mod)); + input[0] = 0xFFFFFFFFFFFFFFFFULL; + input[1] = 0xFFFFFFFFFFFFFFFFULL; + ASSERT_EQ(1ULL, barrett_reduce_128(input, mod)); + + mod = 3; + input[0] = 0; + input[1] = 0; + ASSERT_EQ(0ULL, barrett_reduce_128(input, mod)); + input[0] = 1; + input[1] = 0; + ASSERT_EQ(1ULL, barrett_reduce_128(input, mod)); + input[0] = 123; + input[1] = 456; + ASSERT_EQ(0ULL, barrett_reduce_128(input, mod)); + input[0] = 0xFFFFFFFFFFFFFFFFULL; + input[1] = 0xFFFFFFFFFFFFFFFFULL; + ASSERT_EQ(0ULL, barrett_reduce_128(input, mod)); + + mod = 13131313131313ULL; + input[0] = 0; + input[1] = 0; + ASSERT_EQ(0ULL, barrett_reduce_128(input, mod)); + input[0] = 1; + input[1] = 0; + ASSERT_EQ(1ULL, barrett_reduce_128(input, mod)); + input[0] = 123; + input[1] = 456; + ASSERT_EQ(8722750765283ULL, barrett_reduce_128(input, mod)); + input[0] = 24242424242424; + input[1] = 79797979797979; + ASSERT_EQ(1010101010101ULL, barrett_reduce_128(input, mod)); + } + + TEST(UIntArithSmallMod, MultiplyUIntUIntSmallMod) + { + SmallModulus mod(2); + ASSERT_EQ(0ULL, multiply_uint_uint_mod(0, 0, mod)); + ASSERT_EQ(0ULL, multiply_uint_uint_mod(0, 1, mod)); + ASSERT_EQ(0ULL, multiply_uint_uint_mod(1, 0, mod)); + ASSERT_EQ(1ULL, multiply_uint_uint_mod(1, 1, mod)); + + mod = 10; + ASSERT_EQ(0ULL, multiply_uint_uint_mod(0, 0, mod)); + ASSERT_EQ(0ULL, multiply_uint_uint_mod(0, 1, mod)); + ASSERT_EQ(0ULL, multiply_uint_uint_mod(1, 0, mod)); + ASSERT_EQ(1ULL, multiply_uint_uint_mod(1, 1, mod)); + ASSERT_EQ(9ULL, multiply_uint_uint_mod(7, 7, mod)); + ASSERT_EQ(2ULL, multiply_uint_uint_mod(6, 7, mod)); + ASSERT_EQ(2ULL, multiply_uint_uint_mod(7, 6, mod)); + + mod = 4611686018427289601ULL; + ASSERT_EQ(0ULL, multiply_uint_uint_mod(0, 0, mod)); + ASSERT_EQ(0ULL, multiply_uint_uint_mod(0, 1, mod)); + ASSERT_EQ(0ULL, multiply_uint_uint_mod(1, 0, mod)); + ASSERT_EQ(1ULL, multiply_uint_uint_mod(1, 1, mod)); + ASSERT_EQ(1152921504606822400ULL, multiply_uint_uint_mod(2305843009213644800ULL, 2305843009213644801ULL, mod)); + ASSERT_EQ(1152921504606822400ULL, multiply_uint_uint_mod(2305843009213644801ULL, 2305843009213644800ULL, mod)); + ASSERT_EQ(3458764513820467201ULL, multiply_uint_uint_mod(2305843009213644801ULL, 2305843009213644801ULL, mod)); + ASSERT_EQ(1ULL, multiply_uint_uint_mod(4611686018427289600ULL, 4611686018427289600ULL, mod)); + } + + TEST(UIntArithSmallMod, ModuloUIntSmallMod) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto value(allocate_uint(4, pool)); + + SmallModulus mod(2); + value[0] = 0; + value[1] = 0; + value[2] = 0; + modulo_uint_inplace(value.get(), 3, mod); + ASSERT_EQ(0ULL, value[0]); + ASSERT_EQ(0ULL, value[1]); + ASSERT_EQ(0ULL, value[2]); + + value[0] = 1; + value[1] = 0; + value[2] = 0; + modulo_uint_inplace(value.get(), 3, mod); + ASSERT_EQ(1ULL, value[0]); + ASSERT_EQ(0ULL, value[1]); + ASSERT_EQ(0ULL, value[2]); + + value[0] = 2; + value[1] = 0; + value[2] = 0; + modulo_uint_inplace(value.get(), 3, mod); + ASSERT_EQ(0ULL, value[0]); + ASSERT_EQ(0ULL, value[1]); + ASSERT_EQ(0ULL, value[2]); + + value[0] = 3; + value[1] = 0; + value[2] = 0; + modulo_uint_inplace(value.get(), 3, mod); + ASSERT_EQ(1ULL, value[0]); + ASSERT_EQ(0ULL, value[1]); + ASSERT_EQ(0ULL, value[2]); + + mod = 0xFFFF; + value[0] = 9585656442714717620ul; + value[1] = 1817697005049051848; + value[2] = 0; + modulo_uint_inplace(value.get(), 3, mod); + ASSERT_EQ(65143ULL, value[0]); + ASSERT_EQ(0ULL, value[1]); + ASSERT_EQ(0ULL, value[2]); + + mod = 0x1000; + value[0] = 9585656442714717620ul; + value[1] = 1817697005049051848; + value[2] = 0; + modulo_uint_inplace(value.get(), 3, mod); + ASSERT_EQ(0xDB4ULL, value[0]); + ASSERT_EQ(0ULL, value[1]); + ASSERT_EQ(0ULL, value[2]); + + mod = 0xFFFFFFFFC001ULL; + value[0] = 9585656442714717620ul; + value[1] = 1817697005049051848; + value[2] = 14447416709120365380ul; + value[3] = 67450014862939159; + modulo_uint_inplace(value.get(), 4, mod); + ASSERT_EQ(124510066632001ULL, value[0]); + ASSERT_EQ(0ULL, value[1]); + ASSERT_EQ(0ULL, value[2]); + ASSERT_EQ(0ULL, value[3]); + } + + TEST(UIntArithSmallMod, TryInvertUIntSmallMod) + { + uint64_t result; + SmallModulus mod(5); + ASSERT_FALSE(try_invert_uint_mod(0, mod, result)); + ASSERT_TRUE(try_invert_uint_mod(1, mod, result)); + ASSERT_EQ(1ULL, result); + ASSERT_TRUE(try_invert_uint_mod(2, mod, result)); + ASSERT_EQ(3ULL, result); + ASSERT_TRUE(try_invert_uint_mod(3, mod, result)); + ASSERT_EQ(2ULL, result); + ASSERT_TRUE(try_invert_uint_mod(4, mod, result)); + ASSERT_EQ(4ULL, result); + + mod = 6; + ASSERT_FALSE(try_invert_uint_mod(2, mod, result)); + ASSERT_FALSE(try_invert_uint_mod(3, mod, result)); + ASSERT_TRUE(try_invert_uint_mod(5, mod, result)); + ASSERT_EQ(5ULL, result); + + mod = 1351315121; + ASSERT_TRUE(try_invert_uint_mod(331975426, mod, result)); + ASSERT_EQ(1052541512ULL, result); + } + + TEST(UIntArithSmallMod, TryPrimitiveRootSmallMod) + { + uint64_t result; + SmallModulus mod(11); + + ASSERT_TRUE(try_primitive_root(2, mod, result)); + ASSERT_EQ(10ULL, result); + + mod = 29; + ASSERT_TRUE(try_primitive_root(2, mod, result)); + ASSERT_EQ(28ULL, result); + + vector corrects{ 12, 17 }; + ASSERT_TRUE(try_primitive_root(4, mod, result)); + ASSERT_TRUE(std::find(corrects.begin(), corrects.end(), result) != corrects.end()); + + mod = 1234565441; + ASSERT_TRUE(try_primitive_root(2, mod, result)); + ASSERT_EQ(1234565440ULL, result); + corrects = { 984839708, 273658408, 249725733, 960907033 }; + ASSERT_TRUE(try_primitive_root(8, mod, result)); + ASSERT_TRUE(std::find(corrects.begin(), corrects.end(), result) != corrects.end()); + } + + TEST(UIntArithSmallMod, IsPrimitiveRootSmallMod) + { + SmallModulus mod(11); + ASSERT_TRUE(is_primitive_root(10, 2, mod)); + ASSERT_FALSE(is_primitive_root(9, 2, mod)); + ASSERT_FALSE(is_primitive_root(10, 4, mod)); + + mod = 29; + ASSERT_TRUE(is_primitive_root(28, 2, mod)); + ASSERT_TRUE(is_primitive_root(12, 4, mod)); + ASSERT_FALSE(is_primitive_root(12, 2, mod)); + ASSERT_FALSE(is_primitive_root(12, 8, mod)); + + + mod = 1234565441ULL; + ASSERT_TRUE(is_primitive_root(1234565440ULL, 2, mod)); + ASSERT_TRUE(is_primitive_root(960907033ULL, 8, mod)); + ASSERT_TRUE(is_primitive_root(1180581915ULL, 16, mod)); + ASSERT_FALSE(is_primitive_root(1180581915ULL, 32, mod)); + ASSERT_FALSE(is_primitive_root(1180581915ULL, 8, mod)); + ASSERT_FALSE(is_primitive_root(1180581915ULL, 2, mod)); + } + + TEST(UIntArithSmallMod, TryMinimalPrimitiveRootSmallMod) + { + SmallModulus mod(11); + + uint64_t result; + ASSERT_TRUE(try_minimal_primitive_root(2, mod, result)); + ASSERT_EQ(10ULL, result); + + mod = 29; + ASSERT_TRUE(try_minimal_primitive_root(2, mod, result)); + ASSERT_EQ(28ULL, result); + ASSERT_TRUE(try_minimal_primitive_root(4, mod, result)); + ASSERT_EQ(12ULL, result); + + mod = 1234565441; + ASSERT_TRUE(try_minimal_primitive_root(2, mod, result)); + ASSERT_EQ(1234565440ULL, result); + ASSERT_TRUE(try_minimal_primitive_root(8, mod, result)); + ASSERT_EQ(249725733ULL, result); + } + + TEST(UIntArithSmallMod, ExponentiateUIntSmallMod) + { + SmallModulus mod(5); + ASSERT_EQ(1ULL, exponentiate_uint_mod(1, 0, mod)); + ASSERT_EQ(1ULL, exponentiate_uint_mod(1, 0xFFFFFFFFFFFFFFFFULL, mod)); + ASSERT_EQ(3ULL, exponentiate_uint_mod(2, 0xFFFFFFFFFFFFFFFFULL, mod)); + + mod = 0x1000000000000000ULL; + ASSERT_EQ(0ULL, exponentiate_uint_mod(2, 60, mod)); + ASSERT_EQ(0x800000000000000ULL, exponentiate_uint_mod(2, 59, mod)); + + mod = 131313131313; + ASSERT_EQ(39418477653ULL, exponentiate_uint_mod(2424242424, 16, mod)); + } + } +} diff --git a/tests/seal/util/uintcore.cpp b/tests/seal/util/uintcore.cpp new file mode 100644 index 000000000..524d85ad5 --- /dev/null +++ b/tests/seal/util/uintcore.cpp @@ -0,0 +1,758 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "gtest/gtest.h" +#include "seal/util/uintcore.h" +#include + +using namespace seal::util; +using namespace std; + +namespace SEALTest +{ + namespace util + { + TEST(UIntCore, AllocateUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(0, pool)); + ASSERT_TRUE(nullptr == ptr.get()); + + ptr = allocate_uint(1, pool); + ASSERT_TRUE(nullptr != ptr.get()); + + ptr = allocate_uint(2, pool); + ASSERT_TRUE(nullptr != ptr.get()); + } + + TEST(UIntCore, SetZeroUInt) + { + set_zero_uint(0, nullptr); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(1, pool)); + ptr[0] = 0x1234567812345678; + set_zero_uint(1, ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + + ptr = allocate_uint(2, pool); + ptr[0] = 0x1234567812345678; + ptr[1] = 0x1234567812345678; + set_zero_uint(2, ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + } + + TEST(UIntCore, AllocateZeroUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_zero_uint(0, pool)); + ASSERT_TRUE(nullptr == ptr.get()); + + ptr = allocate_zero_uint(1, pool); + ASSERT_TRUE(nullptr != ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + + ptr = allocate_zero_uint(2, pool); + ASSERT_TRUE(nullptr != ptr.get()); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + } + + TEST(UIntCore, SetUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(1, pool)); + ptr[0] = 0xFFFFFFFFFFFFFFFF; + set_uint(1, 1, ptr.get()); + ASSERT_EQ(1ULL, ptr[0]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + set_uint(0x1234567812345678, 1, ptr.get()); + ASSERT_EQ(static_cast(0x1234567812345678), ptr[0]); + + ptr = allocate_uint(2, pool); + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + set_uint(1, 2, ptr.get()); + ASSERT_EQ(1ULL, ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + set_uint(0x1234567812345678, 2, ptr.get()); + ASSERT_EQ(static_cast(0x1234567812345678), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + } + + TEST(UIntCore, SetUIntUInt) + { + set_uint_uint(nullptr, 0, nullptr); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr1(allocate_uint(1, pool)); + ptr1[0] = 0x1234567887654321; + auto ptr2(allocate_uint(1, pool)); + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + set_uint_uint(ptr1.get(), 1, ptr2.get()); + ASSERT_EQ(static_cast(0x1234567887654321), ptr2[0]); + + ptr1[0] = 0x1231231231231231; + set_uint_uint(ptr1.get(), 1, ptr1.get()); + ASSERT_EQ(static_cast(0x1231231231231231), ptr1[0]); + + ptr1 = allocate_uint(2, pool); + ptr2 = allocate_uint(2, pool); + ptr1[0] = 0x1234567887654321; + ptr1[1] = 0x8765432112345678; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + set_uint_uint(ptr1.get(), 2, ptr2.get()); + ASSERT_EQ(static_cast(0x1234567887654321), ptr2[0]); + ASSERT_EQ(static_cast(0x8765432112345678), ptr2[1]); + + ptr1[0] = 0x1231231231231321; + ptr1[1] = 0x3213213213213211; + set_uint_uint(ptr1.get(), 2, ptr1.get()); + ASSERT_EQ(static_cast(0x1231231231231321), ptr1[0]); + ASSERT_EQ(static_cast(0x3213213213213211), ptr1[1]); + } + + TEST(UIntCore, SetUIntUInt2) + { + set_uint_uint(nullptr, 0, 0, nullptr); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr1(allocate_uint(1, pool)); + ptr1[0] = 0x1234567887654321; + set_uint_uint(nullptr, 0, 1, ptr1.get()); + ASSERT_EQ(static_cast(0), ptr1[0]); + + auto ptr2(allocate_uint(1, pool)); + ptr1[0] = 0x1234567887654321; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + set_uint_uint(ptr1.get(), 1, 1, ptr2.get()); + ASSERT_EQ(static_cast(0x1234567887654321), ptr2[0]); + + ptr1[0] = 0x1231231231231231; + set_uint_uint(ptr1.get(), 1, 1, ptr1.get()); + ASSERT_EQ(static_cast(0x1231231231231231), ptr1[0]); + + ptr1 = allocate_uint(2, pool); + ptr2 = allocate_uint(2, pool); + ptr1[0] = 0x1234567887654321; + ptr1[1] = 0x8765432112345678; + set_uint_uint(nullptr, 0, 2, ptr1.get()); + ASSERT_EQ(static_cast(0), ptr1[0]); + ASSERT_EQ(static_cast(0), ptr1[1]); + + ptr1[0] = 0x1234567887654321; + ptr1[1] = 0x8765432112345678; + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + set_uint_uint(ptr1.get(), 1, 2, ptr2.get()); + ASSERT_EQ(static_cast(0x1234567887654321), ptr2[0]); + ASSERT_EQ(static_cast(0), ptr2[1]); + + ptr2[0] = 0xFFFFFFFFFFFFFFFF; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + set_uint_uint(ptr1.get(), 2, 2, ptr2.get()); + ASSERT_EQ(static_cast(0x1234567887654321), ptr2[0]); + ASSERT_EQ(static_cast(0x8765432112345678), ptr2[1]); + + ptr1[0] = 0x1231231231231321; + ptr1[1] = 0x3213213213213211; + set_uint_uint(ptr1.get(), 2, 2, ptr1.get()); + ASSERT_EQ(static_cast(0x1231231231231321), ptr1[0]); + ASSERT_EQ(static_cast(0x3213213213213211), ptr1[1]); + + set_uint_uint(ptr1.get(), 1, 2, ptr1.get()); + ASSERT_EQ(static_cast(0x1231231231231321), ptr1[0]); + ASSERT_EQ(static_cast(0), ptr1[1]); + } + + TEST(UIntCore, IsZeroUInt) + { + ASSERT_TRUE(is_zero_uint(nullptr, 0)); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(1, pool)); + ptr[0] = 1; + ASSERT_FALSE(is_zero_uint(ptr.get(), 1)); + ptr[0] = 0; + ASSERT_TRUE(is_zero_uint(ptr.get(), 1)); + + ptr = allocate_uint(2, pool); + ptr[0] = 0x8000000000000000; + ptr[1] = 0x8000000000000000; + ASSERT_FALSE(is_zero_uint(ptr.get(), 2)); + ptr[0] = 0; + ptr[1] = 0x8000000000000000; + ASSERT_FALSE(is_zero_uint(ptr.get(), 2)); + ptr[0] = 0x8000000000000000; + ptr[1] = 0; + ASSERT_FALSE(is_zero_uint(ptr.get(), 2)); + ptr[0] = 0; + ptr[1] = 0; + ASSERT_TRUE(is_zero_uint(ptr.get(), 2)); + } + + TEST(UIntCore, IsEqualUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(1, pool)); + ptr[0] = 1; + ASSERT_TRUE(is_equal_uint(ptr.get(), 1, 1)); + ASSERT_FALSE(is_equal_uint(ptr.get(), 1, 0)); + ASSERT_FALSE(is_equal_uint(ptr.get(), 1, 2)); + + ptr = allocate_uint(2, pool); + ptr[0] = 1; + ptr[1] = 1; + ASSERT_FALSE(is_equal_uint(ptr.get(), 2, 1)); + ptr[0] = 1; + ptr[1] = 0; + ASSERT_TRUE(is_equal_uint(ptr.get(), 2, 1)); + ptr[0] = 0x1234567887654321; + ptr[1] = 0; + ASSERT_TRUE(is_equal_uint(ptr.get(), 2, 0x1234567887654321)); + ASSERT_FALSE(is_equal_uint(ptr.get(), 2, 0x2234567887654321)); + } + + TEST(UIntCore, IsBitSetUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + ptr[0] = 0; + ptr[1] = 0; + for (int i = 0; i < 128; ++i) + { + ASSERT_FALSE(is_bit_set_uint(ptr.get(), 2, i)); + } + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + for (int i = 0; i < 128; ++i) + { + ASSERT_TRUE(is_bit_set_uint(ptr.get(), 2, i)); + } + + ptr[0] = 0x0000000000000001; + ptr[1] = 0x8000000000000000; + for (int i = 0; i < 128; ++i) + { + if (i == 0 || i == 127) + { + ASSERT_TRUE(is_bit_set_uint(ptr.get(), 2, i)); + } + else + { + ASSERT_FALSE(is_bit_set_uint(ptr.get(), 2, i)); + } + } + } + + TEST(UIntCore, IsHighBitSetUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + ptr[0] = 0; + ptr[1] = 0; + ASSERT_FALSE(is_high_bit_set_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ASSERT_TRUE(is_high_bit_set_uint(ptr.get(), 2)); + + ptr[0] = 0; + ptr[1] = 0x8000000000000000; + ASSERT_TRUE(is_high_bit_set_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0x7FFFFFFFFFFFFFFF; + ASSERT_FALSE(is_high_bit_set_uint(ptr.get(), 2)); + } + + TEST(UIntCore, SetBitUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + ptr[0] = 0; + ptr[1] = 0; + set_bit_uint(ptr.get(), 2, 0); + ASSERT_EQ(1ULL, ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + + set_bit_uint(ptr.get(), 2, 127); + ASSERT_EQ(1ULL, ptr[0]); + ASSERT_EQ(static_cast(0x8000000000000000), ptr[1]); + + set_bit_uint(ptr.get(), 2, 63); + ASSERT_EQ(static_cast(0x8000000000000001), ptr[0]); + ASSERT_EQ(static_cast(0x8000000000000000), ptr[1]); + + set_bit_uint(ptr.get(), 2, 64); + ASSERT_EQ(static_cast(0x8000000000000001), ptr[0]); + ASSERT_EQ(static_cast(0x8000000000000001), ptr[1]); + + set_bit_uint(ptr.get(), 2, 3); + ASSERT_EQ(static_cast(0x8000000000000009), ptr[0]); + ASSERT_EQ(static_cast(0x8000000000000001), ptr[1]); + } + + TEST(UIntCore, GetSignificantBitCountUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + ptr[0] = 0; + ptr[1] = 0; + ASSERT_EQ(0, get_significant_bit_count_uint(ptr.get(), 2)); + + ptr[0] = 1; + ptr[1] = 0; + ASSERT_EQ(1, get_significant_bit_count_uint(ptr.get(), 2)); + + ptr[0] = 2; + ptr[1] = 0; + ASSERT_EQ(2, get_significant_bit_count_uint(ptr.get(), 2)); + + ptr[0] = 3; + ptr[1] = 0; + ASSERT_EQ(2, get_significant_bit_count_uint(ptr.get(), 2)); + + ptr[0] = 29; + ptr[1] = 0; + ASSERT_EQ(5, get_significant_bit_count_uint(ptr.get(), 2)); + + ptr[0] = 4; + ptr[1] = 0; + ASSERT_EQ(3, get_significant_bit_count_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0; + ASSERT_EQ(64, get_significant_bit_count_uint(ptr.get(), 2)); + + ptr[0] = 0; + ptr[1] = 1; + ASSERT_EQ(65, get_significant_bit_count_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 1; + ASSERT_EQ(65, get_significant_bit_count_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0x7000000000000000; + ASSERT_EQ(127, get_significant_bit_count_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0x8000000000000000; + ASSERT_EQ(128, get_significant_bit_count_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ASSERT_EQ(128, get_significant_bit_count_uint(ptr.get(), 2)); + } + + TEST(UIntCore, GetSignificantUInt64CountUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + ptr[0] = 0; + ptr[1] = 0; + ASSERT_EQ(0ULL, get_significant_uint64_count_uint(ptr.get(), 2)); + + ptr[0] = 1; + ptr[1] = 0; + ASSERT_EQ(1ULL, get_significant_uint64_count_uint(ptr.get(), 2)); + + ptr[0] = 2; + ptr[1] = 0; + ASSERT_EQ(1ULL, get_significant_uint64_count_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0; + ASSERT_EQ(1ULL, get_significant_uint64_count_uint(ptr.get(), 2)); + + ptr[0] = 0; + ptr[1] = 1; + ASSERT_EQ(2ULL, get_significant_uint64_count_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 1; + ASSERT_EQ(2ULL, get_significant_uint64_count_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0x8000000000000000; + ASSERT_EQ(2ULL, get_significant_uint64_count_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ASSERT_EQ(2ULL, get_significant_uint64_count_uint(ptr.get(), 2)); + } + + TEST(UIntCore, GetPowerOfTwoUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_zero_uint(2, pool)); + ASSERT_EQ(-1, get_power_of_two_uint(ptr.get(), 1)); + ASSERT_EQ(-1, get_power_of_two_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ASSERT_EQ(-1, get_power_of_two_uint(ptr.get(), 1)); + ASSERT_EQ(-1, get_power_of_two_uint(ptr.get(), 2)); + + ptr[0] = 0x0000000000000001; + ptr[1] = 0x0000000000000000; + ASSERT_EQ(0, get_power_of_two_uint(ptr.get(), 1)); + ASSERT_EQ(0, get_power_of_two_uint(ptr.get(), 2)); + + ptr[0] = 0x0000000000000001; + ptr[1] = 0x8000000000000000; + ASSERT_EQ(-1, get_power_of_two_uint(ptr.get(), 2)); + + ptr[0] = 0x0000000000000000; + ptr[1] = 0x8000000000000000; + ASSERT_EQ(127, get_power_of_two_uint(ptr.get(), 2)); + + ptr[0] = 0x8000000000000000; + ptr[1] = 0x0000000000000000; + ASSERT_EQ(63, get_power_of_two_uint(ptr.get(), 2)); + + ptr[0] = 0x9000000000000000; + ptr[1] = 0x0000000000000000; + ASSERT_EQ(-1, get_power_of_two_uint(ptr.get(), 2)); + + ptr[0] = 0x8000000000000001; + ptr[1] = 0x0000000000000000; + ASSERT_EQ(-1, get_power_of_two_uint(ptr.get(), 2)); + + ptr[0] = 0x0000000000000000; + ptr[1] = 0x0000000000000001; + ASSERT_EQ(64, get_power_of_two_uint(ptr.get(), 2)); + } + + TEST(UIntCore, GetPowerOfTwoMinusOneUInt) + { + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_zero_uint(2, pool)); + ASSERT_EQ(0, get_power_of_two_minus_one_uint(ptr.get(), 1)); + ASSERT_EQ(0, get_power_of_two_minus_one_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ASSERT_EQ(64, get_power_of_two_minus_one_uint(ptr.get(), 1)); + ASSERT_EQ(128, get_power_of_two_minus_one_uint(ptr.get(), 2)); + + ptr[0] = 0x0000000000000001; + ptr[1] = 0x0000000000000000; + ASSERT_EQ(1, get_power_of_two_minus_one_uint(ptr.get(), 1)); + ASSERT_EQ(1, get_power_of_two_minus_one_uint(ptr.get(), 2)); + + ptr[0] = 0x0000000000000001; + ptr[1] = 0x8000000000000000; + ASSERT_EQ(-1, get_power_of_two_minus_one_uint(ptr.get(), 2)); + + ptr[0] = 0x0000000000000000; + ptr[1] = 0x8000000000000000; + ASSERT_EQ(-1, get_power_of_two_minus_one_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0x7FFFFFFFFFFFFFFF; + ASSERT_EQ(127, get_power_of_two_minus_one_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFE; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + ASSERT_EQ(-1, get_power_of_two_minus_one_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0x0000000000000000; + ASSERT_EQ(64, get_power_of_two_minus_one_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFE; + ptr[1] = 0x0000000000000000; + ASSERT_EQ(-1, get_power_of_two_minus_one_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0x0000000000000001; + ASSERT_EQ(65, get_power_of_two_minus_one_uint(ptr.get(), 2)); + + ptr[0] = 0xFFFFFFFFFFFFFFFE; + ptr[1] = 0x0000000000000001; + ASSERT_EQ(-1, get_power_of_two_minus_one_uint(ptr.get(), 2)); + } + + TEST(UIntCore, FilterHighBitsUInt) + { + filter_highbits_uint(nullptr, 0, 0); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr(allocate_uint(2, pool)); + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + filter_highbits_uint(ptr.get(), 2, 0); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + + ptr[0] = 0xFFFFFFFFFFFFFFFF; + ptr[1] = 0xFFFFFFFFFFFFFFFF; + filter_highbits_uint(ptr.get(), 2, 128); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[0]); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); + filter_highbits_uint(ptr.get(), 2, 127); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[0]); + ASSERT_EQ(static_cast(0x7FFFFFFFFFFFFFFF), ptr[1]); + filter_highbits_uint(ptr.get(), 2, 126); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[0]); + ASSERT_EQ(static_cast(0x3FFFFFFFFFFFFFFF), ptr[1]); + filter_highbits_uint(ptr.get(), 2, 64); + ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + filter_highbits_uint(ptr.get(), 2, 63); + ASSERT_EQ(static_cast(0x7FFFFFFFFFFFFFFF), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + filter_highbits_uint(ptr.get(), 2, 2); + ASSERT_EQ(static_cast(0x3), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + filter_highbits_uint(ptr.get(), 2, 1); + ASSERT_EQ(static_cast(0x1), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + filter_highbits_uint(ptr.get(), 2, 0); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + + filter_highbits_uint(ptr.get(), 2, 128); + ASSERT_EQ(static_cast(0), ptr[0]); + ASSERT_EQ(static_cast(0), ptr[1]); + } + + TEST(UIntCore, CompareUIntUInt) + { + ASSERT_EQ(0, compare_uint_uint(nullptr, nullptr, 0)); + ASSERT_TRUE(is_equal_uint_uint(nullptr, nullptr, 0)); + ASSERT_FALSE(is_not_equal_uint_uint(nullptr, nullptr, 0)); + ASSERT_FALSE(is_greater_than_uint_uint(nullptr, nullptr, 0)); + ASSERT_FALSE(is_less_than_uint_uint(nullptr, nullptr, 0)); + ASSERT_TRUE(is_greater_than_or_equal_uint_uint(nullptr, nullptr, 0)); + ASSERT_TRUE(is_less_than_or_equal_uint_uint(nullptr, nullptr, 0)); + + MemoryPool &pool = *global_variables::global_memory_pool; + auto ptr1(allocate_uint(2, pool)); + auto ptr2(allocate_uint(2, pool)); + ptr1[0] = 0; + ptr1[1] = 0; + ptr2[0] = 0; + ptr2[1] = 0; + ASSERT_EQ(0, compare_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_not_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_greater_than_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_less_than_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_greater_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_less_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + + ptr1[0] = 0x1234567887654321; + ptr1[1] = 0x8765432112345678; + ptr2[0] = 0x1234567887654321; + ptr2[1] = 0x8765432112345678; + ASSERT_EQ(0, compare_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_not_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_greater_than_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_less_than_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_greater_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_less_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + + ptr1[0] = 1; + ptr1[1] = 0; + ptr2[0] = 2; + ptr2[1] = 0; + ASSERT_EQ(-1, compare_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_not_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_greater_than_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_less_than_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_greater_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_less_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + + ptr1[0] = 1; + ptr1[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 2; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + ASSERT_EQ(-1, compare_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_not_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_greater_than_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_less_than_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_greater_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_less_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + + ptr1[0] = 0xFFFFFFFFFFFFFFFF; + ptr1[1] = 0x0000000000000001; + ptr2[0] = 0x0000000000000000; + ptr2[1] = 0x0000000000000002; + ASSERT_EQ(-1, compare_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_not_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_greater_than_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_less_than_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_greater_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_less_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + + ptr1[0] = 2; + ptr1[1] = 0; + ptr2[0] = 1; + ptr2[1] = 0; + ASSERT_EQ(1, compare_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_not_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_greater_than_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_less_than_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_greater_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_less_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + + ptr1[0] = 2; + ptr1[1] = 0xFFFFFFFFFFFFFFFF; + ptr2[0] = 1; + ptr2[1] = 0xFFFFFFFFFFFFFFFF; + ASSERT_EQ(1, compare_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_not_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_greater_than_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_less_than_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_greater_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_less_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + + ptr1[0] = 0xFFFFFFFFFFFFFFFF; + ptr1[1] = 0x0000000000000003; + ptr2[0] = 0x0000000000000000; + ptr2[1] = 0x0000000000000002; + ASSERT_EQ(1, compare_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_not_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_greater_than_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_less_than_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_TRUE(is_greater_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + ASSERT_FALSE(is_less_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); + } + + TEST(UIntCore, GetPowerOfTwo) + { + ASSERT_EQ(-1, get_power_of_two(0)); + ASSERT_EQ(0, get_power_of_two(1)); + ASSERT_EQ(1, get_power_of_two(2)); + ASSERT_EQ(-1, get_power_of_two(3)); + ASSERT_EQ(2, get_power_of_two(4)); + ASSERT_EQ(-1, get_power_of_two(5)); + ASSERT_EQ(-1, get_power_of_two(6)); + ASSERT_EQ(-1, get_power_of_two(7)); + ASSERT_EQ(3, get_power_of_two(8)); + ASSERT_EQ(-1, get_power_of_two(15)); + ASSERT_EQ(4, get_power_of_two(16)); + ASSERT_EQ(-1, get_power_of_two(17)); + ASSERT_EQ(-1, get_power_of_two(255)); + ASSERT_EQ(8, get_power_of_two(256)); + ASSERT_EQ(-1, get_power_of_two(257)); + ASSERT_EQ(10, get_power_of_two(1 << 10)); + ASSERT_EQ(30, get_power_of_two(1 << 30)); + ASSERT_EQ(32, get_power_of_two(1ULL << 32)); + ASSERT_EQ(62, get_power_of_two(1ULL << 62)); + ASSERT_EQ(63, get_power_of_two(1ULL << 63)); + } + + TEST(UIntCore, GetPowerOfTwoMinusOne) + { + ASSERT_EQ(0, get_power_of_two_minus_one(0)); + ASSERT_EQ(1, get_power_of_two_minus_one(1)); + ASSERT_EQ(-1, get_power_of_two_minus_one(2)); + ASSERT_EQ(2, get_power_of_two_minus_one(3)); + ASSERT_EQ(-1, get_power_of_two_minus_one(4)); + ASSERT_EQ(-1, get_power_of_two_minus_one(5)); + ASSERT_EQ(-1, get_power_of_two_minus_one(6)); + ASSERT_EQ(3, get_power_of_two_minus_one(7)); + ASSERT_EQ(-1, get_power_of_two_minus_one(8)); + ASSERT_EQ(-1, get_power_of_two_minus_one(14)); + ASSERT_EQ(4, get_power_of_two_minus_one(15)); + ASSERT_EQ(-1, get_power_of_two_minus_one(16)); + ASSERT_EQ(8, get_power_of_two_minus_one(255)); + ASSERT_EQ(10, get_power_of_two_minus_one((1 << 10) - 1)); + ASSERT_EQ(30, get_power_of_two_minus_one((1 << 30) - 1)); + ASSERT_EQ(32, get_power_of_two_minus_one((1ULL << 32) - 1)); + ASSERT_EQ(63, get_power_of_two_minus_one((1ULL << 63) - 1)); + ASSERT_EQ(64, get_power_of_two_minus_one(~static_cast(0))); + } + + TEST(UIntCore, DuplicateUIntIfNeeded) + { + //MemoryPool &pool = *global_variables::global_memory_pool; + MemoryPoolST pool; + auto ptr(allocate_uint(2, pool)); + ptr[0] = 0xF0F0F0F0F0; + ptr[1] = 0xABABABABAB; + auto ptr2 = duplicate_uint_if_needed(ptr.get(), 0, 0, false, pool); + // No forcing and sizes are same (although zero) so just alias + ASSERT_TRUE(ptr2.get() == ptr.get()); + + ptr2 = duplicate_uint_if_needed(ptr.get(), 0, 0, true, pool); + // Forcing and size is zero so return nullptr + ASSERT_TRUE(ptr2.get() == nullptr); + + ptr2 = duplicate_uint_if_needed(ptr.get(), 1, 0, false, pool); + ASSERT_TRUE(ptr2.get() == ptr.get()); + + ptr2 = duplicate_uint_if_needed(ptr.get(), 1, 0, true, pool); + ASSERT_TRUE(ptr2.get() == nullptr); + + ptr2 = duplicate_uint_if_needed(ptr.get(), 1, 1, false, pool); + ASSERT_TRUE(ptr2.get() == ptr.get()); + + ptr2 = duplicate_uint_if_needed(ptr.get(), 1, 1, true, pool); + ASSERT_TRUE(ptr2.get() != ptr.get()); + ASSERT_EQ(ptr[0], ptr2[0]); + + ptr2 = duplicate_uint_if_needed(ptr.get(), 2, 2, true, pool); + ASSERT_TRUE(ptr2.get() != ptr.get()); + ASSERT_EQ(ptr[0], ptr2[0]); + ASSERT_EQ(ptr[1], ptr2[1]); + + ptr2 = duplicate_uint_if_needed(ptr.get(), 2, 2, false, pool); + ASSERT_TRUE(ptr2.get() == ptr.get()); + + ptr2 = duplicate_uint_if_needed(ptr.get(), 2, 1, false, pool); + ASSERT_TRUE(ptr2.get() == ptr.get()); + + ptr2 = duplicate_uint_if_needed(ptr.get(), 1, 2, false, pool); + ASSERT_TRUE(ptr2.get() != ptr.get()); + ASSERT_EQ(ptr[0], ptr2[0]); + ASSERT_EQ(0ULL, ptr2[1]); + + ptr2 = duplicate_uint_if_needed(ptr.get(), 1, 2, true, pool); + ASSERT_TRUE(ptr2.get() != ptr.get()); + ASSERT_EQ(ptr[0], ptr2[0]); + ASSERT_EQ(0ULL, ptr2[1]); + } + + TEST(UIntCore, HammingWeight) + { + ASSERT_EQ(0ULL, hamming_weight(0ULL)); + ASSERT_EQ(1ULL, hamming_weight(1ULL)); + ASSERT_EQ(1ULL, hamming_weight(0x10000ULL)); + ASSERT_EQ(2ULL, hamming_weight(0x10001ULL)); + ASSERT_EQ(32ULL, hamming_weight(0xFFFFFFFFULL)); + ASSERT_EQ(64ULL, hamming_weight(0xFFFFFFFFFFFFFFFFULL)); + ASSERT_EQ(32ULL, hamming_weight(0xF0F0F0F0F0F0F0F0ULL)); + ASSERT_EQ(16ULL, hamming_weight(0xA0A0A0A0A0A0A0A0ULL)); + } + + TEST(UIntCore, HammingWeightSplit) + { + ASSERT_EQ(0ULL, hamming_weight_split(0ULL)); + ASSERT_EQ(1ULL, hamming_weight_split(1ULL)); + ASSERT_EQ(0x10000ULL, hamming_weight_split(0x10000ULL)); + ASSERT_EQ(1ULL, hamming_weight_split(0x10001ULL)); + ASSERT_EQ(0xFFFFULL, hamming_weight_split(0xFFFFFFFFULL)); + ASSERT_EQ(0xFFFFFFFFULL, hamming_weight_split(0xFFFFFFFFFFFFFFFFULL)); + ASSERT_EQ(0xF0F0F00ULL, hamming_weight_split(0xF0F0F0000F0F0F00ULL)); + ASSERT_EQ(0xA0A0A0A0ULL, hamming_weight_split(0xA0A0A0A0A0A0A0A0ULL)); + } + } +} diff --git a/tools/Makefile b/tools/Makefile new file mode 100644 index 000000000..320388b4b --- /dev/null +++ b/tools/Makefile @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +SHELL=/bin/bash +COMPR_SYSTEM_INFO_ARCHIVE=../system_info.tar.gz + +.PHONY: system_info clean + +system_info: clean $(COMPR_SYSTEM_INFO_ARCHIVE) + +$(COMPR_SYSTEM_INFO_ARCHIVE): scripts/collect_system_info.sh + @$(SHELL) scripts/collect_system_info.sh + +clean: + @rm -f $(COMPR_SYSTEM_INFO_ARCHIVE) diff --git a/tools/config/packages.config b/tools/config/packages.config new file mode 100644 index 000000000..cfe2aac14 --- /dev/null +++ b/tools/config/packages.config @@ -0,0 +1,4 @@ + + + + diff --git a/tools/scripts/cmake_config.cmd b/tools/scripts/cmake_config.cmd new file mode 100644 index 000000000..fb4060f7f --- /dev/null +++ b/tools/scripts/cmake_config.cmd @@ -0,0 +1,50 @@ +@echo off + +rem Copyright (c) Microsoft Corporation. All rights reserved. +rem Licensed under the MIT license. + +setlocal + +rem The purpose of this script is to have CMake generate config.h for use by SEAL. +rem We assume that CMake was installed with Visual Studio, which should be the default +rem when the user installs the "Desktop Development with C++" workload. + +set PROJECTCONFIGURATION=%1 +set VSDEVENVDIR=%~2 +set INCLUDEPATH=%~3 + +echo Configuring SEAL through CMake +echo Looking for CMake + +if not exist "%VSDEVENVDIR%" ( + rem We may be running in the CI server. Try a standard VS path. + echo Did not find VS at provided location: "%VSDEVENVDIR%". + echo Trying standard location. + set VSDEVENVDIR="C:\Program Files (x86)\Microsoft Visual Studio\2017\Enterprise\Common7\IDE" +) + +set VSDEVENVDIR=%VSDEVENVDIR:"=% +set CMAKEPATH=%VSDEVENVDIR%\CommonExtensions\Microsoft\CMake\CMake\bin\cmake.exe + +if not exist "%CMAKEPATH%" ( + echo ************************************************************************************************************** + echo Did not find CMake at "%CMAKEPATH%" + echo Please make sure "Visual C++ Tools for CMake" are enabled in the "Desktop development with C++" workload. + echo ************************************************************************************************************** + exit 1 +) + +echo Found CMake at: %CMAKEPATH% + +%~d0 +cd %~dp0 +cd ..\..\src +if not exist ".config" ( + mkdir .config +) +cd .config + +echo Running CMake configuration in: +echo %cd% + +"%CMAKEPATH%" .. -G "Visual Studio 15 2017" -A x64 -DALLOW_COMMAND_LINE_BUILD=1 -DCMAKE_BUILD_TYPE=%PROJECTCONFIGURATION% -DMSGSL_INCLUDE_DIR="%INCLUDEPATH%" diff --git a/tools/scripts/collect_system_info.sh b/tools/scripts/collect_system_info.sh new file mode 100755 index 000000000..10801fb18 --- /dev/null +++ b/tools/scripts/collect_system_info.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +CMAKE_CXX_COMPILER=cxx_compiler.out +CMAKE_ENV=cmake_env.out +CMAKE_SYSTEM_INFO=system_information.out + +CMAKE_CXX_COMPILER_CMD=`cmake -LA ../SEAL|sed -n 's/^CMAKE_CXX_COMPILER:FILEPATH=\(.*\)$/\1/p'` + +echo "Extracting: cmake -LA ../SEAL > $CMAKE_ENV" +cmake -LA ../SEAL > $CMAKE_ENV +echo "Extracting: cmake --system-information > $CMAKE_SYSTEM_INFO" +cmake --system-information > $CMAKE_SYSTEM_INFO +echo "Extracting: $CMAKE_CXX_COMPILER_CMD -v > $CMAKE_CXX_COMPILER 2>&1" +$CMAKE_CXX_COMPILER_CMD -v 2> $CMAKE_CXX_COMPILER + +ARCHIVE_NAME=../system_info.tar +SEALDIR=../SEAL +FILES=( + "$SEALDIR/seal/util/config.h" + "$SEALDIR/CMakeCache.txt" + "$SEALDIR/CMakeFiles/CMakeOutput.log" + "$SEALDIR/CMakeFiles/CMakeError.log" + "/proc/cpuinfo" + "$CMAKE_ENV" + "$CMAKE_SYSTEM_INFO" + "$CMAKE_CXX_COMPILER" +) + +print_collecting_filename() { + echo -e "\033[0mCollecting \033[1;32m$1\033[0m" +} + +print_skipping_filename() { + echo -e "\033[0mSkipping \033[1;31m$1\033[0m" +} + +add_to_archive() { + BASENAME=`basename $1` + cp -f $1 $BASENAME 2>/dev/null + if [ -s $ARCHIVE_NAME ] + then + tar -rf $ARCHIVE_NAME ./$BASENAME + else + tar -cf $ARCHIVE_NAME ./$BASENAME + fi + rm -f ./$BASENAME +} + +rm -f "$ARCHIVE_NAME.gz" + +for i in ${FILES[@]} +do + if [ -r $i ] + then + print_collecting_filename $i + add_to_archive $i + else + print_skipping_filename $i + fi +done + +gzip $ARCHIVE_NAME +if [ $? -eq 0 ] +then + echo "Created `realpath $ARCHIVE_NAME.gz`" +else + echo "Could not create `realpath $ARCHIVE_NAME.gz`" + rm -f $ARCHIVE_NAME.gz +fi + +echo -n "Cleaning up ... " +rm -f $CMAKE_ENV +rm -f $CMAKE_SYSTEM_INFO +rm -f $CMAKE_CXX_COMPILER +echo done.