diff --git a/README.md b/README.md index 6e9d0346..bf955ea4 100644 --- a/README.md +++ b/README.md @@ -256,6 +256,35 @@ The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` ins ) ``` +### Keycloak Authentication + +The `KeycloakAuthentication` class can be used to connect to a Trino cluster that is configured with the [OAuth2 authentication type](https://trino.io/docs/current/security/oauth2.html) using an external OIDC identity provider (i.e Keycloak) + +It works by sending credentials to the OpenId identity provider and recieving a grant, then passing said grant to the Trino cluster secured using OAuth2 + +> [!WARNING] +> Client Authentication must be turned off (public access) as the flow does not send a client secret + +- DBAPI + + ```python + from trino.dbapi import connect + from trino.auth import KeycloakAuthentication + + conn = connect( + user="", + auth=KeycloakAuthentication( + username="", + password="", + keycloak_url="", + realm="", + client_id="", + ), + ... + ) + + ``` + ### Certificate authentication `CertificateAuthentication` class can be used to connect to Trino cluster configured with [certificate based authentication](https://trino.io/docs/current/security/certificate.html). `CertificateAuthentication` requires paths to a valid client certificate and private key. diff --git a/trino/auth.py b/trino/auth.py index c2155fd1..aee96501 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -31,6 +31,7 @@ from requests import Session from requests.auth import AuthBase from requests.auth import extract_cookies_to_jar +from requests.exceptions import JSONDecodeError import trino.logging from trino import exceptions @@ -47,6 +48,76 @@ def set_http_session(self, http_session: Session) -> Session: def get_exceptions(self) -> Tuple[Any, ...]: return tuple() + + +class KeycloakAuthentication(Authentication): + def __init__(self, username: str, password: str, keycloak_url: str, realm: str, client_id: str) -> None: + self._username = username + self._password = password + self._well_known_url = f"{keycloak_url.strip('/')}/realms/{realm}/.well-known/openid-configuration" + self._client_id = client_id + + def set_http_session(self, http_session: Session) -> Session: + open_id_configuration = http_session.get(self._well_known_url) + + open_id_configuration.raise_for_status() + + token_endpoint = open_id_configuration.json()["token_endpoint"] + + if token_endpoint is None: + raise exceptions.TrinoAuthError("token_endpoint not found in OpenID configuration") + + token_response = http_session.post( + url=token_endpoint, + data={ + "grant_type": "password", + "client_id": self._client_id, + "username": self._username, + "password": self._password, + "scope": "openid", + }, + ) + + try: + error_response = token_response.json().get("error") + + if error_response == "invalid_grant": + raise exceptions.TrinoAuthError("Invalid username or password") + + if error_response == "invalid_client": + raise exceptions.TrinoAuthError("Invalid client_id") + except JSONDecodeError: + pass + + token_response.raise_for_status() + + http_session.auth = KeycloakTokenBearer( + token=token_response.json()["access_token"] + ) + return http_session + + def get_exceptions(self) -> Tuple[Any, ...]: + return () + + def __eq__(self, other: object) -> bool: + if not isinstance(other, KeycloakAuthentication): + return False + + return ( + self._username == other._username + and self._password == other._password + and self._well_known_url == other._well_known_url + and self._client_id == other._client_id + ) + + +class KeycloakTokenBearer(AuthBase): + def __init__(self, token: str) -> None: + self._token = token + + def __call__(self, r: PreparedRequest) -> PreparedRequest: + r.headers["Authorization"] = f"Bearer {self._token}" + return r class KerberosAuthentication(Authentication):