From 121dd136d7e3ecab98c939c0b02901c95e947bf8 Mon Sep 17 00:00:00 2001 From: Greg Hewett Date: Mon, 24 Jul 2023 15:26:10 -0500 Subject: [PATCH] adding import and export jwk keys into signatures --- CMakeLists.txt | 19 ++- alternatives/openssl_3/vcpkg.json | 3 +- include/mls/credential.h | 4 +- include/mls/crypto.h | 8 ++ lib/bytes/CMakeLists.txt | 6 +- lib/bytes/include/bytes/bytes.h | 12 ++ lib/bytes/src/bytes.cpp | 124 +++++++++++++++- lib/bytes/test/bytes.cpp | 28 ++++ lib/hpke/CMakeLists.txt | 7 +- lib/hpke/include/hpke/signature.h | 7 + lib/hpke/src/group.cpp | 228 +++++++++++++++++++++++++++++- lib/hpke/src/group.h | 9 ++ lib/hpke/src/signature.cpp | 121 ++++++++++++++++ lib/tls_syntax/CMakeLists.txt | 2 +- src/crypto.cpp | 30 ++++ test/crypto.cpp | 135 +++++++++++++++++- vcpkg.json | 3 +- 17 files changed, 730 insertions(+), 16 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index eebbfc10..4e06fc97 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -114,6 +114,8 @@ target_include_directories(${LIB_NAME} ${OPENSSL_INCLUDE_DIR} ) +install(TARGETS ${LIB_NAME} EXPORT mlspp-targets) + ### ### Tests ### @@ -125,9 +127,7 @@ endif() ### Exports ### set(CMAKE_EXPORT_PACKAGE_REGISTRY ON) -export(EXPORT mlspp-targets - NAMESPACE MLSPP:: - FILE ${CMAKE_CURRENT_BINARY_DIR}/mlspp-targets.cmake) +export(EXPORT mlspp-targets NAMESPACE MLSPP:: FILE mlspp-targets.cmake) export(PACKAGE MLSPP) configure_package_config_file(cmake/config.cmake.in @@ -144,8 +144,6 @@ write_basic_package_version_file( ### Install ### -install(TARGETS ${LIB_NAME} EXPORT mlspp-targets) - install( DIRECTORY include/ @@ -156,7 +154,16 @@ install( FILES ${CMAKE_CURRENT_BINARY_DIR}/mlspp-config.cmake ${CMAKE_CURRENT_BINARY_DIR}/mlspp-config-version.cmake - ${CMAKE_CURRENT_BINARY_DIR}/mlspp-targets.cmake + DESTINATION + ${CMAKE_INSTALL_DATADIR}/mlspp) + +install( + EXPORT + mlspp-targets + NAMESPACE + MLSPP:: + FILE + mlspp-targets.cmake DESTINATION ${CMAKE_INSTALL_DATADIR}/mlspp) diff --git a/alternatives/openssl_3/vcpkg.json b/alternatives/openssl_3/vcpkg.json index 4b4d7657..97c9cc3e 100644 --- a/alternatives/openssl_3/vcpkg.json +++ b/alternatives/openssl_3/vcpkg.json @@ -7,7 +7,8 @@ "name": "openssl", "version>=": "3.0.7" }, - "doctest" + "doctest", + "nlohmann-json" ], "builtin-baseline": "5908d702d61cea1429b223a0b7a10ab86bad4c78", "overrides": [ diff --git a/include/mls/credential.h b/include/mls/credential.h index e086e845..981610cb 100644 --- a/include/mls/credential.h +++ b/include/mls/credential.h @@ -11,9 +11,9 @@ namespace mls { // } BasicCredential; struct BasicCredential { - BasicCredential() = default; + BasicCredential() {} - explicit BasicCredential(bytes identity_in) + BasicCredential(bytes identity_in) : identity(std::move(identity_in)) { } diff --git a/include/mls/crypto.h b/include/mls/crypto.h index 42b89f67..d25653bc 100644 --- a/include/mls/crypto.h +++ b/include/mls/crypto.h @@ -210,6 +210,9 @@ extern const std::string multi_credential; struct SignaturePublicKey { + static SignaturePublicKey from_jwk(CipherSuite suite, + const std::string& json_str); + bytes data; bool verify(const CipherSuite& suite, @@ -217,6 +220,8 @@ struct SignaturePublicKey const bytes& message, const bytes& signature) const; + std::string to_jwk(CipherSuite suite) const; + TLS_SERIALIZABLE(data) }; @@ -225,6 +230,8 @@ struct SignaturePrivateKey static SignaturePrivateKey generate(CipherSuite suite); static SignaturePrivateKey parse(CipherSuite suite, const bytes& data); static SignaturePrivateKey derive(CipherSuite suite, const bytes& secret); + static SignaturePrivateKey from_jwk(CipherSuite suite, + const std::string& json_str); SignaturePrivateKey() = default; @@ -236,6 +243,7 @@ struct SignaturePrivateKey const bytes& message) const; void set_public_key(CipherSuite suite); + std::string to_jwk(CipherSuite suite) const; TLS_SERIALIZABLE(data) diff --git a/lib/bytes/CMakeLists.txt b/lib/bytes/CMakeLists.txt index ad4165c4..e420273b 100644 --- a/lib/bytes/CMakeLists.txt +++ b/lib/bytes/CMakeLists.txt @@ -9,7 +9,11 @@ file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src add_library(${CURRENT_LIB_NAME} ${LIB_HEADERS} ${LIB_SOURCES}) add_dependencies(${CURRENT_LIB_NAME} tls_syntax) -target_link_libraries(${CURRENT_LIB_NAME} tls_syntax) +target_link_libraries(${CURRENT_LIB_NAME} + PUBLIC + tls_syntax + PRIVATE + OpenSSL::Crypto) target_include_directories(${CURRENT_LIB_NAME} PUBLIC $ diff --git a/lib/bytes/include/bytes/bytes.h b/lib/bytes/include/bytes/bytes.h index ddb5dacf..49d7ead7 100644 --- a/lib/bytes/include/bytes/bytes.h +++ b/lib/bytes/include/bytes/bytes.h @@ -115,4 +115,16 @@ to_hex(const bytes& data); bytes from_hex(const std::string& hex); +std::string +to_base64(const bytes& data); + +std::string +to_base64url(const bytes& data); + +bytes +from_base64(const std::string& enc); + +bytes +from_base64url(const std::string& enc); + } // namespace bytes_ns diff --git a/lib/bytes/src/bytes.cpp b/lib/bytes/src/bytes.cpp index 509e2532..b63798cf 100644 --- a/lib/bytes/src/bytes.cpp +++ b/lib/bytes/src/bytes.cpp @@ -1,7 +1,10 @@ #include +#include #include -#include +#include +#include +#include #include #include @@ -137,4 +140,123 @@ operator!=(const std::vector& lhs, const bytes_ns::bytes& rhs) return rhs != lhs; } +std::string +to_base64(const bytes& data) +{ + bool done = false; + int result = 0; + + if (data.empty()) { + return ""; + } + + BIO* b64 = BIO_new(BIO_f_base64()); + BIO_set_flags(b64, BIO_FLAGS_BASE64_NO_NL); + BIO* out = BIO_new(BIO_s_mem()); + BIO_push(b64, out); + + while (!done) { + result = BIO_write(b64, data.data(), static_cast(data.size())); + + if (result <= 0) { + if (BIO_should_retry(b64)) { + continue; + } + throw std::runtime_error("base64 encode failed"); + } + done = true; + } + BIO_flush(b64); + char* string_ptr = nullptr; + // long string_len = BIO_get_mem_data(out, &string_ptr); + // BIO_get_mem_data failed clang-tidy + long string_len = BIO_ctrl(out, BIO_CTRL_INFO, 0, &string_ptr); + auto return_value = std::string(string_ptr, string_len); + + BIO_set_close(out, BIO_NOCLOSE); + BIO_free(b64); + BIO_free(out); + return return_value; +} + +std::string +to_base64url(const bytes& data) +{ + if (data.empty()) { + return ""; + } + + std::string return_value = to_base64(data); + + // remove the end padding + auto sz = return_value.find_first_of('='); + + if (sz != std::string::npos) { + return_value = return_value.substr(0, sz); + } + + // replace plus with hyphen + std::replace(return_value.begin(), return_value.end(), '+', '-'); + + // replace slash with underscore + std::replace(return_value.begin(), return_value.end(), '/', '_'); + return return_value; +} + +bytes +from_base64(const std::string& enc) +{ + if (enc.length() == 0) { + return {}; + } + + if (enc.length() % 4 != 0) { + throw std::runtime_error("Base64 length is not divisible by 4"); + } + bytes input = from_ascii(enc); + bytes output(input.size() / 4 * 3); + int output_buffer_length = static_cast(output.size()); + EVP_ENCODE_CTX* ctx = EVP_ENCODE_CTX_new(); + EVP_DecodeInit(ctx); + + int result = EVP_DecodeUpdate(ctx, + output.data(), + &output_buffer_length, + input.data(), + static_cast(input.size())); + + if (result == -1) { + auto code = ERR_get_error(); + throw std::runtime_error(ERR_error_string(code, nullptr)); + } + + if (result == 0 && enc.substr(enc.length() - 2, enc.length()) == "==") { + output = output.slice(0, output.size() - 2); + } else if (result == 0 && enc.substr(enc.length() - 1, enc.length()) == "=") { + output = output.slice(0, output.size() - 1); + } else if (result == 0) { + throw std::runtime_error("Base64 padding was malformed."); + } + EVP_DecodeFinal(ctx, output.data(), &output_buffer_length); + EVP_ENCODE_CTX_free(ctx); + return output; +} + +bytes +from_base64url(const std::string& enc) +{ + if (enc.empty()) { + return {}; + } + std::string enc_copy = enc; // copy + std::replace(enc_copy.begin(), enc_copy.end(), '-', '+'); + std::replace(enc_copy.begin(), enc_copy.end(), '_', '/'); + + while (enc_copy.length() % 4 != 0) { + enc_copy += "="; + } + bytes return_value = from_base64(enc_copy); + return return_value; +} + } // namespace bytes_ns diff --git a/lib/bytes/test/bytes.cpp b/lib/bytes/test/bytes.cpp index a28dbf9f..69a764df 100644 --- a/lib/bytes/test/bytes.cpp +++ b/lib/bytes/test/bytes.cpp @@ -2,6 +2,7 @@ #include #include #include +#include using namespace bytes_ns; using namespace std::literals::string_literals; @@ -40,6 +41,33 @@ TEST_CASE("To/from hex/ASCII") REQUIRE(from_ascii(str) == ascii); } +TEST_CASE("To Base64 / To Base64Url") +{ + struct KnownAnswerTest + { + bytes data; + std::string base64; + std::string base64u; + }; + + const std::vector cases{ + { from_ascii("hello there"), "aGVsbG8gdGhlcmU=", "aGVsbG8gdGhlcmU" }, + { from_ascii("A B C D E F "), "QSBCIEMgRCBFIEYg", "QSBCIEMgRCBFIEYg" }, + { from_ascii("hello\xfethere"), "aGVsbG/+dGhlcmU=", "aGVsbG_-dGhlcmU" }, + { from_ascii("\xfe"), "/g==", "_g" }, + { from_ascii("\x01\x02"), "AQI=", "AQI" }, + { from_ascii("\x01"), "AQ==", "AQ" }, + { from_ascii(""), "", "" }, + }; + + for (const auto& tc : cases) { + REQUIRE(to_base64(tc.data) == tc.base64); + REQUIRE(to_base64url(tc.data) == tc.base64u); + REQUIRE(from_base64(tc.base64) == tc.data); + REQUIRE(from_base64url(tc.base64u) == tc.data); + } +} + TEST_CASE("Operators") { const auto lhs = from_hex("00010203"); diff --git a/lib/hpke/CMakeLists.txt b/lib/hpke/CMakeLists.txt index 44806c4e..5f1f4610 100644 --- a/lib/hpke/CMakeLists.txt +++ b/lib/hpke/CMakeLists.txt @@ -3,6 +3,7 @@ set(CURRENT_LIB_NAME hpke) ### ### Dependencies ### +find_package(nlohmann_json REQUIRED) find_package(OpenSSL 1.1 REQUIRED) ### @@ -14,7 +15,11 @@ file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src add_library(${CURRENT_LIB_NAME} ${LIB_HEADERS} ${LIB_SOURCES}) add_dependencies(${CURRENT_LIB_NAME} bytes tls_syntax) -target_link_libraries(${CURRENT_LIB_NAME} PRIVATE bytes tls_syntax OpenSSL::Crypto) +target_link_libraries(${CURRENT_LIB_NAME} + PRIVATE + nlohmann_json::nlohmann_json OpenSSL::Crypto + PUBLIC + bytes tls_syntax) target_include_directories(${CURRENT_LIB_NAME} PUBLIC $ diff --git a/lib/hpke/include/hpke/signature.h b/lib/hpke/include/hpke/signature.h index 8ee0b39b..378f22a5 100644 --- a/lib/hpke/include/hpke/signature.h +++ b/lib/hpke/include/hpke/signature.h @@ -50,6 +50,13 @@ struct Signature virtual std::unique_ptr deserialize_private( const bytes& skm) const; + virtual std::unique_ptr import_jwk_private( + const std::string& json_str) const; + virtual std::unique_ptr import_jwk( + const std::string& json_str) const; + virtual std::string export_jwk_private(const bytes& env) const; + virtual std::string export_jwk(const bytes& env) const; + virtual bytes sign(const bytes& data, const PrivateKey& sk) const = 0; virtual bool verify(const bytes& data, const bytes& sig, diff --git a/lib/hpke/src/group.cpp b/lib/hpke/src/group.cpp index c6b42b3b..47dc441e 100644 --- a/lib/hpke/src/group.cpp +++ b/lib/hpke/src/group.cpp @@ -526,11 +526,170 @@ struct ECKeyGroup : public EVPGroup #endif } + // EC Key + void get_coordinates(const Group::PublicKey& pk, + bytes& x, + bytes& y) const override + { + auto bnX = make_typed_unique(BN_new()); + auto bnY = make_typed_unique(BN_new()); + const auto& rpk = dynamic_cast(pk); + +#if defined(WITH_OPENSSL3) + OSSL_PARAM* param = nullptr; + + if (1 != EVP_PKEY_todata(rpk.pkey.get(), EVP_PKEY_PUBLIC_KEY, ¶m)) { + throw openssl_error(); + } + auto param_ptr = make_typed_unique(param); + const OSSL_PARAM* pk_param = + OSSL_PARAM_locate_const(param_ptr.get(), OSSL_PKEY_PARAM_PUB_KEY); + + if (pk_param == nullptr) { + throw std::runtime_error("Failed to locate OSSL_PKEY_PARAM_PUB_KEY"); + } + size_t len = 0; + + if (1 != OSSL_PARAM_get_octet_string(pk_param, nullptr, 0, &len)) { + throw std::runtime_error("Failed to get OSSL_PKEY_PARAM_PUB_KEY len"); + } + bytes buf(len); + void* data_ptr = buf.data(); + + if (1 != OSSL_PARAM_get_octet_string(pk_param, &data_ptr, len, nullptr)) { + throw std::runtime_error("Failed to get OSSL_PKEY_PARAM_PUB_KEY data"); + } + auto group = make_typed_unique( + EC_GROUP_new_by_curve_name_ex(nullptr, nullptr, curve_nid)); + + if (group == nullptr) { + throw openssl_error(); + } + auto point = make_typed_unique(EC_POINT_new(group.get())); + + if (point == nullptr) { + throw openssl_error(); + } + const auto* oct_ptr = static_cast(data_ptr); + + if (1 != + EC_POINT_oct2point(group.get(), point.get(), oct_ptr, len, nullptr)) { + throw openssl_error(); + } + + if (1 != EC_POINT_get_affine_coordinates( + group.get(), point.get(), bnX.get(), bnY.get(), nullptr)) { + throw openssl_error(); + } +#else + auto* pub = EVP_PKEY_get0_EC_KEY(rpk.pkey.get()); + const auto* point = EC_KEY_get0_public_key(pub); + const auto* group = EC_KEY_get0_group(pub); + + if (1 != EC_POINT_get_affine_coordinates_GFp( + group, point, bnX.get(), bnY.get(), nullptr)) { + throw openssl_error(); + } +#endif + auto outX = bytes(BN_num_bytes(bnX.get())); + auto outY = bytes(BN_num_bytes(bnY.get())); + + if (BN_bn2bin(bnX.get(), outX.data()) != int(outX.size())) { + throw openssl_error(); + } + + if (BN_bn2bin(bnY.get(), outY.data()) != int(outY.size())) { + throw openssl_error(); + } + const auto zeros_neededX = dh_size - outX.size(); + const auto zeros_neededY = dh_size - outY.size(); + auto leading_zerosX = bytes(zeros_neededX, 0); + auto leading_zerosY = bytes(zeros_neededY, 0); + x = leading_zerosX + outX; + y = leading_zerosY + outY; + } + + // EC Key + std::unique_ptr set_coordinates( + const bytes& x, + const bytes& y) const override + { + auto bnX = make_typed_unique( + BN_bin2bn(x.data(), static_cast(x.size()), nullptr)); + auto bnY = make_typed_unique( + BN_bin2bn(y.data(), static_cast(y.size()), nullptr)); + + if (bnX == nullptr || bnY == nullptr) { + throw std::runtime_error("Failed to convert bnX or bnY"); + } + +#if defined(WITH_OPENSSL3) + auto* group = EC_GROUP_new_by_curve_name_ex(nullptr, nullptr, curve_nid); + auto group_ptr = make_typed_unique(group); + + auto* point = EC_POINT_new(group); + auto point_ptr = make_typed_unique(point); + + if (point == nullptr || group == nullptr) { + throw std::runtime_error("Failed to create EC_POINT or EC_GROUP"); + } + + if (1 != EC_POINT_set_affine_coordinates( + group, point, bnX.get(), bnY.get(), nullptr)) { + throw openssl_error(); + } + + const auto point_size = EC_POINT_point2oct( + group, point, POINT_CONVERSION_UNCOMPRESSED, nullptr, 0, nullptr); + + if (0 == point_size) { + throw openssl_error(); + } + bytes pub(point_size); + + if (EC_POINT_point2oct(group, + point, + POINT_CONVERSION_UNCOMPRESSED, + pub.data(), + point_size, + nullptr) != point_size) { + throw openssl_error(); + } + auto key = public_evp_key(pub); + return std::make_unique(key.release()); +#else + auto eckey = make_typed_unique(new_ec_key()); + + if (eckey == nullptr) { + throw std::runtime_error("Failed to create EC_KEY"); + } + + const auto* group = EC_KEY_get0_group(eckey.get()); + auto* point = EC_POINT_new(group); + auto point_ptr = make_typed_unique(point); + + if (1 != EC_POINT_set_affine_coordinates_GFp( + group, point, bnX.get(), bnY.get(), nullptr)) { + throw openssl_error(); + } + + if (1 != EC_KEY_set_public_key(eckey.get(), point)) { + throw openssl_error(); + } + return std::make_unique(to_pkey(eckey.release())); +#endif + } + private: int curve_nid; #if !defined(WITH_OPENSSL3) - EC_KEY* new_ec_key() const { return EC_KEY_new_by_curve_name(curve_nid); } + // clang-format off + EC_KEY* new_ec_key() const + { + return EC_KEY_new_by_curve_name(curve_nid); + } + // clang-format on static EVP_PKEY* to_pkey(EC_KEY* eckey) { @@ -648,6 +807,30 @@ struct RawKeyGroup : public EVPGroup return std::make_unique(pkey); } + // Raw Key + void get_coordinates(const Group::PublicKey& pk, + bytes& x, + bytes& /*unused*/) const override + { + const auto& rpk = dynamic_cast(pk); + auto raw = bytes(pk_size); + auto* data_ptr = raw.data(); + auto data_len = raw.size(); + + if (1 != EVP_PKEY_get_raw_public_key(rpk.pkey.get(), data_ptr, &data_len)) { + throw openssl_error(); + } + x = raw; + } + + // Raw Key + std::unique_ptr set_coordinates( + const bytes& x, + const bytes& /*unused*/) const override + { + return deserialize(x); + } + private: const int evp_type; @@ -809,11 +992,54 @@ group_sk_size(Group::ID group_id) } } +static inline std::string +group_jwt_curve_name(Group::ID group_id) +{ + switch (group_id) { + case Group::ID::P256: + return "P-256"; + case Group::ID::P384: + return "P-384"; + case Group::ID::P521: + return "P-521"; + case Group::ID::Ed25519: + return "Ed25519"; + case Group::ID::Ed448: + return "Ed448"; + case Group::ID::X25519: + return "X25519"; + case Group::ID::X448: + return "X448"; + default: + throw std::runtime_error("Unknown group"); + } +} + +static inline std::string +group_jwt_key_type(Group::ID group_id) +{ + switch (group_id) { + case Group::ID::P256: + case Group::ID::P384: + case Group::ID::P521: + return "EC"; + case Group::ID::Ed25519: + case Group::ID::Ed448: + case Group::ID::X25519: + case Group::ID::X448: + return "OKP"; + default: + throw std::runtime_error("Unknown group"); + } +} + Group::Group(ID group_id_in, const KDF& kdf_in) : id(group_id_in) , dh_size(group_dh_size(group_id_in)) , pk_size(group_pk_size(group_id_in)) , sk_size(group_sk_size(group_id_in)) + , jwt_key_type(group_jwt_key_type(group_id_in)) + , jwt_curve_name(group_jwt_curve_name(group_id_in)) , kdf(kdf_in) { } diff --git a/lib/hpke/src/group.h b/lib/hpke/src/group.h index ace8d7a9..efb33245 100644 --- a/lib/hpke/src/group.h +++ b/lib/hpke/src/group.h @@ -43,6 +43,8 @@ struct Group const size_t dh_size; const size_t pk_size; const size_t sk_size; + const std::string jwt_key_type; + const std::string jwt_curve_name; virtual std::unique_ptr generate_key_pair() const = 0; virtual std::unique_ptr derive_key_pair( @@ -63,6 +65,13 @@ struct Group const bytes& sig, const PublicKey& pk) const = 0; + virtual void get_coordinates(const Group::PublicKey& pk, + bytes& x, + bytes& y) const = 0; + virtual std::unique_ptr set_coordinates( + const bytes& x, + const bytes& y) const = 0; + protected: const KDF& kdf; diff --git a/lib/hpke/src/signature.cpp b/lib/hpke/src/signature.cpp index a79cafea..78b7d580 100644 --- a/lib/hpke/src/signature.cpp +++ b/lib/hpke/src/signature.cpp @@ -1,14 +1,21 @@ #include #include +#include #include "dhkem.h" #include "common.h" #include "group.h" #include "rsa.h" +#include +#include +#include +#include #include #include +using namespace nlohmann; + namespace hpke { struct GroupSignature : public Signature @@ -103,6 +110,96 @@ struct GroupSignature : public Signature return group.verify(data, sig, rpk); } + std::unique_ptr import_jwk_private( + const std::string& json_str) const override + { + // TODO(ghewett): handle failed parse + json jwk_json = json::parse(json_str); + + // TODO(ghewett): jwk_json should patch cipher suite + + // TODO(ghewett): handle the absense of 'd' + bytes d = from_base64url(jwk_json["d"]); + + return std::make_unique(group.deserialize_private(d).release()); + } + + std::unique_ptr import_jwk( + const std::string& json_str) const override + { + bytes x = bytes({}, 0); + bytes y = bytes({}, 0); + json jwk_json = json::parse(json_str); + + if (jwk_json.empty() || !jwk_json.contains("kty") || + !jwk_json.contains("crv") || !jwk_json.contains("x")) { + throw std::runtime_error("import_jwk: malformed json input"); + } + + if (jwk_json["kty"] != group.jwt_key_type) { + throw std::runtime_error("import_jwk: group keytype does not match json"); + } + + if (jwk_json["crv"] != group.jwt_curve_name) { + throw std::runtime_error("import_jwk: group curve does not match json"); + } + x = from_base64url(jwk_json["x"]); + + if (jwk_json.contains("y")) { + y = from_base64url(jwk_json["y"]); + } + return group.set_coordinates(x, y); + } + + std::string export_jwk(const bytes& enc) const override + { + bytes x; + bytes y; + json json_jwk; + json_jwk["crv"] = group.jwt_curve_name; + json_jwk["kty"] = group.jwt_key_type; + + std::unique_ptr pk = deserialize(enc); + const auto& rpk = + dynamic_cast(*(pk.release())); + group.get_coordinates(rpk, x, y); + + if (!x.empty()) { + json_jwk["x"] = to_base64url(x); + } + + if (!y.empty()) { + json_jwk["y"] = to_base64url(y); + } + return json_jwk.dump(); + } + + std::string export_jwk_private(const bytes& enc) const override + { + bytes x; + bytes y; + json json_jwk; + json_jwk["crv"] = group.jwt_curve_name; + json_jwk["kty"] = group.jwt_key_type; + + // encode the private key + json_jwk["d"] = to_base64url(enc); + + const auto priv = group.deserialize_private(enc); + const auto& rpk = + dynamic_cast(*(priv->public_key().release())); + group.get_coordinates(rpk, x, y); + + if (!x.empty()) { + json_jwk["x"] = to_base64url(x); + } + + if (!y.empty()) { + json_jwk["y"] = to_base64url(y); + } + return json_jwk.dump(); + } + private: const Group& group; }; @@ -182,6 +279,30 @@ Signature::serialize_private(const PrivateKey& /* unused */) const throw std::runtime_error("Not implemented"); } +std::unique_ptr +Signature::import_jwk(const std::string& /* unused */) const +{ + throw std::runtime_error("Not implemented."); +} + +std::unique_ptr +Signature::import_jwk_private(const std::string& /* unused */) const +{ + throw std::runtime_error("Not implemented."); +} + +std::string +Signature::export_jwk(const bytes& /* unused */) const +{ + throw std::runtime_error("Not implemented."); +} + +std::string +Signature::export_jwk_private(const bytes& /* unused */) const +{ + throw std::runtime_error("Not implemented."); +} + std::unique_ptr Signature::deserialize_private(const bytes& /* unused */) const { diff --git a/lib/tls_syntax/CMakeLists.txt b/lib/tls_syntax/CMakeLists.txt index 7561d426..0aae862a 100644 --- a/lib/tls_syntax/CMakeLists.txt +++ b/lib/tls_syntax/CMakeLists.txt @@ -9,7 +9,7 @@ file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src add_library(${CURRENT_LIB_NAME} ${LIB_HEADERS} ${LIB_SOURCES}) add_dependencies(${CURRENT_LIB_NAME} third_party) -target_link_libraries(${CURRENT_LIB_NAME} third_party) +target_link_libraries(${CURRENT_LIB_NAME} PUBLIC third_party) target_include_directories(${CURRENT_LIB_NAME} PUBLIC $ diff --git a/src/crypto.cpp b/src/crypto.cpp index 9611fbd5..1c161f53 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -384,6 +384,20 @@ SignaturePublicKey::verify(const CipherSuite& suite, return suite.sig().verify(content, signature, *pub); } +SignaturePublicKey +SignaturePublicKey::from_jwk(CipherSuite suite, const std::string& json_str) +{ + auto pub = suite.sig().import_jwk(json_str); + auto pub_data = suite.sig().serialize(*pub); + return SignaturePublicKey{ pub_data }; +} + +std::string +SignaturePublicKey::to_jwk(CipherSuite suite) const +{ + return suite.sig().export_jwk(data); +} + SignaturePrivateKey SignaturePrivateKey::generate(CipherSuite suite) { @@ -438,4 +452,20 @@ SignaturePrivateKey::set_public_key(CipherSuite suite) public_key.data = suite.sig().serialize(*pub); } +SignaturePrivateKey +SignaturePrivateKey::from_jwk(CipherSuite suite, const std::string& json_str) +{ + auto priv = suite.sig().import_jwk_private(json_str); + auto priv_data = suite.sig().serialize_private(*priv); + auto pub = priv->public_key(); + auto pub_data = suite.sig().serialize(*pub); + return { priv_data, pub_data }; +} + +std::string +SignaturePrivateKey::to_jwk(CipherSuite suite) const +{ + return suite.sig().export_jwk_private(data); +} + } // namespace mls diff --git a/test/crypto.cpp b/test/crypto.cpp index 8dd930c2..acbffa09 100644 --- a/test/crypto.cpp +++ b/test/crypto.cpp @@ -1,11 +1,12 @@ #include #include #include - +#include #include using namespace mls; using namespace mls_vectors; +using namespace nlohmann; TEST_CASE("Basic HPKE") { @@ -91,6 +92,138 @@ TEST_CASE("Signature Key Serializion") } } +TEST_CASE("Signature Key Serializion To JWK") +{ + + struct KnownAnswerTest + { + CipherSuite suite; + bool supported; + bytes pk; + std::string kty; + std::string crv; + std::string d; + std::string x; + std::string y; + }; + + std::vector cases{ + { CipherSuite::ID::P256_AES128GCM_SHA256_P256, + true, + from_hex( + "cae90bad54df6973c64f7e4116ee78409045ed43e9668d0d474948a510f38acf"), + "EC", + "P-256", + "yukLrVTfaXPGT35BFu54QJBF7UPpZo0NR0lIpRDzis8", + "nUV1xGxWcUobNQrV0DsSN_z7P8hwVivmUji8EIJnrGg", + "2TGu_-lIxa7fn8PW-3gMNod-CjwwoAiLIhkbcsHtSdw" }, + { CipherSuite::ID::X25519_AES128GCM_SHA256_Ed25519, + true, + from_hex( + "9f959eeebab856bede41bfcd985077f5eaae702dde01c76b48952c35c9a97618"), + "OKP", + "Ed25519", + "n5We7rq4Vr7eQb_NmFB39equcC3eAcdrSJUsNcmpdhg", + "NmQinNknsQjwPFpujKmLa09alb4kagXy1YJenH3Zs-I", + "" }, + { CipherSuite::ID::X25519_CHACHA20POLY1305_SHA256_Ed25519, + true, + from_hex( + "f6d9dfcfc3e7f2016df7894b959e3f922d01035292732da12158f0c08b6251ae"), + "OKP", + "Ed25519", + "9tnfz8Pn8gFt94lLlZ4_ki0BA1KScy2hIVjwwItiUa4", + "kcnJ4z9eHBgiuFSDGlsF8PyibD2seAMncB4iKamamSU", + "" }, + { CipherSuite::ID::X448_AES256GCM_SHA512_Ed448, + true, + from_hex("e8dfd869ebe67fe696f0a0a12e04111cf1e4744e1a045fa73b2285a0168f319" + "e66522c9ddec741a8dd8011d0fc4b72303053901540c36f1e89"), + "OKP", + "Ed448", + "6N_Yaevmf-aW8KChLgQRHPHkdE4aBF-" + "nOyKFoBaPMZ5mUiyd3sdBqN2AEdD8S3IwMFOQFUDDbx6J", + "5uf09bDIVeecX74gv2ljKmvf3eLUXYiB6Jbycwww8ijcbnM04rfJr1agpFC2TuVSm5d0iDCj" + "EDIA", + "" }, + { CipherSuite::ID::P521_AES256GCM_SHA512_P521, + true, + from_hex( + "01c58ae6621000da12b682f45248f88b4cef278743a4fa325fc234f8770648d440cab3" + "367e90a49293c02778732776bd3eb985415c5f9df77a212e2097f0026298b8"), + "EC", + "P-521", + "AcWK5mIQANoStoL0Ukj4i0zvJ4dDpPoyX8I0-" + "HcGSNRAyrM2fpCkkpPAJ3hzJ3a9PrmFQVxfnfd6IS4gl_ACYpi4", + "AFLfr4vhftq9G6axgJ8g6xdukrUFn2cD5HDIxp8uzSbYW_" + "QIjKdUV1pF2vzzcz7Vj185LE6kl1SqTX6Z551W38mC", + "AbPIkuJkgfBZCidxSFrJALD1_e8-tKE0Ygy1dF2PZXJMGcHQRPbnytg-" + "4iVVGbjVdcakGIuUq3aAO09NqLi8j81d" }, + { CipherSuite::ID::X448_CHACHA20POLY1305_SHA512_Ed448, + true, + from_hex("5535d624e127fed3bc20d24a51269ce842e1ce36d6a62002b7f59696fcd3d9e" + "7d865da15e8e690caf22c34bf04bd34bd761be1eacb26fec193"), + "OKP", + "Ed448", + "VTXWJOEn_tO8INJKUSac6ELhzjbWpiACt_" + "WWlvzT2efYZdoV6OaQyvIsNL8EvTS9dhvh6ssm_sGT", + "jfbh2FAWZ57XmEEgrlGLAk6Am-qZ1IibFy2qip1uU3zOfWJ-TXmq4Ty-" + "yssJdZ5c0niU3SNO7JkA", + "" }, + { CipherSuite::ID::P384_AES256GCM_SHA384_P384, + true, + from_hex("33500ad0e749f53707e1f5ebef7d80758f95923c5b02acd89c21ffb2eb9f4f0" + "ccc5db144cd92e1577963dfb1b4e3fa68"), + "EC", + "P-384", + "M1AK0OdJ9TcH4fXr732AdY-VkjxbAqzYnCH_suufTwzMXbFEzZLhV3lj37G04_po", + "FyXCw9vukrBkLD_Lu7HvZw6cr-gwvpldN4aqZgtjAuM1rRSL74Lfi3CBBD8LpB0A", + "UUd8Qs3VdkOTFJlP62TKaVBp0JZlD74b7TU2gNlkDX3o8EIfl4POCooLs920bCJf" } + }; + + for (const auto& tc : cases) { + const CipherSuite suite{ tc.suite }; + + if (!tc.supported) { + auto private_key = SignaturePrivateKey::generate(suite); + CHECK_THROWS_WITH(private_key.to_jwk(suite), "Unsupported group"); + continue; + } + + // Export Private Key + auto private_key = SignaturePrivateKey::parse(suite, tc.pk); + auto jwk_str = private_key.to_jwk(tc.suite); + auto jwk_json = json::parse(jwk_str); + REQUIRE(jwk_json["kty"] == tc.kty); + REQUIRE(jwk_json["crv"] == tc.crv); + REQUIRE(jwk_json["d"] == tc.d); + REQUIRE(jwk_json["x"] == tc.x); + + if (!tc.y.empty()) { + REQUIRE(jwk_json["y"] == tc.y); + } + + // Export Public Key + auto jwk_pk_str = private_key.public_key.to_jwk(tc.suite); + auto jwk_pk_json = json::parse(jwk_pk_str); + REQUIRE(jwk_pk_json["kty"] == tc.kty); + REQUIRE(jwk_pk_json["crv"] == tc.crv); + REQUIRE(jwk_pk_json["x"] == tc.x); + + if (!tc.y.empty()) { + REQUIRE(jwk_pk_json["y"] == tc.y); + } + + // Import Private Key + auto import_jwk_sk = SignaturePrivateKey::from_jwk(tc.suite, jwk_str); + REQUIRE(tc.pk == import_jwk_sk.data); + + // Import Public Key + auto import_jwk_pk = SignaturePublicKey::from_jwk(tc.suite, jwk_pk_str); + REQUIRE(private_key.public_key.data == import_jwk_pk.data); + } +} + TEST_CASE("Crypto Interop") { for (auto suite : all_supported_suites) { diff --git a/vcpkg.json b/vcpkg.json index f5870561..0ffd459e 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -7,7 +7,8 @@ "name": "openssl", "version>=": "1.1.1n" }, - "doctest" + "doctest", + "nlohmann-json" ], "builtin-baseline": "3b3bd424827a1f7f4813216f6b32b6c61e386b2e", "overrides": [