Skip to content

Commit 3ef0e81

Browse files
tommytroentronghnkimtore
committed
style: use cargo fmt for formatting
Co-authored-by: tronghn <[email protected]> Co-authored-by: kimtore <[email protected]>
1 parent be46f9b commit 3ef0e81

File tree

5 files changed

+91
-67
lines changed

5 files changed

+91
-67
lines changed

src/handlers.rs

+44-32
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::identity_provider::*;
66
use crate::types;
77
use crate::types::{IdentityProvider, IntrospectRequest, TokenRequest, TokenResponse};
88
use axum::extract::State;
9+
use axum::http::header::CONTENT_TYPE;
910
use axum::http::StatusCode;
1011
use axum::response::{IntoResponse, Response};
1112
use axum::Json;
@@ -14,52 +15,65 @@ use jsonwebtoken::Algorithm::RS512;
1415
use jsonwebtoken::DecodingKey;
1516
use log::error;
1617
use std::sync::Arc;
17-
use axum::http::header::CONTENT_TYPE;
1818
use thiserror::Error;
1919
use tokio::sync::RwLock;
2020

2121
#[axum::debug_handler]
22-
pub async fn token(State(state): State<HandlerState>, JsonOrForm(request): JsonOrForm<TokenRequest>) -> Result<impl IntoResponse, ApiError> {
22+
pub async fn token(
23+
State(state): State<HandlerState>,
24+
JsonOrForm(request): JsonOrForm<TokenRequest>,
25+
) -> Result<impl IntoResponse, ApiError> {
2326
let endpoint = state.token_endpoint(&request.identity_provider).await;
24-
let params = state.token_request(&request.identity_provider, request.target).await;
27+
let params = state
28+
.token_request(&request.identity_provider, request.target)
29+
.await;
2530

2631
let client = reqwest::Client::new();
27-
let request_builder = client.post(endpoint)
32+
let request_builder = client
33+
.post(endpoint)
2834
.header("accept", "application/json")
2935
.form(&params);
3036

3137
let response = request_builder
3238
.send()
3339
.await
34-
.map_err(ApiError::UpstreamRequest)?
35-
;
40+
.map_err(ApiError::UpstreamRequest)?;
3641

3742
if response.status() >= StatusCode::BAD_REQUEST {
3843
let err: types::ErrorResponse = response.json().await.map_err(ApiError::JSON)?;
3944
return Err(ApiError::Upstream(err));
4045
}
4146

4247
let res: TokenResponse = response
43-
.json().await
44-
.inspect_err(|err| {
45-
error!("Identity provider returned invalid JSON: {:?}", err)
46-
})
47-
.map_err(ApiError::JSON)?
48-
;
48+
.json()
49+
.await
50+
.inspect_err(|err| error!("Identity provider returned invalid JSON: {:?}", err))
51+
.map_err(ApiError::JSON)?;
4952

5053
Ok((StatusCode::OK, Json(res)))
5154
}
5255

53-
pub async fn introspection(State(state): State<HandlerState>, Json(request): Json<IntrospectRequest>) -> Result<impl IntoResponse, ApiError> {
56+
pub async fn introspection(
57+
State(state): State<HandlerState>,
58+
Json(request): Json<IntrospectRequest>,
59+
) -> Result<impl IntoResponse, ApiError> {
5460
// Need to decode the token to get the issuer before we actually validate it.
5561
let mut validation = jwt::Validation::new(RS512);
5662
validation.validate_exp = false;
5763
validation.insecure_disable_signature_validation();
5864
let key = DecodingKey::from_secret(&[]);
59-
let token_data = jwt::decode::<Claims>(&request.token, &key, &validation).map_err(ApiError::Validate)?;
65+
let token_data =
66+
jwt::decode::<Claims>(&request.token, &key, &validation).map_err(ApiError::Validate)?;
6067

6168
let claims = match token_data.claims.iss {
62-
s if s == state.cfg.maskinporten_issuer => state.maskinporten.write().await.introspect(request.token).await,
69+
s if s == state.cfg.maskinporten_issuer => {
70+
state
71+
.maskinporten
72+
.write()
73+
.await
74+
.introspect(request.token)
75+
.await
76+
}
6377
_ => panic!("Unknown issuer: {}", token_data.claims.iss),
6478
};
6579

@@ -74,7 +88,11 @@ pub struct HandlerState {
7488
}
7589

7690
impl HandlerState {
77-
async fn token_request(&self, identity_provider: &IdentityProvider, target: String) -> Box<dyn erased_serde::Serialize + Send> {
91+
async fn token_request(
92+
&self,
93+
identity_provider: &IdentityProvider,
94+
target: String,
95+
) -> Box<dyn erased_serde::Serialize + Send> {
7896
match identity_provider {
7997
IdentityProvider::EntraID => todo!(),
8098
IdentityProvider::TokenX => todo!(),
@@ -88,9 +106,7 @@ impl HandlerState {
88106
match identity_provider {
89107
IdentityProvider::EntraID => todo!(),
90108
IdentityProvider::TokenX => todo!(),
91-
IdentityProvider::Maskinporten => {
92-
self.maskinporten.read().await.token_endpoint()
93-
}
109+
IdentityProvider::Maskinporten => self.maskinporten.read().await.token_endpoint(),
94110
}
95111
}
96112
}
@@ -113,19 +129,15 @@ pub enum ApiError {
113129
impl IntoResponse for ApiError {
114130
fn into_response(self) -> Response {
115131
match &self {
116-
ApiError::UpstreamRequest(err) => {
117-
(err.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), self.to_string())
118-
}
119-
ApiError::JSON(_) => {
120-
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
121-
}
122-
ApiError::Upstream(_err) => {
123-
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string())
124-
}
125-
ApiError::Validate(_) => {
126-
(StatusCode::BAD_REQUEST, self.to_string())
127-
}
128-
}.into_response()
132+
ApiError::UpstreamRequest(err) => (
133+
err.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
134+
self.to_string(),
135+
),
136+
ApiError::JSON(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
137+
ApiError::Upstream(_err) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
138+
ApiError::Validate(_) => (StatusCode::BAD_REQUEST, self.to_string()),
139+
}
140+
.into_response()
129141
}
130142
}
131143

src/identity_provider.rs

+13-9
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ use std::collections::HashMap;
99
pub trait Provider<T: Serialize> {
1010
fn token_request(&self, target: String) -> T;
1111
fn token_endpoint(&self) -> String;
12-
fn introspect(&mut self, token: String) -> impl std::future::Future<Output=HashMap<String, Value>> + Send;
12+
fn introspect(
13+
&mut self,
14+
token: String,
15+
) -> impl std::future::Future<Output = HashMap<String, Value>> + Send;
1316
}
1417

1518
#[derive(Clone)]
@@ -68,7 +71,10 @@ pub struct MaskinportenTokenRequest {
6871

6972
impl Provider<MaskinportenTokenRequest> for Maskinporten {
7073
fn token_request(&self, target: String) -> MaskinportenTokenRequest {
71-
let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs();
74+
let now = std::time::SystemTime::now()
75+
.duration_since(std::time::UNIX_EPOCH)
76+
.unwrap()
77+
.as_secs();
7278
let jti = uuid::Uuid::new_v4();
7379

7480
let claims = AssertionClaims {
@@ -80,11 +86,7 @@ impl Provider<MaskinportenTokenRequest> for Maskinporten {
8086
aud: self.cfg.maskinporten_issuer.to_string(),
8187
};
8288

83-
let token = jwt::encode(
84-
&self.client_assertion_header,
85-
&claims,
86-
&self.private_jwk,
87-
).unwrap();
89+
let token = jwt::encode(&self.client_assertion_header, &claims, &self.private_jwk).unwrap();
8890

8991
MaskinportenTokenRequest {
9092
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer".to_string(),
@@ -97,15 +99,17 @@ impl Provider<MaskinportenTokenRequest> for Maskinporten {
9799
}
98100

99101
async fn introspect(&mut self, token: String) -> HashMap<String, Value> {
100-
self.upstream_jwks.validate(&token).await
102+
self.upstream_jwks
103+
.validate(&token)
104+
.await
101105
.map(|mut hashmap| {
102106
hashmap.insert("active".to_string(), Value::Bool(true));
103107
hashmap
104108
})
105109
.unwrap_or_else(|err| {
106110
HashMap::from([
107111
("active".to_string(), Value::Bool(false)),
108-
("error".to_string(), Value::String(format!("{:?}", err)))
112+
("error".to_string(), Value::String(format!("{:?}", err))),
109113
])
110114
})
111115
}

src/jwks.rs

+15-15
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ impl Jwks {
3030
}
3131

3232
let client = reqwest::Client::new();
33-
let request_builder = client.get(endpoint)
34-
.header("accept", "application/json");
33+
let request_builder = client.get(endpoint).header("accept", "application/json");
3534

3635
let response: Response = request_builder
37-
.send().await
36+
.send()
37+
.await
3838
.map_err(Error::Fetch)?
39-
.json().await
40-
.map_err(Error::JsonDecode)?
41-
;
39+
.json()
40+
.await
41+
.map_err(Error::JsonDecode)?;
4242

4343
let mut keys: HashMap<String, jwk::JsonWebKey> = HashMap::new();
4444
for key in response.keys {
@@ -61,10 +61,7 @@ impl Jwks {
6161

6262
/// Check a JWT against a JWKS.
6363
/// Returns the JWT's claims on success.
64-
pub async fn validate(
65-
&mut self,
66-
token: &str,
67-
) -> Result<HashMap<String, Value>, Error> {
64+
pub async fn validate(&mut self, token: &str) -> Result<HashMap<String, Value>, Error> {
6865
let alg = jwt::Algorithm::RS256;
6966
let mut validation = jwt::Validation::new(alg);
7067
validation.set_required_spec_claims(&["iss", "exp", "iat"]);
@@ -73,8 +70,8 @@ impl Jwks {
7370

7471
let key_id = jwt::decode_header(token)
7572
.map_err(Error::InvalidTokenHeader)?
76-
.kid.ok_or(Error::MissingKeyID)?
77-
;
73+
.kid
74+
.ok_or(Error::MissingKeyID)?;
7875

7976
// Refresh key store if needed before validating.
8077
let signing_key = match self.keys.get(&key_id) {
@@ -85,9 +82,12 @@ impl Jwks {
8582
Some(key) => key,
8683
};
8784

88-
Ok(jwt::decode::<HashMap<String, Value>>(token, &signing_key.key.to_decoding_key(), &validation)
89-
.map_err(InvalidToken)?
90-
.claims
85+
Ok(jwt::decode::<HashMap<String, Value>>(
86+
token,
87+
&signing_key.key.to_decoding_key(),
88+
&validation,
9189
)
90+
.map_err(InvalidToken)?
91+
.claims)
9292
}
9393
}

src/main.rs

+18-10
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1+
pub mod handlers;
12
pub mod identity_provider;
23
pub mod jwks;
3-
pub mod handlers;
44
pub mod types;
55

6-
use std::sync::Arc;
76
use crate::config::Config;
87
use axum::routing::post;
9-
use axum::{Router};
8+
use axum::Router;
109
use clap::Parser;
1110
use dotenv::dotenv;
1211
use log::{info, LevelFilter};
12+
use std::sync::Arc;
1313
use tokio::sync::RwLock;
1414

1515
pub mod config {
@@ -52,7 +52,9 @@ fn print_texas_logo() {
5252

5353
#[tokio::main]
5454
async fn main() {
55-
env_logger::builder().filter_level(LevelFilter::Debug).init();
55+
env_logger::builder()
56+
.filter_level(LevelFilter::Debug)
57+
.init();
5658

5759
print_texas_logo();
5860

@@ -62,7 +64,9 @@ async fn main() {
6264

6365
let maskinporten = identity_provider::Maskinporten::new(
6466
cfg.clone(),
65-
jwks::Jwks::new(&cfg.maskinporten_issuer, &cfg.maskinporten_jwks_uri).await.unwrap(),
67+
jwks::Jwks::new(&cfg.maskinporten_issuer, &cfg.maskinporten_jwks_uri)
68+
.await
69+
.unwrap(),
6670
);
6771

6872
let state = handlers::HandlerState {
@@ -71,14 +75,18 @@ async fn main() {
7175
};
7276

7377
let app = Router::new()
74-
.route("/token", post(handlers::token)).with_state(state.clone())
75-
.route("/introspection", post(handlers::introspection).with_state(state.clone()));
78+
.route("/token", post(handlers::token))
79+
.with_state(state.clone())
80+
.route(
81+
"/introspection",
82+
post(handlers::introspection).with_state(state.clone()),
83+
);
7684

77-
let listener = tokio::net::TcpListener::bind(cfg.bind_address).await.unwrap();
85+
let listener = tokio::net::TcpListener::bind(cfg.bind_address)
86+
.await
87+
.unwrap();
7888

7989
info!("Serving on {:?}", listener.local_addr().unwrap());
8090

8191
axum::serve(listener, app).await.unwrap();
8292
}
83-
84-

src/types.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ pub struct ClientTokenRequest {
3535
/// but this might change in the future.
3636
#[derive(Deserialize, Serialize)]
3737
pub enum TokenType {
38-
Bearer
38+
Bearer,
3939
}
4040

4141
/// This is a token request that comes from the application we are serving.

0 commit comments

Comments
 (0)