diff --git a/CHANGELOG.md b/CHANGELOG.md index 015ca25f7047..404e24837d3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,18 @@ and this project adheres Fto [Semantic Versioning](http://semver.org/spec/v2.0.0 - All definitions in CCF's public headers are now under the `ccf::` namespace. Any application code which references any of these types directly (notably `StartupConfig`, `http_status`, `LoggerLevel`), they will now need to be prefixed with the `ccf::` namespace. - `cchost` now requires `--config`. +### Removed + +- Previously deprecated `/gov/jwt_keys/all` has been removed. The `/gov/service/jwk` endpoint to be used instead. + +### Changed + +- JWT authentication now supports raw public keys along with certificates (#6601). + - Public key information ('n' and 'e' claims) now have a priority if defined in JWK set, 'x5c' remains as a backup option. + - Has same side-effects as #5809 does please see the changelog entry for that change for more details. In short: + - stale JWKs may be used for JWT validation on older nodes during the upgrade. + - old tables are not cleaned up, #6222 is tracking those. + ## [6.0.0-dev7] [6.0.0-dev7]: https://github.com/microsoft/CCF/releases/tag/6.0.0-dev7 diff --git a/doc/schemas/gov_openapi.json b/doc/schemas/gov_openapi.json index 90ab5ed30d61..2a5e0519f8c8 100644 --- a/doc/schemas/gov_openapi.json +++ b/doc/schemas/gov_openapi.json @@ -291,27 +291,6 @@ }, "type": "object" }, - "KeyIdInfo": { - "properties": { - "cert": { - "$ref": "#/components/schemas/Pem" - }, - "issuer": { - "$ref": "#/components/schemas/string" - } - }, - "required": [ - "issuer", - "cert" - ], - "type": "object" - }, - "KeyIdInfo_array": { - "items": { - "$ref": "#/components/schemas/KeyIdInfo" - }, - "type": "array" - }, "MDType": { "enum": [ "NONE", @@ -799,6 +778,24 @@ "type": "string" }, "OpenIDJWKMetadata": { + "properties": { + "constraint": { + "$ref": "#/components/schemas/string" + }, + "issuer": { + "$ref": "#/components/schemas/string" + }, + "public_key": { + "$ref": "#/components/schemas/base64string" + } + }, + "required": [ + "issuer", + "public_key" + ], + "type": "object" + }, + "OpenIDJWKMetadataLegacy": { "properties": { "cert": { "$ref": "#/components/schemas/base64string" @@ -811,11 +808,17 @@ } }, "required": [ - "cert", - "issuer" + "issuer", + "cert" ], "type": "object" }, + "OpenIDJWKMetadataLegacy_array": { + "items": { + "$ref": "#/components/schemas/OpenIDJWKMetadataLegacy" + }, + "type": "array" + }, "OpenIDJWKMetadata_array": { "items": { "$ref": "#/components/schemas/OpenIDJWKMetadata" @@ -1222,9 +1225,9 @@ }, "type": "object" }, - "string_to_KeyIdInfo_array": { + "string_to_OpenIDJWKMetadataLegacy_array": { "additionalProperties": { - "$ref": "#/components/schemas/KeyIdInfo_array" + "$ref": "#/components/schemas/OpenIDJWKMetadataLegacy_array" }, "type": "object" }, @@ -1473,31 +1476,6 @@ } ] }, - "/gov/jwt_keys/all": { - "get": { - "deprecated": true, - "description": "This endpoint is deprecated from 5.0.0. It is replaced by POST /gov/service/jwk", - "operationId": "GetGovJwtKeysAll", - "responses": { - "200": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/string_to_KeyIdInfo_array" - } - } - }, - "description": "Default response description" - }, - "default": { - "$ref": "#/components/responses/default" - } - }, - "x-ccf-forwarding": { - "$ref": "#/components/x-ccf-forwarding/always" - } - } - }, "/gov/kv/constitution": { "get": { "deprecated": true, @@ -1752,6 +1730,31 @@ "get": { "deprecated": true, "operationId": "GetGovKvJwtPublicSigningKeysMetadata", + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/string_to_OpenIDJWKMetadataLegacy_array" + } + } + }, + "description": "Default response description" + }, + "default": { + "$ref": "#/components/responses/default" + } + }, + "summary": "This route is auto-generated from the KV schema.", + "x-ccf-forwarding": { + "$ref": "#/components/x-ccf-forwarding/sometimes" + } + } + }, + "/gov/kv/jwt/public_signing_keys_metadata_v2": { + "get": { + "deprecated": true, + "operationId": "GetGovKvJwtPublicSigningKeysMetadataV2", "responses": { "200": { "content": { diff --git a/include/ccf/crypto/jwk.h b/include/ccf/crypto/jwk.h index 1b4886cb1a22..ae0f1f5b9ab0 100644 --- a/include/ccf/crypto/jwk.h +++ b/include/ccf/crypto/jwk.h @@ -27,13 +27,12 @@ namespace ccf::crypto JsonWebKeyType kty; std::optional<std::string> kid = std::nullopt; std::optional<std::vector<std::string>> x5c = std::nullopt; - std::optional<std::string> issuer = std::nullopt; bool operator==(const JsonWebKey&) const = default; }; DECLARE_JSON_TYPE_WITH_OPTIONAL_FIELDS(JsonWebKey); DECLARE_JSON_REQUIRED_FIELDS(JsonWebKey, kty); - DECLARE_JSON_OPTIONAL_FIELDS(JsonWebKey, kid, x5c, issuer); + DECLARE_JSON_OPTIONAL_FIELDS(JsonWebKey, kid, x5c); enum class JsonWebKeyECCurve { @@ -47,6 +46,25 @@ namespace ccf::crypto {JsonWebKeyECCurve::P384, "P-384"}, {JsonWebKeyECCurve::P521, "P-521"}}); + struct JsonWebKeyData + { + JsonWebKeyType kty; + std::optional<std::string> kid = std::nullopt; + std::optional<std::vector<std::string>> x5c = std::nullopt; + std::optional<std::string> n = std::nullopt; + std::optional<std::string> e = std::nullopt; + std::optional<std::string> x = std::nullopt; + std::optional<std::string> y = std::nullopt; + std::optional<JsonWebKeyECCurve> crv = std::nullopt; + std::optional<std::string> issuer = std::nullopt; + + bool operator==(const JsonWebKeyData&) const = default; + }; + DECLARE_JSON_TYPE_WITH_OPTIONAL_FIELDS(JsonWebKeyData); + DECLARE_JSON_REQUIRED_FIELDS(JsonWebKeyData, kty); + DECLARE_JSON_OPTIONAL_FIELDS( + JsonWebKeyData, kid, x5c, n, e, x, y, crv, issuer); + static JsonWebKeyECCurve curve_id_to_jwk_curve(CurveID curve_id) { switch (curve_id) diff --git a/include/ccf/crypto/rsa_public_key.h b/include/ccf/crypto/rsa_public_key.h index cd62eba0e7f4..1fcd81dc6d43 100644 --- a/include/ccf/crypto/rsa_public_key.h +++ b/include/ccf/crypto/rsa_public_key.h @@ -84,6 +84,13 @@ namespace ccf::crypto MDType md_type = MDType::NONE, size_t salt_legth = 0) = 0; + virtual bool verify_pkcs1( + const uint8_t* contents, + size_t contents_size, + const uint8_t* signature, + size_t signature_size, + MDType md_type = MDType::NONE) = 0; + struct Components { std::vector<uint8_t> n; diff --git a/include/ccf/endpoints/authentication/jwt_auth.h b/include/ccf/endpoints/authentication/jwt_auth.h index 3a44ee55a73a..70d4f3f7c813 100644 --- a/include/ccf/endpoints/authentication/jwt_auth.h +++ b/include/ccf/endpoints/authentication/jwt_auth.h @@ -17,7 +17,7 @@ namespace ccf nlohmann::json payload; }; - struct VerifiersCache; + struct PublicKeysCache; bool validate_issuer( const std::string& iss, @@ -28,7 +28,7 @@ namespace ccf { protected: static const OpenAPISecuritySchema security_schema; - std::unique_ptr<VerifiersCache> verifiers; + std::unique_ptr<PublicKeysCache> keys_cache; public: static constexpr auto SECURITY_SCHEME_NAME = "jwt"; diff --git a/include/ccf/service/tables/jwt.h b/include/ccf/service/tables/jwt.h index 23ebe5268499..8b21448bf58a 100644 --- a/include/ccf/service/tables/jwt.h +++ b/include/ccf/service/tables/jwt.h @@ -37,27 +37,42 @@ namespace ccf using JwtIssuer = std::string; using JwtKeyId = std::string; using Cert = std::vector<uint8_t>; + using PublicKey = std::vector<uint8_t>; struct OpenIDJWKMetadata { - Cert cert; + PublicKey public_key; JwtIssuer issuer; std::optional<JwtIssuer> constraint; }; DECLARE_JSON_TYPE_WITH_OPTIONAL_FIELDS(OpenIDJWKMetadata); - DECLARE_JSON_REQUIRED_FIELDS(OpenIDJWKMetadata, cert, issuer); + DECLARE_JSON_REQUIRED_FIELDS(OpenIDJWKMetadata, issuer, public_key); DECLARE_JSON_OPTIONAL_FIELDS(OpenIDJWKMetadata, constraint); - using JwtIssuers = ServiceMap<JwtIssuer, JwtIssuerMetadata>; - using JwtPublicSigningKeys = + using JwtPublicSigningKeysMetadata = ServiceMap<JwtKeyId, std::vector<OpenIDJWKMetadata>>; + struct OpenIDJWKMetadataLegacy + { + Cert cert; + JwtIssuer issuer; + std::optional<JwtIssuer> constraint; + }; + DECLARE_JSON_TYPE_WITH_OPTIONAL_FIELDS(OpenIDJWKMetadataLegacy); + DECLARE_JSON_REQUIRED_FIELDS(OpenIDJWKMetadataLegacy, issuer, cert); + DECLARE_JSON_OPTIONAL_FIELDS(OpenIDJWKMetadataLegacy, constraint); + + using JwtPublicSigningKeysMetadataLegacy = + ServiceMap<JwtKeyId, std::vector<OpenIDJWKMetadataLegacy>>; + + using JwtIssuers = ServiceMap<JwtIssuer, JwtIssuerMetadata>; + namespace Tables { static constexpr auto JWT_ISSUERS = "public:ccf.gov.jwt.issuers"; static constexpr auto JWT_PUBLIC_SIGNING_KEYS_METADATA = - "public:ccf.gov.jwt.public_signing_keys_metadata"; + "public:ccf.gov.jwt.public_signing_keys_metadata_v2"; namespace Legacy { @@ -65,6 +80,8 @@ namespace ccf "public:ccf.gov.jwt.public_signing_key"; static constexpr auto JWT_PUBLIC_SIGNING_KEY_ISSUER = "public:ccf.gov.jwt.public_signing_key_issuer"; + static constexpr auto JWT_PUBLIC_SIGNING_KEYS_METADATA = + "public:ccf.gov.jwt.public_signing_keys_metadata"; using JwtPublicSigningKeys = ccf::kv::RawCopySerialisedMap<JwtKeyId, Cert>; @@ -75,7 +92,7 @@ namespace ccf struct JsonWebKeySet { - std::vector<ccf::crypto::JsonWebKey> keys; + std::vector<ccf::crypto::JsonWebKeyData> keys; bool operator!=(const JsonWebKeySet& rhs) const { diff --git a/samples/constitutions/default/actions.js b/samples/constitutions/default/actions.js index 654ecc065326..6ac5527e5aca 100644 --- a/samples/constitutions/default/actions.js +++ b/samples/constitutions/default/actions.js @@ -130,15 +130,26 @@ function checkJwks(value, field) { for (const [i, jwk] of value.keys.entries()) { checkType(jwk.kid, "string", `${field}.keys[${i}].kid`); checkType(jwk.kty, "string", `${field}.keys[${i}].kty`); - checkType(jwk.x5c, "array", `${field}.keys[${i}].x5c`); - checkLength(jwk.x5c, 1, null, `${field}.keys[${i}].x5c`); - for (const [j, b64der] of jwk.x5c.entries()) { - checkType(b64der, "string", `${field}.keys[${i}].x5c[${j}]`); - const pem = - "-----BEGIN CERTIFICATE-----\n" + - b64der + - "\n-----END CERTIFICATE-----"; - checkX509CertBundle(pem, `${field}.keys[${i}].x5c[${j}]`); + if (jwk.x5c) { + checkType(jwk.x5c, "array", `${field}.keys[${i}].x5c`); + checkLength(jwk.x5c, 1, null, `${field}.keys[${i}].x5c`); + for (const [j, b64der] of jwk.x5c.entries()) { + checkType(b64der, "string", `${field}.keys[${i}].x5c[${j}]`); + const pem = + "-----BEGIN CERTIFICATE-----\n" + + b64der + + "\n-----END CERTIFICATE-----"; + checkX509CertBundle(pem, `${field}.keys[${i}].x5c[${j}]`); + } + } else if (jwk.n && jwk.e) { + checkType(jwk.n, "string", `${field}.keys[${i}].n`); + checkType(jwk.e, "string", `${field}.keys[${i}].e`); + } else if (jwk.x && jwk.y) { + checkType(jwk.x, "string", `${field}.keys[${i}].x`); + checkType(jwk.y, "string", `${field}.keys[${i}].y`); + checkType(jwk.crv, "string", `${field}.keys[${i}].crv`); + } else { + throw new Error("JWK must contain either x5c, or n/e for RSA key type, or x/y/crv for EC key type"); } } } diff --git a/src/crypto/openssl/rsa_public_key.cpp b/src/crypto/openssl/rsa_public_key.cpp index b8fb2f61be58..6679ed9dcc64 100644 --- a/src/crypto/openssl/rsa_public_key.cpp +++ b/src/crypto/openssl/rsa_public_key.cpp @@ -54,6 +54,17 @@ namespace ccf::crypto auto msg = OpenSSL::error_string(ec); throw std::runtime_error(fmt::format("OpenSSL error: {}", msg)); } + +// As it's a common patter to rely on successful key wrapper construction as a +// confirmation of a concrete key type, this must fail for non-RSA keys. +#if defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3 + if (!key || EVP_PKEY_get_base_id(key) != EVP_PKEY_RSA) +#else + if (!key || !EVP_PKEY_get0_RSA(key)) +#endif + { + throw std::logic_error("invalid RSA key"); + } } std::pair<Unique_BIGNUM, Unique_BIGNUM> get_modulus_and_exponent( @@ -208,6 +219,22 @@ namespace ccf::crypto pctx, signature, signature_size, hash.data(), hash.size()) == 1; } + bool RSAPublicKey_OpenSSL::verify_pkcs1( + const uint8_t* contents, + size_t contents_size, + const uint8_t* signature, + size_t signature_size, + MDType md_type) + { + auto hash = OpenSSLHashProvider().Hash(contents, contents_size, md_type); + Unique_EVP_PKEY_CTX pctx(key); + CHECK1(EVP_PKEY_verify_init(pctx)); + CHECK1(EVP_PKEY_CTX_set_rsa_padding(pctx, RSA_PKCS1_PADDING)); + CHECK1(EVP_PKEY_CTX_set_signature_md(pctx, get_md_type(md_type))); + return EVP_PKEY_verify( + pctx, signature, signature_size, hash.data(), hash.size()) == 1; + } + std::vector<uint8_t> RSAPublicKey_OpenSSL::bn_bytes(const BIGNUM* bn) { std::vector<uint8_t> r(BN_num_bytes(bn)); diff --git a/src/crypto/openssl/rsa_public_key.h b/src/crypto/openssl/rsa_public_key.h index 061ba053ad80..abe43fcf758a 100644 --- a/src/crypto/openssl/rsa_public_key.h +++ b/src/crypto/openssl/rsa_public_key.h @@ -55,6 +55,13 @@ namespace ccf::crypto MDType md_type = MDType::NONE, size_t salt_length = 0) override; + virtual bool verify_pkcs1( + const uint8_t* contents, + size_t contents_size, + const uint8_t* signature, + size_t signature_size, + MDType md_type = MDType::NONE) override; + virtual Components components() const override; static std::vector<uint8_t> bn_bytes(const BIGNUM* bn); diff --git a/src/endpoints/authentication/jwt_auth.cpp b/src/endpoints/authentication/jwt_auth.cpp index 05ceb862ff2d..d617b4b0f556 100644 --- a/src/endpoints/authentication/jwt_auth.cpp +++ b/src/endpoints/authentication/jwt_auth.cpp @@ -3,6 +3,9 @@ #include "ccf/endpoints/authentication/jwt_auth.h" +#include "ccf/crypto/ecdsa.h" +#include "ccf/crypto/public_key.h" +#include "ccf/crypto/rsa_key_pair.h" #include "ccf/ds/nonstd.h" #include "ccf/pal/locking.h" #include "ccf/rpc_context.h" @@ -82,34 +85,77 @@ namespace ccf return tenant_id && tid && *tid == *tenant_id; } - struct VerifiersCache + struct PublicKeysCache { - static constexpr size_t DEFAULT_MAX_VERIFIERS = 10; + static constexpr size_t DEFAULT_MAX_KEYS = 10; using DER = std::vector<uint8_t>; - ccf::pal::Mutex verifiers_lock; - LRU<DER, ccf::crypto::VerifierPtr> verifiers; + using KeyVariant = + std::variant<ccf::crypto::RSAPublicKeyPtr, ccf::crypto::PublicKeyPtr>; + ccf::pal::Mutex keys_lock; + LRU<DER, KeyVariant> keys; - VerifiersCache(size_t max_verifiers = DEFAULT_MAX_VERIFIERS) : - verifiers(max_verifiers) - {} + PublicKeysCache(size_t max_keys = DEFAULT_MAX_KEYS) : keys(max_keys) {} - ccf::crypto::VerifierPtr get_verifier(const DER& der) + bool verify( + const uint8_t* contents, + size_t contents_size, + const uint8_t* signature, + size_t signature_size, + const DER& der) { - std::lock_guard<ccf::pal::Mutex> guard(verifiers_lock); + std::lock_guard<ccf::pal::Mutex> guard(keys_lock); - auto it = verifiers.find(der); - if (it == verifiers.end()) + auto it = keys.find(der); + if (it == keys.end()) { - it = verifiers.insert(der, ccf::crypto::make_unique_verifier(der)); + try + { + it = keys.insert(der, ccf::crypto::make_rsa_public_key(der)); + } + catch (const std::exception&) + { + it = keys.insert(der, ccf::crypto::make_public_key(der)); + } } - return it->second; + const auto& key = it->second; + if (std::holds_alternative<ccf::crypto::RSAPublicKeyPtr>(key)) + { + LOG_DEBUG_FMT("Verify der: {} as RSA key", der); + + // Obsolote PKCS1 padding is chosen for JWT, as explained in details in + // https://github.com/microsoft/CCF/issues/6601#issuecomment-2512059875. + return std::get<ccf::crypto::RSAPublicKeyPtr>(key)->verify_pkcs1( + contents, + contents_size, + signature, + signature_size, + ccf::crypto::MDType::SHA256); + } + else if (std::holds_alternative<ccf::crypto::PublicKeyPtr>(key)) + { + LOG_DEBUG_FMT("Verify der: {} as EC key", der); + + const auto sig_der = ccf::crypto::ecdsa_sig_p1363_to_der( + std::vector<uint8_t>(signature, signature + signature_size)); + return std::get<ccf::crypto::PublicKeyPtr>(key)->verify( + contents, + contents_size, + sig_der.data(), + sig_der.size(), + ccf::crypto::MDType::SHA256); + } + else + { + LOG_DEBUG_FMT("Key not found for der: {}", der); + return false; + } } }; JwtAuthnPolicy::JwtAuthnPolicy() : - verifiers(std::make_unique<VerifiersCache>()) + keys_cache(std::make_unique<PublicKeysCache>()) {} JwtAuthnPolicy::~JwtAuthnPolicy() = default; @@ -129,11 +175,42 @@ namespace ccf } auto& token = token_opt.value(); - auto keys = tx.ro<JwtPublicSigningKeys>( + auto keys = tx.ro<JwtPublicSigningKeysMetadata>( ccf::Tables::JWT_PUBLIC_SIGNING_KEYS_METADATA); const auto key_id = token.header_typed.kid; auto token_keys = keys->get(key_id); + // For metadata KID->(cert,issuer,constraint). + // + // Note, that Legacy keys are stored as certs, new approach is raw keys, so + // conversion from cert to raw key is needed. + if (!token_keys) + { + auto fallback_certs = tx.ro<JwtPublicSigningKeysMetadataLegacy>( + ccf::Tables::Legacy::JWT_PUBLIC_SIGNING_KEYS_METADATA); + auto fallback_data = fallback_certs->get(key_id); + if (fallback_data) + { + auto new_keys = std::vector<OpenIDJWKMetadata>(); + for (const auto& metadata : *fallback_data) + { + auto verifier = ccf::crypto::make_unique_verifier(metadata.cert); + new_keys.push_back(OpenIDJWKMetadata{ + .public_key = verifier->public_key_der(), + .issuer = metadata.issuer, + .constraint = metadata.constraint}); + } + if (!new_keys.empty()) + { + token_keys = new_keys; + } + } + } + + // For metadata as two separate tables, KID->JwtIssuer and KID->Cert. + // + // Note, that Legacy keys are stored as certs, new approach is raw keys, so + // conversion from certs to keys is needed. if (!token_keys) { auto fallback_keys = tx.ro<Tables::Legacy::JwtPublicSigningKeys>( @@ -141,11 +218,12 @@ namespace ccf auto fallback_issuers = tx.ro<Tables::Legacy::JwtPublicSigningKeyIssuer>( ccf::Tables::Legacy::JWT_PUBLIC_SIGNING_KEY_ISSUER); - auto fallback_key = fallback_keys->get(key_id); - if (fallback_key) + auto fallback_cert = fallback_keys->get(key_id); + if (fallback_cert) { + auto verifier = ccf::crypto::make_unique_verifier(*fallback_cert); token_keys = std::vector<OpenIDJWKMetadata>{OpenIDJWKMetadata{ - .cert = *fallback_key, + .public_key = verifier->public_key_der(), .issuer = *fallback_issuers->get(key_id), .constraint = std::nullopt}}; } @@ -160,8 +238,12 @@ namespace ccf for (const auto& metadata : *token_keys) { - auto verifier = verifiers->get_verifier(metadata.cert); - if (!::http::JwtVerifier::validate_token_signature(token, verifier)) + if (!keys_cache->verify( + (uint8_t*)token.signed_content.data(), + token.signed_content.size(), + token.signature.data(), + token.signature.size(), + metadata.public_key)) { error_reason = "Signature verification failed"; continue; @@ -171,7 +253,7 @@ namespace ccf const size_t time_now = std::chrono::duration_cast<std::chrono::seconds>( ccf::get_enclave_time()) .count(); - if (time_now < token.payload_typed.nbf) + if (token.payload_typed.nbf && time_now < *token.payload_typed.nbf) { error_reason = fmt::format( "Current time {} is before token's Not Before (nbf) claim {}", diff --git a/src/http/http_jwt.h b/src/http/http_jwt.h index aecf1c71074e..09d688400cb7 100644 --- a/src/http/http_jwt.h +++ b/src/http/http_jwt.h @@ -16,9 +16,13 @@ namespace http { enum class JwtCryptoAlgorithm { - RS256 + RS256, + ES256, }; - DECLARE_JSON_ENUM(JwtCryptoAlgorithm, {{JwtCryptoAlgorithm::RS256, "RS256"}}); + DECLARE_JSON_ENUM( + JwtCryptoAlgorithm, + {{JwtCryptoAlgorithm::RS256, "RS256"}, + {JwtCryptoAlgorithm::ES256, "ES256"}}); struct JwtHeader { @@ -30,14 +34,14 @@ namespace http struct JwtPayload { - size_t nbf; size_t exp; std::string iss; + std::optional<size_t> nbf; std::optional<std::string> tid; }; DECLARE_JSON_TYPE_WITH_OPTIONAL_FIELDS(JwtPayload) - DECLARE_JSON_REQUIRED_FIELDS(JwtPayload, nbf, exp, iss); - DECLARE_JSON_OPTIONAL_FIELDS(JwtPayload, tid) + DECLARE_JSON_REQUIRED_FIELDS(JwtPayload, exp, iss); + DECLARE_JSON_OPTIONAL_FIELDS(JwtPayload, nbf, tid); class JwtVerifier { diff --git a/src/node/gov/handlers/service_state.h b/src/node/gov/handlers/service_state.h index eabd6d65ea3a..6992fabb3f1c 100644 --- a/src/node/gov/handlers/service_state.h +++ b/src/node/gov/handlers/service_state.h @@ -600,7 +600,7 @@ namespace ccf::gov::endpoints auto keys = nlohmann::json::object(); auto jwt_keys_handle = - ctx.tx.template ro<ccf::JwtPublicSigningKeys>( + ctx.tx.template ro<ccf::JwtPublicSigningKeysMetadata>( ccf::Tables::JWT_PUBLIC_SIGNING_KEYS_METADATA); jwt_keys_handle->foreach( @@ -612,11 +612,10 @@ namespace ccf::gov::endpoints { auto info = nlohmann::json::object(); - // cert is stored as DER - convert to PEM for API - const auto cert_pem = - ccf::crypto::cert_der_to_pem(metadata.cert); - info["certificate"] = cert_pem.str(); - + info["publicKey"] = + ccf::crypto::make_rsa_public_key(metadata.public_key) + ->public_key_pem() + .str(); info["issuer"] = metadata.issuer; info["constraint"] = metadata.constraint; diff --git a/src/node/rpc/jwt_management.h b/src/node/rpc/jwt_management.h index af7c011ac8ef..3aebb17f6c70 100644 --- a/src/node/rpc/jwt_management.h +++ b/src/node/rpc/jwt_management.h @@ -2,6 +2,7 @@ // Licensed under the Apache 2.0 License. #pragma once +#include "ccf/crypto/rsa_key_pair.h" #include "ccf/crypto/verifier.h" #include "ccf/ds/hex.h" #include "ccf/service/tables/jwt.h" @@ -12,13 +13,120 @@ #include <set> #include <sstream> +namespace +{ + std::vector<uint8_t> try_parse_raw_rsa(const ccf::crypto::JsonWebKeyData& jwk) + { + if (!jwk.e || jwk.e->empty() || !jwk.n || jwk.n->empty()) + { + return {}; + } + + std::vector<uint8_t> der; + ccf::crypto::JsonWebKeyRSAPublic data; + data.kty = ccf::crypto::JsonWebKeyType::RSA; + data.kid = jwk.kid.value(); + data.n = jwk.n.value(); + data.e = jwk.e.value(); + try + { + const auto pubkey = ccf::crypto::make_rsa_public_key(data); + return pubkey->public_key_der(); + } + catch (const std::invalid_argument& exc) + { + throw std::logic_error( + fmt::format("Failed to construct RSA public key: {}", exc.what())); + } + } + + std::vector<uint8_t> try_parse_raw_ec(const ccf::crypto::JsonWebKeyData& jwk) + { + if (!jwk.x || jwk.x->empty() || !jwk.y || jwk.y->empty() || !jwk.crv) + { + return {}; + } + + ccf::crypto::JsonWebKeyECPublic data; + data.kty = ccf::crypto::JsonWebKeyType::EC; + data.kid = jwk.kid.value(); + data.crv = jwk.crv.value(); + data.x = jwk.x.value(); + data.y = jwk.y.value(); + try + { + const auto pubkey = ccf::crypto::make_public_key(data); + return pubkey->public_key_der(); + } + catch (const std::invalid_argument& exc) + { + throw std::logic_error( + fmt::format("Failed to construct EC public key: {}", exc.what())); + } + } + + std::vector<uint8_t> try_parse_x5c(const ccf::crypto::JsonWebKeyData& jwk) + { + if (!jwk.x5c || jwk.x5c->empty()) + { + return {}; + } + + const auto& kid = jwk.kid.value(); + auto& der_base64 = jwk.x5c.value()[0]; + ccf::Cert der; + try + { + der = ccf::crypto::raw_from_b64(der_base64); + } + catch (const std::invalid_argument& e) + { + throw std::logic_error( + fmt::format("Could not parse x5c of key id {}: {}", kid, e.what())); + } + try + { + auto verifier = ccf::crypto::make_unique_verifier(der); + return verifier->public_key_der(); + } + catch (std::invalid_argument& exc) + { + throw std::logic_error(fmt::format( + "JWKS kid {} has an invalid X.509 certificate: {}", kid, exc.what())); + } + } + + std::vector<uint8_t> try_parse_jwk(const ccf::crypto::JsonWebKeyData& jwk) + { + const auto& kid = jwk.kid.value(); + auto key = try_parse_raw_rsa(jwk); + if (!key.empty()) + { + return key; + } + key = try_parse_raw_ec(jwk); + if (!key.empty()) + { + return key; + } + key = try_parse_x5c(jwk); + if (!key.empty()) + { + return key; + } + + throw std::logic_error( + fmt::format("JWKS kid {} has neither RSA/EC public key or x5c", kid)); + } +} + namespace ccf { static void legacy_remove_jwt_public_signing_keys( ccf::kv::Tx& tx, std::string issuer) { - auto keys = - tx.rw<JwtPublicSigningKeys>(Tables::Legacy::JWT_PUBLIC_SIGNING_KEYS); + auto keys = tx.rw<Tables::Legacy::JwtPublicSigningKeys>( + Tables::Legacy::JWT_PUBLIC_SIGNING_KEYS); auto key_issuer = tx.rw<Tables::Legacy::JwtPublicSigningKeyIssuer>( Tables::Legacy::JWT_PUBLIC_SIGNING_KEY_ISSUER); @@ -31,14 +139,38 @@ namespace ccf } return true; }); + + auto metadata = tx.rw<JwtPublicSigningKeysMetadataLegacy>( + Tables::Legacy::JWT_PUBLIC_SIGNING_KEYS_METADATA); + metadata->foreach([&issuer, &metadata](const auto& k, const auto& v) { + std::vector<OpenIDJWKMetadataLegacy> updated; + for (const auto& key : v) + { + if (key.issuer != issuer) + { + updated.push_back(key); + } + } + + if (updated.empty()) + { + metadata->remove(k); + } + else if (updated.size() < v.size()) + { + metadata->put(k, updated); + } + + return true; + }); } static bool check_issuer_constraint( const std::string& issuer, const std::string& constraint) { // Only accept key constraints for the same (sub)domain. This is to avoid - // setting keys from issuer A which will be used to validate iss claims for - // issuer B, so this doesn't make sense (at least for now). + // setting keys from issuer A which will be used to validate iss claims + // for issuer B, so this doesn't make sense (at least for now). const auto issuer_domain = ::http::parse_url_full(issuer).host; const auto constraint_domain = ::http::parse_url_full(constraint).host; @@ -48,13 +180,13 @@ namespace ccf return false; } - // Either constraint's domain == issuer's domain or it is a subdomain, e.g.: - // limited.facebook.com + // Either constraint's domain == issuer's domain or it is a subdomain, + // e.g.: limited.facebook.com // .facebook.com // // It may make sense to support vice-versa too, but we haven't found any - // instances of that so far, so leaveing it only-way only for facebook-like - // cases. + // instances of that so far, so leaveing it only-way only for + // facebook-like cases. if (issuer_domain != constraint_domain) { const auto pattern = "." + constraint_domain; @@ -68,12 +200,12 @@ namespace ccf ccf::kv::Tx& tx, std::string issuer) { // Unlike resetting JWT keys for a particular issuer, removing keys can be - // safely done on both table revisions, as soon as the application shouldn't - // use them anyway after being ask about that explicitly. + // safely done on both table revisions, as soon as the application + // shouldn't use them anyway after being ask about that explicitly. legacy_remove_jwt_public_signing_keys(tx, issuer); - auto keys = - tx.rw<JwtPublicSigningKeys>(Tables::JWT_PUBLIC_SIGNING_KEYS_METADATA); + auto keys = tx.rw<JwtPublicSigningKeysMetadata>( + Tables::JWT_PUBLIC_SIGNING_KEYS_METADATA); keys->foreach([&issuer, &keys](const auto& k, const auto& v) { auto it = find_if(v.begin(), v.end(), [&](const auto& metadata) { @@ -105,82 +237,53 @@ namespace ccf const JwtIssuerMetadata& issuer_metadata, const JsonWebKeySet& jwks) { - auto keys = - tx.rw<JwtPublicSigningKeys>(Tables::JWT_PUBLIC_SIGNING_KEYS_METADATA); + auto keys = tx.rw<JwtPublicSigningKeysMetadata>( + Tables::JWT_PUBLIC_SIGNING_KEYS_METADATA); // add keys if (jwks.keys.empty()) { LOG_FAIL_FMT("{}: JWKS has no keys", log_prefix); return false; } - std::map<std::string, std::vector<uint8_t>> new_keys; + std::map<std::string, PublicKey> new_keys; std::map<std::string, JwtIssuer> issuer_constraints; - for (auto& jwk : jwks.keys) - { - if (!jwk.kid.has_value()) - { - LOG_FAIL_FMT("No kid for JWT signing key"); - return false; - } - - if (!jwk.x5c.has_value() && jwk.x5c->empty()) - { - LOG_FAIL_FMT("{}: JWKS is invalid (empty x5c)", log_prefix); - return false; - } - - auto& der_base64 = jwk.x5c.value()[0]; - ccf::Cert der; - auto const& kid = jwk.kid.value(); - try - { - der = ccf::crypto::raw_from_b64(der_base64); - } - catch (const std::invalid_argument& e) - { - LOG_FAIL_FMT( - "{}: Could not parse x5c of key id {}: {}", - log_prefix, - kid, - e.what()); - return false; - } - try - { - ccf::crypto::make_unique_verifier( - (std::vector<uint8_t>)der); // throws on error - } - catch (std::invalid_argument& exc) + try + { + for (auto& jwk : jwks.keys) { - LOG_FAIL_FMT( - "{}: JWKS kid {} has an invalid X.509 certificate: {}", - log_prefix, - kid, - exc.what()); - return false; - } + if (!jwk.kid.has_value()) + { + throw(std::logic_error("Missing kid for JWT signing key")); + } - LOG_INFO_FMT("{}: Storing JWT signing key with kid {}", log_prefix, kid); - new_keys.emplace(kid, der); + const auto& kid = jwk.kid.value(); + auto key_der = try_parse_jwk(jwk); - if (jwk.issuer) - { - if (!check_issuer_constraint(issuer, *jwk.issuer)) + if (jwk.issuer) { - LOG_FAIL_FMT( - "{}: JWKS kid {} with issuer constraint {} fails validation " - "against issuer {}", - log_prefix, - kid, - *jwk.issuer, - issuer); - return false; + if (!check_issuer_constraint(issuer, *jwk.issuer)) + { + throw std::logic_error(fmt::format( + "JWKS kid {} with issuer constraint {} fails validation " + "against " + "issuer {}", + kid, + *jwk.issuer, + issuer)); + } + + issuer_constraints.emplace(kid, *jwk.issuer); } - issuer_constraints.emplace(kid, *jwk.issuer); + new_keys.emplace(kid, key_der); } } + catch (const std::exception& exc) + { + LOG_FAIL_FMT("{}: {}", log_prefix, exc.what()); + return false; + } if (new_keys.empty()) { @@ -203,7 +306,10 @@ namespace ccf for (auto& [kid, der] : new_keys) { - OpenIDJWKMetadata value{der, issuer, std::nullopt}; + OpenIDJWKMetadata value{ + .public_key = der, .issuer = issuer, .constraint = std::nullopt}; + value.public_key = der; + const auto it = issuer_constraints.find(kid); if (it != issuer_constraints.end()) { @@ -218,7 +324,7 @@ namespace ccf keys_for_kid->begin(), keys_for_kid->end(), [&value](const auto& metadata) { - return metadata.cert == value.cert && + return metadata.public_key == value.public_key && metadata.issuer == value.issuer && metadata.constraint == value.constraint; }) != keys_for_kid->end()) diff --git a/src/node/rpc/member_frontend.h b/src/node/rpc/member_frontend.h index 05086a79eb71..4e490f30e61f 100644 --- a/src/node/rpc/member_frontend.h +++ b/src/node/rpc/member_frontend.h @@ -67,14 +67,6 @@ namespace ccf DECLARE_JSON_TYPE(JsBundle) DECLARE_JSON_REQUIRED_FIELDS(JsBundle, metadata, modules) - struct KeyIdInfo - { - JwtIssuer issuer; - ccf::crypto::Pem cert; - }; - DECLARE_JSON_TYPE(KeyIdInfo) - DECLARE_JSON_REQUIRED_FIELDS(KeyIdInfo, issuer, cert) - struct FullMemberDetails : public ccf::MemberDetails { ccf::crypto::Pem cert; @@ -1098,30 +1090,6 @@ namespace ccf "5.0.0", "POST /gov/recovery/members/{memberId}:recover") .install(); - using JWTKeyMap = std::map<JwtKeyId, std::vector<KeyIdInfo>>; - - auto get_jwt_keys = [this](auto& ctx, nlohmann::json&& body) { - auto keys = ctx.tx.ro(network.jwt_public_signing_keys_metadata); - JWTKeyMap kmap; - keys->foreach([&kmap](const auto& k, const auto& v) { - std::vector<KeyIdInfo> info; - for (const auto& metadata : v) - { - info.push_back(KeyIdInfo{ - metadata.issuer, ccf::crypto::cert_der_to_pem(metadata.cert)}); - } - kmap.emplace(k, std::move(info)); - return true; - }); - - return make_success(kmap); - }; - make_endpoint( - "/jwt_keys/all", HTTP_GET, json_adapter(get_jwt_keys), no_auth_required) - .set_auto_schema<void, JWTKeyMap>() - .set_openapi_deprecated_replaced("5.0.0", "POST /gov/service/jwk") - .install(); - auto post_proposals_js = [this](ccf::endpoints::EndpointContext& ctx) { std::optional<ccf::MemberCOSESign1AuthnIdentity> cose_auth_id = std::nullopt; diff --git a/src/service/network_tables.h b/src/service/network_tables.h index b7d92d338aee..45b3d7de316a 100644 --- a/src/service/network_tables.h +++ b/src/service/network_tables.h @@ -154,8 +154,11 @@ namespace ccf // const CACertBundlePEMs ca_cert_bundles = {Tables::CA_CERT_BUNDLE_PEMS}; const JwtIssuers jwt_issuers = {Tables::JWT_ISSUERS}; - const JwtPublicSigningKeys jwt_public_signing_keys_metadata = { + const JwtPublicSigningKeysMetadata jwt_public_signing_keys_metadata = { Tables::JWT_PUBLIC_SIGNING_KEYS_METADATA}; + const JwtPublicSigningKeysMetadataLegacy + legacy_jwt_public_signing_keys_metadata = { + Tables::Legacy::JWT_PUBLIC_SIGNING_KEYS_METADATA}; const Tables::Legacy::JwtPublicSigningKeys legacy_jwt_public_signing_keys = {Tables::Legacy::JWT_PUBLIC_SIGNING_KEYS}; const Tables::Legacy::JwtPublicSigningKeyIssuer @@ -168,6 +171,7 @@ namespace ccf ca_cert_bundles, jwt_issuers, jwt_public_signing_keys_metadata, + legacy_jwt_public_signing_keys_metadata, legacy_jwt_public_signing_keys, legacy_jwt_public_signing_key_issuer); } diff --git a/tests/infra/crypto.py b/tests/infra/crypto.py index 23fbc8039fe3..2947ff2567a9 100644 --- a/tests/infra/crypto.py +++ b/tests/infra/crypto.py @@ -307,10 +307,8 @@ def pub_key_pem_to_der(pem: str) -> bytes: return cert.public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo) -def create_jwt(body_claims: dict, key_priv_pem: str, key_id: str) -> str: - return jwt.encode( - body_claims, key_priv_pem, algorithm="RS256", headers={"kid": key_id} - ) +def create_jwt(body_claims: dict, key_priv_pem: str, key_id: str, alg="RS256") -> str: + return jwt.encode(body_claims, key_priv_pem, algorithm=alg, headers={"kid": key_id}) def cert_pem_to_der(pem: str) -> bytes: diff --git a/tests/infra/jwt_issuer.py b/tests/infra/jwt_issuer.py index 1882e57b02f5..9766220d923c 100644 --- a/tests/infra/jwt_issuer.py +++ b/tests/infra/jwt_issuer.py @@ -11,8 +11,23 @@ import json import time import uuid + from infra.log_capture import flush_info from loguru import logger as LOG +from enum import Enum +from cryptography.x509 import load_pem_x509_certificate +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import ec + + +class JwtAlg(Enum): + RS256 = "RS256" # RSA using SHA-256 + ES256 = "ES256" # ECDSA using P-256 and SHA-256 + + +class JwtAuthType(Enum): + CERT = 1 + KEY = 2 def make_bearer_header(jwt): @@ -107,17 +122,50 @@ def __exit__(self, exc_type, exc_value, traceback): self.stop() +def get_jwt_issuers(args, node): + with node.api_versioned_client(api_version=args.gov_api_version) as c: + r = c.get("/gov/service/jwk") + assert r.status_code == HTTPStatus.OK, r + body = r.body.json() + return body["issuers"] + + +def get_jwt_keys(args, node): + with node.api_versioned_client(api_version=args.gov_api_version) as c: + r = c.get("/gov/service/jwk") + assert r.status_code == HTTPStatus.OK, r + body = r.body.json() + return body["keys"] + + +def to_b64(number: int): + as_bytes = number.to_bytes((number.bit_length() + 7) // 8, "big") + return base64.b64encode(as_bytes).decode("ascii") + + class JwtIssuer: TEST_JWT_ISSUER_NAME = "https://example.issuer" TEST_CA_BUNDLE_NAME = "test_ca_bundle_name" - def _generate_cert(self, cn=None): - key_priv, key_pub = infra.crypto.generate_rsa_keypair(2048) + def _generate_auth_data(self, cn=None): + if self._alg == JwtAlg.RS256: + key_priv, key_pub = infra.crypto.generate_rsa_keypair(2048) + elif self._alg == JwtAlg.ES256: + key_priv, key_pub = infra.crypto.generate_ec_keypair(ec.SECP256R1) + else: + raise ValueError(f"Unsupported algorithm: {self._alg}") + cert = infra.crypto.generate_cert(key_priv, cn=cn) return (key_priv, key_pub), cert def __init__( - self, name=TEST_JWT_ISSUER_NAME, cert=None, refresh_interval=3, cn=None + self, + name=TEST_JWT_ISSUER_NAME, + cert=None, + refresh_interval=3, + cn=None, + auth_type=JwtAuthType.CERT, + alg=JwtAlg.RS256, ): self.name = name self.default_kid = f"{uuid.uuid4()}" @@ -126,7 +174,9 @@ def __init__( # Auto-refresh ON if issuer name starts with "https://" self.auto_refresh = self.name.startswith("https://") stripped_host = self.name[len("https://") :] if self.auto_refresh else None - (self.tls_priv, _), self.tls_cert = self._generate_cert( + self._auth_type = auth_type + self._alg = alg + (self.tls_priv, _), self.tls_cert = self._generate_auth_data( cn or stripped_host or name ) if not cert: @@ -134,6 +184,11 @@ def __init__( else: self.cert_pem = cert + @property + def public_key(self): + cert = load_pem_x509_certificate(self.cert_pem.encode(), default_backend()) + return cert.public_key() + @property def issuer_url(self): name = f"{self.name}" @@ -141,25 +196,53 @@ def issuer_url(self): name += f":{self.server.bind_port}" return name - def refresh_keys(self, kid=None): + def refresh_keys(self, kid=None, send_update=True): if not kid: self.default_kid = f"{uuid.uuid4()}" kid_ = kid or self.default_kid - (self.key_priv_pem, self.key_pub_pem), self.cert_pem = self._generate_cert() - if self.server: + (self.key_priv_pem, self.key_pub_pem), self.cert_pem = ( + self._generate_auth_data() + ) + if self.server and send_update: self.server.set_jwks(self.create_jwks(kid_)) - def _create_jwks(self, kid, test_invalid_is_key=False): - der_b64 = base64.b64encode( - infra.crypto.cert_pem_to_der(self.cert_pem) - if not test_invalid_is_key - else infra.crypto.pub_key_pem_to_der(self.key_pub_pem) - ).decode("ascii") + def _create_jwks_with_cert(self, kid): + der_b64 = base64.b64encode(infra.crypto.cert_pem_to_der(self.cert_pem)).decode( + "ascii" + ) return {"kty": "RSA", "kid": kid, "x5c": [der_b64], "issuer": self.name[::]} - def create_jwks(self, kid=None, test_invalid_is_key=False): + def _create_jwks_with_raw_key(self, kid): + pubkey = self.public_key + if self._alg == JwtAlg.RS256: + n = to_b64(pubkey.public_numbers().n) + e = to_b64(pubkey.public_numbers().e) + return {"kty": "RSA", "kid": kid, "n": n, "e": e, "issuer": self.name[::]} + elif self._alg == JwtAlg.ES256: + x = to_b64(pubkey.public_numbers().x) + y = to_b64(pubkey.public_numbers().y) + return { + "kty": "EC", + "kid": kid, + "x": x, + "y": y, + "crv": "P-256", + "issuer": self.name[::], + } + else: + raise ValueError(f"Unsupported algorithm: {self._alg}") + + def _create_jwks(self, kid): + if self._auth_type == JwtAuthType.KEY: + return self._create_jwks_with_raw_key(kid) + elif self._auth_type == JwtAuthType.CERT: + return self._create_jwks_with_cert(kid) + else: + raise ValueError(f"Unsupported auth type: {self._auth_type}") + + def create_jwks(self, kid=None): kid_ = kid or self.default_kid - return {"keys": [self._create_jwks(kid_, test_invalid_is_key)]} + return {"keys": [self._create_jwks(kid_)]} def create_jwks_for_kids(self, kids): jwks = {} @@ -217,7 +300,8 @@ def issue_jwt(self, kid=None, claims=None): claims["exp"] = now + 3600 if "iss" not in claims: claims["iss"] = self.name - return infra.crypto.create_jwt(claims, self.key_priv_pem, kid_) + + return infra.crypto.create_jwt(claims, self.key_priv_pem, kid_, self._alg.value) def wait_for_refresh(self, network, args, kid=None): timeout = self.refresh_interval * 3 @@ -237,10 +321,16 @@ def wait_for_refresh(self, network, args, kid=None): LOG.warning(body) keys = body["keys"] if kid_ in keys: - stored_cert = keys[kid_][0]["certificate"] - if self.cert_pem == stored_cert: - flush_info(logs) - return + if "publicKey" in keys[kid_][0]: + stored_key = keys[kid_][0]["publicKey"] + if self.key_pub_pem == stored_key: + flush_info(logs) + return + else: + stored_cert = keys[kid_][0]["certificate"] + if self.cert_pem == stored_cert: + flush_info(logs) + return time.sleep(0.1) else: with primary.client( diff --git a/tests/js-custom-authorization/custom_authorization.py b/tests/js-custom-authorization/custom_authorization.py index 2dbd949e4ec0..ea4d03924637 100644 --- a/tests/js-custom-authorization/custom_authorization.py +++ b/tests/js-custom-authorization/custom_authorization.py @@ -13,7 +13,7 @@ import base64 import json import time -import infra.jwt_issuer +from infra.jwt_issuer import JwtAlg, JwtAuthType, JwtIssuer, make_bearer_header import datetime import re import uuid @@ -111,7 +111,7 @@ def try_auth(primary, issuer, kid, iss, tid): LOG.info(f"Creating JWT with kid={kid} iss={iss} tenant={tid}") return c.get( "/app/jwt", - headers=infra.jwt_issuer.make_bearer_header( + headers=make_bearer_header( issuer.issue_jwt(kid, claims={"iss": iss, "tid": tid}) ), ) @@ -344,7 +344,7 @@ def create_keypair(local_id, valid_from, validity_days): def test_jwt_auth(network, args): primary, _ = network.find_nodes() - issuer = infra.jwt_issuer.JwtIssuer("https://example.issuer") + issuer = JwtIssuer("https://example.issuer") jwt_kid = "my_key_id" @@ -354,26 +354,26 @@ def test_jwt_auth(network, args): LOG.info("Calling jwt endpoint after storing keys") with primary.client("user0") as c: - r = c.get("/app/jwt", headers=infra.jwt_issuer.make_bearer_header("garbage")) + r = c.get("/app/jwt", headers=make_bearer_header("garbage")) assert r.status_code == HTTPStatus.UNAUTHORIZED, r.status_code assert "Malformed JWT" in parse_error_message(r), r jwt_mismatching_key_priv_pem, _ = infra.crypto.generate_rsa_keypair(2048) jwt = infra.crypto.create_jwt({}, jwt_mismatching_key_priv_pem, jwt_kid) - r = c.get("/app/jwt", headers=infra.jwt_issuer.make_bearer_header(jwt)) + r = c.get("/app/jwt", headers=make_bearer_header(jwt)) assert r.status_code == HTTPStatus.UNAUTHORIZED, r.status_code assert "JWT payload is missing required field" in parse_error_message(r), r r = c.get( "/app/jwt", - headers=infra.jwt_issuer.make_bearer_header(issuer.issue_jwt(jwt_kid)), + headers=make_bearer_header(issuer.issue_jwt(jwt_kid)), ) assert r.status_code == HTTPStatus.OK, r.status_code LOG.info("Calling JWT with too-late nbf") r = c.get( "/app/jwt", - headers=infra.jwt_issuer.make_bearer_header( + headers=make_bearer_header( issuer.issue_jwt(jwt_kid, claims={"nbf": time.time() + 60}) ), ) @@ -383,7 +383,7 @@ def test_jwt_auth(network, args): LOG.info("Calling JWT with too-early exp") r = c.get( "/app/jwt", - headers=infra.jwt_issuer.make_bearer_header( + headers=make_bearer_header( issuer.issue_jwt(jwt_kid, claims={"exp": time.time() - 60}) ), ) @@ -394,6 +394,37 @@ def test_jwt_auth(network, args): return network +@reqs.description("JWT authentication as by OpenID spec with raw public key") +def test_jwt_auth_raw_key(network, args): + primary, _ = network.find_nodes() + + for alg in [JwtAlg.RS256, JwtAlg.ES256]: + issuer = JwtIssuer("noautorefresh://issuer", alg=alg, auth_type=JwtAuthType.KEY) + jwt_kid = "my_key_id" + issuer.register(network, kid=jwt_kid) + + LOG.info("Calling jwt endpoint after storing keys") + with primary.client("user0") as c: + token = issuer.issue_jwt(jwt_kid) + r = c.get( + "/app/jwt", + headers=make_bearer_header(token), + ) + assert r.status_code == HTTPStatus.OK, r.status_code + + # Change client's key only, new token shouldn't pass validation. + issuer.refresh_keys(kid=jwt_kid, send_update=False) + token = issuer.issue_jwt(jwt_kid) + r = c.get( + "/app/jwt", + headers=make_bearer_header(token), + ) + assert r.status_code == HTTPStatus.UNAUTHORIZED, r.status_code + + network.consortium.remove_jwt_issuer(primary, issuer.name) + return network + + @reqs.description("JWT authentication as by MSFT Entra (single tenant)") def test_jwt_auth_msft_single_tenant(network, args): """For a specific tenant, only tokens with this issuer+tenant can auth.""" @@ -405,7 +436,7 @@ def test_jwt_auth_msft_single_tenant(network, args): "https://login.microsoftonline.com/9188050d-6c67-4c5b-b112-36a304b66da/v2.0" ) - issuer = infra.jwt_issuer.JwtIssuer(name="https://login.microsoftonline.com") + issuer = JwtIssuer(name="https://login.microsoftonline.com") jwt_kid = "my_key_id" set_issuer_with_a_key(primary, network, issuer, jwt_kid, ISSUER_TENANT) @@ -443,7 +474,7 @@ def test_jwt_auth_msft_multitenancy(network, args): ANOTHER_TENANT_ID = "deadbeef-6c67-4c5b-b112-36a304b66da" ISSUER_ANOTHER = f"https://login.microsoftonline.com/{ANOTHER_TENANT_ID}/v2.0" - issuer = infra.jwt_issuer.JwtIssuer(name="https://login.microsoftonline.com") + issuer = JwtIssuer(name="https://login.microsoftonline.com") jwt_kid_1 = "my_key_id_1" jwt_kid_2 = "my_key_id_2" @@ -520,8 +551,8 @@ def test_jwt_auth_msft_same_kids_different_issuers(network, args): ANOTHER_TENANT_ID = "deadbeef-6c67-4c5b-b112-36a304b66da" ISSUER_ANOTHER = f"https://login.microsoftonline.com/{ANOTHER_TENANT_ID}/v2.0" - issuer = infra.jwt_issuer.JwtIssuer(name=ISSUER_TENANT) - another = infra.jwt_issuer.JwtIssuer(name=ISSUER_ANOTHER) + issuer = JwtIssuer(name=ISSUER_TENANT) + another = JwtIssuer(name=ISSUER_ANOTHER) # Immitate same key sharing another.cert_pem, another.key_priv_pem = issuer.cert_pem, issuer.key_priv_pem @@ -582,7 +613,7 @@ def test_jwt_auth_msft_same_kids_overwrite_constraint(network, args): ANOTHER_TENANT_ID = "deadbeef-6c67-4c5b-b112-36a304b66da" ISSUER_ANOTHER = f"https://login.microsoftonline.com/{ANOTHER_TENANT_ID}/v2.0" - issuer = infra.jwt_issuer.JwtIssuer(name=ISSUER_TENANT) + issuer = JwtIssuer(name=ISSUER_TENANT) jwt_kid = "my_key_id" set_issuer_with_a_key(primary, network, issuer, jwt_kid, COMMNON_ISSUER) @@ -708,6 +739,7 @@ def run_authn(args): network.start_and_open(args) network = test_cert_auth(network, args) network = test_jwt_auth(network, args) + network = test_jwt_auth_raw_key(network, args) network = test_jwt_auth_msft_single_tenant(network, args) network = test_jwt_auth_msft_multitenancy(network, args) network = test_jwt_auth_msft_same_kids_different_issuers(network, args) diff --git a/tests/jwt_test.py b/tests/jwt_test.py index ef0e861fd3f6..8149f59bfb3c 100644 --- a/tests/jwt_test.py +++ b/tests/jwt_test.py @@ -4,6 +4,7 @@ import tempfile import json import time +import base64 import infra.network import infra.path import infra.proc @@ -12,33 +13,16 @@ import infra.e2e_args import infra.proposal import suite.test_requirements as reqs -import infra.jwt_issuer +from infra.jwt_issuer import get_jwt_issuers, get_jwt_keys from infra.runner import ConcurrentRunner import ca_certs import ccf.ledger from ccf.tx_id import TxID import infra.clients -import http from loguru import logger as LOG -def get_jwt_issuers(args, node): - with node.api_versioned_client(api_version=args.gov_api_version) as c: - r = c.get("/gov/service/jwk") - assert r.status_code == http.HTTPStatus.OK, r - body = r.body.json() - return body["issuers"] - - -def get_jwt_keys(args, node): - with node.api_versioned_client(api_version=args.gov_api_version) as c: - r = c.get("/gov/service/jwk") - assert r.status_code == http.HTTPStatus.OK, r - body = r.body.json() - return body["keys"] - - def set_issuer_with_keys(network, primary, issuer, kids): with tempfile.NamedTemporaryFile(prefix="ccf", mode="w+") as metadata_fp: json.dump({"issuer": issuer.name}, metadata_fp) @@ -213,7 +197,7 @@ def test_jwt_endpoint(network, args): assert kid in service_keys, service_keys assert service_keys[kid][0]["issuer"] == issuer.name assert service_keys[kid][0]["constraint"] == issuer.name - assert service_keys[kid][0]["certificate"] == issuer.cert_pem + assert service_keys[kid][0]["publicKey"] == issuer.key_pub_pem @reqs.description("JWT without key policy") @@ -246,7 +230,12 @@ def test_jwt_without_key_policy(network, args): LOG.info("Try to add a public key instead of a certificate") with tempfile.NamedTemporaryFile(prefix="ccf", mode="w+") as jwks_fp: - json.dump(issuer.create_jwks(kid, test_invalid_is_key=True), jwks_fp) + jwks = issuer.create_jwks(kid) + der_b64 = base64.b64encode( + infra.crypto.pub_key_pem_to_der(issuer.key_pub_pem) + ).decode("ascii") + jwks["keys"][0]["x5c"] = [der_b64] + json.dump(jwks, jwks_fp) jwks_fp.flush() try: network.consortium.set_jwt_public_signing_keys( @@ -266,9 +255,9 @@ def test_jwt_without_key_policy(network, args): ) keys = get_jwt_keys(args, primary) - stored_cert = keys[kid][0]["certificate"] + stored_key = keys[kid][0]["publicKey"] - assert stored_cert == issuer.cert_pem, "input cert is not equal to stored cert" + assert stored_key == issuer.key_pub_pem, "input key is not equal to stored key" LOG.info("Remove JWT issuer") network.consortium.remove_jwt_issuer(primary, issuer.name) @@ -285,9 +274,9 @@ def test_jwt_without_key_policy(network, args): network.consortium.set_jwt_issuer(primary, metadata_fp.name) keys = get_jwt_keys(args, primary) - stored_cert = keys[kid][0]["certificate"] + stored_key = keys[kid][0]["publicKey"] - assert stored_cert == issuer.cert_pem, "input cert is not equal to stored cert" + assert stored_key == issuer.key_pub_pem, "input key is not equal to stored key" return network @@ -320,18 +309,18 @@ def make_attested_cert(network, args): return pem -def check_kv_jwt_key_matches(args, network, kid, cert_pem): +def check_kv_jwt_key_matches(args, network, kid, key_pem): primary, _ = network.find_nodes() latest_jwt_signing_keys = get_jwt_keys(args, primary) - if cert_pem is None: + if key_pem is None: assert kid not in latest_jwt_signing_keys else: # Necessary to get an AssertionError if the key is not found yet, # when used from with_timeout() assert kid in latest_jwt_signing_keys - stored_cert = latest_jwt_signing_keys[kid][0]["certificate"] - assert stored_cert == cert_pem, "input cert is not equal to stored cert" + stored_key = latest_jwt_signing_keys[kid][0]["publicKey"] + assert stored_key == key_pem, "input cert is not equal to stored cert" def check_kv_jwt_keys_not_empty(args, network, issuer): @@ -405,7 +394,9 @@ def test_jwt_key_auto_refresh(network, args): LOG.info("Check that keys got refreshed") # Note: refresh interval is set to 1s, see network args below. with_timeout( - lambda: check_kv_jwt_key_matches(args, network, kid, issuer.cert_pem), + lambda: check_kv_jwt_key_matches( + args, network, kid, issuer.key_pub_pem + ), timeout=5, ) @@ -438,7 +429,7 @@ def check_has_failures(): with_timeout( lambda: check_kv_jwt_key_matches(args, network, kid, None), timeout=5 ) - check_kv_jwt_key_matches(args, network, kid2, issuer.cert_pem) + check_kv_jwt_key_matches(args, network, kid2, issuer.key_pub_pem) return network @@ -482,7 +473,9 @@ def test_jwt_key_auto_refresh_entries(network, args): LOG.info("Check that keys got refreshed") # Note: refresh interval is set to 1s, see network args below. with_timeout( - lambda: check_kv_jwt_key_matches(args, network, kid, issuer.cert_pem), + lambda: check_kv_jwt_key_matches( + args, network, kid, issuer.key_pub_pem + ), timeout=5, ) @@ -512,8 +505,10 @@ def test_jwt_key_auto_refresh_entries(network, args): for tx in chunk: txid = TxID(tx.gcm_header.view, tx.gcm_header.seqno) tables = tx.get_public_domain().get_tables() - if "public:ccf.gov.jwt.public_signing_keys_metadata" in tables: - pub_keys = tables["public:ccf.gov.jwt.public_signing_keys_metadata"] + if "public:ccf.gov.jwt.public_signing_keys_metadata_v2" in tables: + pub_keys = tables[ + "public:ccf.gov.jwt.public_signing_keys_metadata_v2" + ] if kid.encode() in pub_keys: if last_key_refresh is None: LOG.info(f"Refresh found for kid: {kid} at {txid}") @@ -567,7 +562,7 @@ def test_jwt_key_initial_refresh(network, args): # Auto-refresh interval has been set to a large value so that it doesn't happen within the timeout. # This is testing the one-off refresh after adding a new issuer. with_timeout( - lambda: check_kv_jwt_key_matches(args, network, kid, issuer.cert_pem), + lambda: check_kv_jwt_key_matches(args, network, kid, issuer.key_pub_pem), timeout=5, )