-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replace python-jose with pyjwt (#1875)
* Replace python-jose with pyjwt. * Replace non-existent get_unverified_claims function * Change Exception to handle JWT-specific errors * Convert the public key to PEM format * Add pem format tests * Another test, plus autherror fixes --------- Co-authored-by: Pamela Fox <[email protected]>
- Loading branch information
Showing
4 changed files
with
207 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,20 @@ | ||
import base64 | ||
import json | ||
import re | ||
from datetime import datetime, timedelta | ||
|
||
import aiohttp | ||
import jwt | ||
import pytest | ||
from azure.core.credentials import AzureKeyCredential | ||
from azure.search.documents.aio import SearchClient | ||
from azure.search.documents.indexes.models import SearchField, SearchIndex | ||
from cryptography.hazmat.primitives import serialization | ||
from cryptography.hazmat.primitives.asymmetric import rsa | ||
|
||
from core.authentication import AuthenticationHelper, AuthError | ||
|
||
from .mocks import MockAsyncPageIterator | ||
from .mocks import MockAsyncPageIterator, MockResponse | ||
|
||
MockSearchIndex = SearchIndex( | ||
name="test", | ||
|
@@ -40,6 +47,36 @@ def create_search_client(): | |
return SearchClient(endpoint="", index_name="", credential=AzureKeyCredential("")) | ||
|
||
|
||
def create_mock_jwt(kid="mock_kid", oid="OID_X"): | ||
# Create a payload with necessary claims | ||
payload = { | ||
"iss": "https://login.microsoftonline.com/TENANT_ID/v2.0", | ||
"sub": "AaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaA", | ||
"aud": "SERVER_APP", | ||
"exp": int((datetime.utcnow() + timedelta(hours=1)).timestamp()), | ||
"iat": int(datetime.utcnow().timestamp()), | ||
"nbf": int(datetime.utcnow().timestamp()), | ||
"name": "John Doe", | ||
"oid": oid, | ||
"preferred_username": "[email protected]", | ||
"rh": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA.", | ||
"tid": "22222222-2222-2222-2222-222222222222", | ||
"uti": "AbCdEfGhIjKlMnOp-ABCDEFG", | ||
"ver": "2.0", | ||
} | ||
|
||
# Create a header | ||
header = {"kid": kid, "alg": "RS256", "typ": "JWT"} | ||
|
||
# Create a mock private key (for signing) | ||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) | ||
|
||
# Create the JWT | ||
token = jwt.encode(payload, private_key, algorithm="RS256", headers=header) | ||
|
||
return token, private_key.public_key(), payload | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_get_auth_claims_success(mock_confidential_client_success, mock_validate_token_success): | ||
helper = create_authentication_helper() | ||
|
@@ -479,3 +516,136 @@ async def mock_search(self, *args, **kwargs): | |
) | ||
assert filter is None | ||
assert called_search is False | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_create_pem_format(mock_confidential_client_success, mock_validate_token_success): | ||
helper = create_authentication_helper() | ||
mock_token, public_key, payload = create_mock_jwt(oid="OID_X") | ||
_, other_public_key, _ = create_mock_jwt(oid="OID_Y") | ||
mock_jwks = { | ||
"keys": [ | ||
# Include a key with a different KID to ensure the correct key is selected | ||
{ | ||
"kty": "RSA", | ||
"kid": "other_mock_kid", | ||
"use": "sig", | ||
"n": base64.urlsafe_b64encode( | ||
other_public_key.public_numbers().n.to_bytes( | ||
(other_public_key.public_numbers().n.bit_length() + 7) // 8, byteorder="big" | ||
) | ||
) | ||
.decode("utf-8") | ||
.rstrip("="), | ||
"e": base64.urlsafe_b64encode( | ||
other_public_key.public_numbers().e.to_bytes( | ||
(other_public_key.public_numbers().e.bit_length() + 7) // 8, byteorder="big" | ||
) | ||
) | ||
.decode("utf-8") | ||
.rstrip("="), | ||
}, | ||
{ | ||
"kty": "RSA", | ||
"kid": "mock_kid", | ||
"use": "sig", | ||
"n": base64.urlsafe_b64encode( | ||
public_key.public_numbers().n.to_bytes( | ||
(public_key.public_numbers().n.bit_length() + 7) // 8, byteorder="big" | ||
) | ||
) | ||
.decode("utf-8") | ||
.rstrip("="), | ||
"e": base64.urlsafe_b64encode( | ||
public_key.public_numbers().e.to_bytes( | ||
(public_key.public_numbers().e.bit_length() + 7) // 8, byteorder="big" | ||
) | ||
) | ||
.decode("utf-8") | ||
.rstrip("="), | ||
}, | ||
] | ||
} | ||
|
||
pem_key = await helper.create_pem_format(mock_jwks, mock_token) | ||
|
||
# Assert that the result is bytes | ||
assert isinstance(pem_key, bytes), "create_pem_format should return bytes" | ||
|
||
# Convert bytes to string for regex matching | ||
pem_str = pem_key.decode("utf-8") | ||
|
||
# Assert that the key starts and ends with the correct markers | ||
assert pem_str.startswith("-----BEGIN PUBLIC KEY-----"), "PEM key should start with the correct marker" | ||
assert pem_str.endswith("-----END PUBLIC KEY-----\n"), "PEM key should end with the correct marker" | ||
|
||
# Assert that the format matches the structure of a PEM key | ||
pem_regex = r"^-----BEGIN PUBLIC KEY-----\n([A-Za-z0-9+/\n]+={0,2})\n-----END PUBLIC KEY-----\n$" | ||
assert re.match(pem_regex, pem_str), "PEM key format is incorrect" | ||
|
||
# Verify that the key can be used to decode the token | ||
try: | ||
decoded = jwt.decode( | ||
mock_token, key=pem_key, algorithms=["RS256"], audience=payload["aud"], issuer=payload["iss"] | ||
) | ||
assert decoded["oid"] == payload["oid"], "Decoded token should contain correct OID" | ||
except Exception as e: | ||
pytest.fail(f"jwt.decode raised an unexpected exception: {str(e)}") | ||
|
||
# Try to load the key using cryptography library to ensure it's a valid PEM format | ||
try: | ||
loaded_public_key = serialization.load_pem_public_key(pem_key) | ||
assert isinstance(loaded_public_key, rsa.RSAPublicKey), "Loaded key should be an RSA public key" | ||
except Exception as e: | ||
pytest.fail(f"Failed to load PEM key: {str(e)}") | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_validate_access_token(monkeypatch, mock_confidential_client_success): | ||
mock_token, public_key, payload = create_mock_jwt(oid="OID_X") | ||
|
||
def mock_get(*args, **kwargs): | ||
return MockResponse( | ||
status=200, | ||
text=json.dumps( | ||
{ | ||
"keys": [ | ||
{ | ||
"kty": "RSA", | ||
"use": "sig", | ||
"kid": "23nt", | ||
"x5t": "23nt", | ||
"n": "hu2SJ", | ||
"e": "AQAB", | ||
"x5c": ["MIIC/jCC"], | ||
"issuer": "https://login.microsoftonline.com/TENANT_ID/v2.0", | ||
}, | ||
{ | ||
"kty": "RSA", | ||
"use": "sig", | ||
"kid": "MGLq", | ||
"x5t": "MGLq", | ||
"n": "yfNcG8", | ||
"e": "AQAB", | ||
"x5c": ["MIIC/jCC"], | ||
"issuer": "https://login.microsoftonline.com/TENANT_ID/v2.0", | ||
}, | ||
] | ||
} | ||
), | ||
) | ||
|
||
monkeypatch.setattr(aiohttp.ClientSession, "get", mock_get) | ||
|
||
def mock_decode(*args, **kwargs): | ||
return payload | ||
|
||
monkeypatch.setattr(jwt, "decode", mock_decode) | ||
|
||
async def mock_create_pem_format(*args, **kwargs): | ||
return public_key | ||
|
||
monkeypatch.setattr(AuthenticationHelper, "create_pem_format", mock_create_pem_format) | ||
|
||
helper = create_authentication_helper() | ||
await helper.validate_access_token(mock_token) |