Skip to content

Commit 7529f95

Browse files
tommytroentronghnkimtore
committed
feat: add support for form-urlencoding in request
* refactor: move handlers and types into separate types Co-authored-by: tronghn <[email protected]> Co-authored-by: kimtore <[email protected]>
1 parent 7c21188 commit 7529f95

File tree

3 files changed

+234
-201
lines changed

3 files changed

+234
-201
lines changed

src/handlers.rs

+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
use axum::extract::{FromRequest, Request};
2+
use axum::{async_trait, Form, RequestExt};
3+
4+
use crate::config::Config;
5+
use crate::identity_provider::*;
6+
use crate::types;
7+
use crate::types::{IdentityProvider, IntrospectRequest, TokenRequest, TokenResponse};
8+
use axum::extract::State;
9+
use axum::http::StatusCode;
10+
use axum::response::{IntoResponse, Response};
11+
use axum::Json;
12+
use jsonwebtoken as jwt;
13+
use jsonwebtoken::Algorithm::RS512;
14+
use jsonwebtoken::DecodingKey;
15+
use log::error;
16+
use std::sync::Arc;
17+
use axum::http::header::CONTENT_TYPE;
18+
use thiserror::Error;
19+
use tokio::sync::RwLock;
20+
21+
#[axum::debug_handler]
22+
pub async fn token(State(state): State<HandlerState>, JsonOrForm(request): JsonOrForm<TokenRequest>) -> Result<impl IntoResponse, ApiError> {
23+
let endpoint = state.token_endpoint(&request.identity_provider).await;
24+
let params = state.token_request(&request.identity_provider, request.target).await;
25+
26+
let client = reqwest::Client::new();
27+
let request_builder = client.post(endpoint)
28+
.header("accept", "application/json")
29+
.form(&params);
30+
31+
let response = request_builder
32+
.send()
33+
.await
34+
.map_err(ApiError::UpstreamRequest)?
35+
;
36+
37+
if response.status() >= StatusCode::BAD_REQUEST {
38+
let err: types::ErrorResponse = response.json().await.map_err(ApiError::JSON)?;
39+
return Err(ApiError::Upstream(err));
40+
}
41+
42+
let res: TokenResponse = response
43+
.json().await
44+
.inspect_err(|err| {
45+
error!("Identity provider returned invalid JSON: {:?}", err)
46+
})
47+
.map_err(ApiError::JSON)?
48+
;
49+
50+
Ok((StatusCode::OK, Json(res)))
51+
}
52+
53+
pub async fn introspection(State(state): State<HandlerState>, Json(request): Json<IntrospectRequest>) -> Result<impl IntoResponse, ApiError> {
54+
// Need to decode the token to get the issuer before we actually validate it.
55+
let mut validation = jwt::Validation::new(RS512);
56+
validation.validate_exp = false;
57+
validation.insecure_disable_signature_validation();
58+
let key = DecodingKey::from_secret(&[]);
59+
let token_data = jwt::decode::<Claims>(&request.token, &key, &validation).map_err(ApiError::Validate)?;
60+
61+
let claims = match token_data.claims.iss {
62+
s if s == state.cfg.maskinporten_issuer => state.maskinporten.write().await.introspect(request.token).await,
63+
_ => panic!("Unknown issuer: {}", token_data.claims.iss),
64+
};
65+
66+
Ok((StatusCode::OK, Json(claims)))
67+
}
68+
69+
#[derive(Clone)]
70+
pub struct HandlerState {
71+
pub cfg: Config,
72+
pub maskinporten: Arc<RwLock<Maskinporten>>,
73+
// TODO: other providers
74+
}
75+
76+
impl HandlerState {
77+
async fn token_request(&self, identity_provider: &IdentityProvider, target: String) -> Box<dyn erased_serde::Serialize + Send> {
78+
match identity_provider {
79+
IdentityProvider::EntraID => todo!(),
80+
IdentityProvider::TokenX => todo!(),
81+
IdentityProvider::Maskinporten => {
82+
Box::new(self.maskinporten.read().await.token_request(target))
83+
}
84+
}
85+
}
86+
87+
async fn token_endpoint(&self, identity_provider: &IdentityProvider) -> String {
88+
match identity_provider {
89+
IdentityProvider::EntraID => todo!(),
90+
IdentityProvider::TokenX => todo!(),
91+
IdentityProvider::Maskinporten => {
92+
self.maskinporten.read().await.token_endpoint()
93+
}
94+
}
95+
}
96+
}
97+
98+
#[derive(Debug, Error)]
99+
pub enum ApiError {
100+
#[error("identity provider error: {0}")]
101+
UpstreamRequest(reqwest::Error),
102+
103+
#[error("upstream error: {0}")]
104+
Upstream(types::ErrorResponse),
105+
106+
#[error("invalid JSON in token response: {0}")]
107+
JSON(reqwest::Error),
108+
109+
#[error("invalid token: {0}")]
110+
Validate(jwt::errors::Error),
111+
}
112+
113+
impl IntoResponse for ApiError {
114+
fn into_response(self) -> Response {
115+
match &self {
116+
ApiError::UpstreamRequest(err) => {
117+
(err.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), self.to_string())
118+
}
119+
ApiError::JSON(_) => {
120+
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
121+
}
122+
ApiError::Upstream(_err) => {
123+
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
124+
}
125+
ApiError::Validate(_) => {
126+
(StatusCode::BAD_REQUEST, self.to_string())
127+
}
128+
}.into_response()
129+
}
130+
}
131+
132+
#[derive(serde::Deserialize)]
133+
struct Claims {
134+
iss: String,
135+
}
136+
137+
pub struct JsonOrForm<T>(T);
138+
139+
#[async_trait]
140+
impl<S, T> FromRequest<S> for JsonOrForm<T>
141+
where
142+
S: Send + Sync,
143+
Json<T>: FromRequest<()>,
144+
Form<T>: FromRequest<()>,
145+
T: 'static,
146+
{
147+
type Rejection = Response;
148+
149+
async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
150+
let content_type_header = req.headers().get(CONTENT_TYPE);
151+
let content_type = content_type_header.and_then(|value| value.to_str().ok());
152+
153+
if let Some(content_type) = content_type {
154+
if content_type.starts_with("application/json") {
155+
let Json(payload) = req.extract().await.map_err(IntoResponse::into_response)?;
156+
return Ok(Self(payload));
157+
}
158+
159+
if content_type.starts_with("application/x-www-form-urlencoded") {
160+
let Form(payload) = req.extract().await.map_err(IntoResponse::into_response)?;
161+
return Ok(Self(payload));
162+
}
163+
}
164+
165+
Err(StatusCode::UNSUPPORTED_MEDIA_TYPE.into_response())
166+
}
167+
}

src/main.rs

+2-201
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
pub mod identity_provider;
22
pub mod jwks;
3+
pub mod handlers;
4+
pub mod types;
35

46
use std::sync::Arc;
57
use crate::config::Config;
@@ -79,205 +81,4 @@ async fn main() {
7981
axum::serve(listener, app).await.unwrap();
8082
}
8183

82-
pub mod handlers {
83-
use std::sync::Arc;
84-
use crate::config::Config;
85-
use crate::identity_provider::*;
86-
use crate::types;
87-
use crate::types::{IdentityProvider, IntrospectRequest, TokenRequest, TokenResponse};
88-
use axum::extract::State;
89-
use axum::http::StatusCode;
90-
use axum::response::{IntoResponse, Response};
91-
use axum::Json;
92-
use jsonwebtoken as jwt;
93-
use jsonwebtoken::Algorithm::RS512;
94-
use jsonwebtoken::DecodingKey;
95-
use log::error;
96-
use thiserror::Error;
97-
use tokio::sync::{RwLock};
98-
99-
#[derive(Debug, Error)]
100-
pub enum ApiError {
101-
#[error("identity provider error: {0}")]
102-
UpstreamRequest(reqwest::Error),
103-
104-
#[error("upstream error: {0}")]
105-
Upstream(types::ErrorResponse),
106-
107-
#[error("invalid JSON in token response: {0}")]
108-
JSON(reqwest::Error),
109-
110-
#[error("invalid token: {0}")]
111-
Validate(jwt::errors::Error),
112-
}
113-
114-
impl IntoResponse for ApiError {
115-
fn into_response(self) -> Response {
116-
match &self {
117-
ApiError::UpstreamRequest(err) => {
118-
(err.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), self.to_string())
119-
}
120-
ApiError::JSON(_) => {
121-
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
122-
}
123-
ApiError::Upstream(_err) => {
124-
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
125-
}
126-
ApiError::Validate(_) => {
127-
(StatusCode::BAD_REQUEST, self.to_string())
128-
}
129-
}.into_response()
130-
}
131-
}
132-
133-
#[derive(Clone)]
134-
pub struct HandlerState {
135-
pub cfg: Config,
136-
pub maskinporten: Arc<RwLock<Maskinporten>>,
137-
// TODO: other providers
138-
}
139-
140-
impl HandlerState {
141-
async fn token_request(&self, identity_provider: &IdentityProvider, target: String ) -> Box<dyn erased_serde::Serialize + Send> {
142-
match identity_provider {
143-
IdentityProvider::EntraID => todo!(),
144-
IdentityProvider::TokenX => todo!(),
145-
IdentityProvider::Maskinporten => {
146-
Box::new(self.maskinporten.read().await.token_request(target))
147-
},
148-
}
149-
}
150-
151-
async fn token_endpoint(&self, identity_provider: &IdentityProvider) -> String {
152-
match identity_provider {
153-
IdentityProvider::EntraID => todo!(),
154-
IdentityProvider::TokenX => todo!(),
155-
IdentityProvider::Maskinporten => {
156-
self.maskinporten.read().await.token_endpoint()
157-
},
158-
}
159-
}
160-
}
161-
162-
#[axum::debug_handler]
163-
pub async fn token(State(state): State<HandlerState>, Json(request): Json<TokenRequest>) -> Result<impl IntoResponse, ApiError> {
164-
let endpoint = state.token_endpoint(&request.identity_provider).await;
165-
let params = state.token_request(&request.identity_provider, request.target).await;
166-
167-
let client = reqwest::Client::new();
168-
let request_builder = client.post(endpoint)
169-
.header("accept", "application/json")
170-
.form(&params);
171-
172-
let response = request_builder
173-
.send()
174-
.await
175-
.map_err(ApiError::UpstreamRequest)?
176-
;
177-
178-
if response.status() >= StatusCode::BAD_REQUEST {
179-
let err: types::ErrorResponse = response.json().await.map_err(ApiError::JSON)?;
180-
return Err(ApiError::Upstream(err));
181-
}
182-
183-
let res: TokenResponse = response
184-
.json().await
185-
.inspect_err(|err| {
186-
error!("Identity provider returned invalid JSON: {:?}", err)
187-
})
188-
.map_err(ApiError::JSON)?
189-
;
190-
191-
Ok((StatusCode::OK, Json(res)))
192-
}
193-
194-
pub async fn introspection(State(state): State<HandlerState>, Json(request): Json<IntrospectRequest>) -> Result<impl IntoResponse, ApiError> {
195-
// Need to decode the token to get the issuer before we actually validate it.
196-
let mut validation = jwt::Validation::new(RS512);
197-
validation.validate_exp = false;
198-
validation.insecure_disable_signature_validation();
199-
let key = DecodingKey::from_secret(&[]);
200-
let token_data = jwt::decode::<Claims>(&request.token, &key, &validation).map_err(ApiError::Validate)?;
201-
202-
let claims = match token_data.claims.iss {
203-
s if s == state.cfg.maskinporten_issuer => state.maskinporten.write().await.introspect(request.token).await,
204-
_ => panic!("Unknown issuer: {}", token_data.claims.iss),
205-
};
206-
207-
Ok((StatusCode::OK, Json(claims)))
208-
}
209-
210-
#[derive(serde::Deserialize)]
211-
struct Claims {
212-
iss: String,
213-
}
214-
}
215-
216-
pub mod types {
217-
use serde::{Deserialize, Serialize};
218-
use std::fmt::{Display, Formatter};
219-
220-
/// This is an upstream RFCXXXX token response.
221-
#[derive(Serialize, Deserialize)]
222-
pub struct TokenResponse {
223-
pub access_token: String,
224-
pub token_type: TokenType,
225-
#[serde(rename = "expires_in")]
226-
pub expires_in_seconds: usize,
227-
}
228-
229-
#[derive(Deserialize, Debug, Clone)]
230-
pub struct ErrorResponse {
231-
pub error: String,
232-
#[serde(rename = "error_description")]
233-
pub description: String,
234-
}
235-
236-
impl Display for ErrorResponse {
237-
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
238-
write!(f, "{}: {}", self.error, self.description)
239-
}
240-
}
241-
242-
/// This is the token request sent to our identity provider.
243-
/// TODO: hard coded parameters that only works with Maskinporten for now.
244-
#[derive(Serialize)]
245-
pub struct ClientTokenRequest {
246-
pub grant_type: String,
247-
pub assertion: String,
248-
}
249-
250-
/// For forwards API compatibility. Token type is always Bearer,
251-
/// but this might change in the future.
252-
#[derive(Deserialize, Serialize)]
253-
pub enum TokenType {
254-
Bearer
255-
}
256-
257-
/// This is a token request that comes from the application we are serving.
258-
#[derive(Deserialize)]
259-
pub struct TokenRequest {
260-
pub target: String, // typically <cluster>:<namespace>:<app>
261-
pub identity_provider: IdentityProvider,
262-
#[serde(skip_serializing_if = "Option::is_none")]
263-
pub user_token: Option<String>,
264-
#[serde(skip_serializing_if = "Option::is_none")]
265-
pub force: Option<bool>,
266-
}
267-
268-
#[derive(Deserialize)]
269-
pub struct IntrospectRequest {
270-
pub token: String,
271-
}
272-
273-
#[derive(Deserialize, Serialize)]
274-
pub enum IdentityProvider {
275-
#[serde(rename = "entra")]
276-
EntraID,
277-
#[serde(rename = "tokenx")]
278-
TokenX,
279-
#[serde(rename = "maskinporten")]
280-
Maskinporten,
281-
}
282-
}
28384

0 commit comments

Comments
 (0)