|
1 | 1 | pub mod identity_provider;
|
2 | 2 | pub mod jwks;
|
| 3 | +pub mod handlers; |
| 4 | +pub mod types; |
3 | 5 |
|
4 | 6 | use std::sync::Arc;
|
5 | 7 | use crate::config::Config;
|
@@ -79,205 +81,4 @@ async fn main() {
|
79 | 81 | axum::serve(listener, app).await.unwrap();
|
80 | 82 | }
|
81 | 83 |
|
82 |
| -pub mod handlers { |
83 |
| - use std::sync::Arc; |
84 |
| - use crate::config::Config; |
85 |
| - use crate::identity_provider::*; |
86 |
| - use crate::types; |
87 |
| - use crate::types::{IdentityProvider, IntrospectRequest, TokenRequest, TokenResponse}; |
88 |
| - use axum::extract::State; |
89 |
| - use axum::http::StatusCode; |
90 |
| - use axum::response::{IntoResponse, Response}; |
91 |
| - use axum::Json; |
92 |
| - use jsonwebtoken as jwt; |
93 |
| - use jsonwebtoken::Algorithm::RS512; |
94 |
| - use jsonwebtoken::DecodingKey; |
95 |
| - use log::error; |
96 |
| - use thiserror::Error; |
97 |
| - use tokio::sync::{RwLock}; |
98 |
| - |
99 |
| - #[derive(Debug, Error)] |
100 |
| - pub enum ApiError { |
101 |
| - #[error("identity provider error: {0}")] |
102 |
| - UpstreamRequest(reqwest::Error), |
103 |
| - |
104 |
| - #[error("upstream error: {0}")] |
105 |
| - Upstream(types::ErrorResponse), |
106 |
| - |
107 |
| - #[error("invalid JSON in token response: {0}")] |
108 |
| - JSON(reqwest::Error), |
109 |
| - |
110 |
| - #[error("invalid token: {0}")] |
111 |
| - Validate(jwt::errors::Error), |
112 |
| - } |
113 |
| - |
114 |
| - impl IntoResponse for ApiError { |
115 |
| - fn into_response(self) -> Response { |
116 |
| - match &self { |
117 |
| - ApiError::UpstreamRequest(err) => { |
118 |
| - (err.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), self.to_string()) |
119 |
| - } |
120 |
| - ApiError::JSON(_) => { |
121 |
| - (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()) |
122 |
| - } |
123 |
| - ApiError::Upstream(_err) => { |
124 |
| - (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()) |
125 |
| - } |
126 |
| - ApiError::Validate(_) => { |
127 |
| - (StatusCode::BAD_REQUEST, self.to_string()) |
128 |
| - } |
129 |
| - }.into_response() |
130 |
| - } |
131 |
| - } |
132 |
| - |
133 |
| - #[derive(Clone)] |
134 |
| - pub struct HandlerState { |
135 |
| - pub cfg: Config, |
136 |
| - pub maskinporten: Arc<RwLock<Maskinporten>>, |
137 |
| - // TODO: other providers |
138 |
| - } |
139 |
| - |
140 |
| - impl HandlerState { |
141 |
| - async fn token_request(&self, identity_provider: &IdentityProvider, target: String ) -> Box<dyn erased_serde::Serialize + Send> { |
142 |
| - match identity_provider { |
143 |
| - IdentityProvider::EntraID => todo!(), |
144 |
| - IdentityProvider::TokenX => todo!(), |
145 |
| - IdentityProvider::Maskinporten => { |
146 |
| - Box::new(self.maskinporten.read().await.token_request(target)) |
147 |
| - }, |
148 |
| - } |
149 |
| - } |
150 |
| - |
151 |
| - async fn token_endpoint(&self, identity_provider: &IdentityProvider) -> String { |
152 |
| - match identity_provider { |
153 |
| - IdentityProvider::EntraID => todo!(), |
154 |
| - IdentityProvider::TokenX => todo!(), |
155 |
| - IdentityProvider::Maskinporten => { |
156 |
| - self.maskinporten.read().await.token_endpoint() |
157 |
| - }, |
158 |
| - } |
159 |
| - } |
160 |
| - } |
161 |
| - |
162 |
| - #[axum::debug_handler] |
163 |
| - pub async fn token(State(state): State<HandlerState>, Json(request): Json<TokenRequest>) -> Result<impl IntoResponse, ApiError> { |
164 |
| - let endpoint = state.token_endpoint(&request.identity_provider).await; |
165 |
| - let params = state.token_request(&request.identity_provider, request.target).await; |
166 |
| - |
167 |
| - let client = reqwest::Client::new(); |
168 |
| - let request_builder = client.post(endpoint) |
169 |
| - .header("accept", "application/json") |
170 |
| - .form(¶ms); |
171 |
| - |
172 |
| - let response = request_builder |
173 |
| - .send() |
174 |
| - .await |
175 |
| - .map_err(ApiError::UpstreamRequest)? |
176 |
| - ; |
177 |
| - |
178 |
| - if response.status() >= StatusCode::BAD_REQUEST { |
179 |
| - let err: types::ErrorResponse = response.json().await.map_err(ApiError::JSON)?; |
180 |
| - return Err(ApiError::Upstream(err)); |
181 |
| - } |
182 |
| - |
183 |
| - let res: TokenResponse = response |
184 |
| - .json().await |
185 |
| - .inspect_err(|err| { |
186 |
| - error!("Identity provider returned invalid JSON: {:?}", err) |
187 |
| - }) |
188 |
| - .map_err(ApiError::JSON)? |
189 |
| - ; |
190 |
| - |
191 |
| - Ok((StatusCode::OK, Json(res))) |
192 |
| - } |
193 |
| - |
194 |
| - pub async fn introspection(State(state): State<HandlerState>, Json(request): Json<IntrospectRequest>) -> Result<impl IntoResponse, ApiError> { |
195 |
| - // Need to decode the token to get the issuer before we actually validate it. |
196 |
| - let mut validation = jwt::Validation::new(RS512); |
197 |
| - validation.validate_exp = false; |
198 |
| - validation.insecure_disable_signature_validation(); |
199 |
| - let key = DecodingKey::from_secret(&[]); |
200 |
| - let token_data = jwt::decode::<Claims>(&request.token, &key, &validation).map_err(ApiError::Validate)?; |
201 |
| - |
202 |
| - let claims = match token_data.claims.iss { |
203 |
| - s if s == state.cfg.maskinporten_issuer => state.maskinporten.write().await.introspect(request.token).await, |
204 |
| - _ => panic!("Unknown issuer: {}", token_data.claims.iss), |
205 |
| - }; |
206 |
| - |
207 |
| - Ok((StatusCode::OK, Json(claims))) |
208 |
| - } |
209 |
| - |
210 |
| - #[derive(serde::Deserialize)] |
211 |
| - struct Claims { |
212 |
| - iss: String, |
213 |
| - } |
214 |
| -} |
215 |
| - |
216 |
| -pub mod types { |
217 |
| - use serde::{Deserialize, Serialize}; |
218 |
| - use std::fmt::{Display, Formatter}; |
219 |
| - |
220 |
| - /// This is an upstream RFCXXXX token response. |
221 |
| - #[derive(Serialize, Deserialize)] |
222 |
| - pub struct TokenResponse { |
223 |
| - pub access_token: String, |
224 |
| - pub token_type: TokenType, |
225 |
| - #[serde(rename = "expires_in")] |
226 |
| - pub expires_in_seconds: usize, |
227 |
| - } |
228 |
| - |
229 |
| - #[derive(Deserialize, Debug, Clone)] |
230 |
| - pub struct ErrorResponse { |
231 |
| - pub error: String, |
232 |
| - #[serde(rename = "error_description")] |
233 |
| - pub description: String, |
234 |
| - } |
235 |
| - |
236 |
| - impl Display for ErrorResponse { |
237 |
| - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
238 |
| - write!(f, "{}: {}", self.error, self.description) |
239 |
| - } |
240 |
| - } |
241 |
| - |
242 |
| - /// This is the token request sent to our identity provider. |
243 |
| - /// TODO: hard coded parameters that only works with Maskinporten for now. |
244 |
| - #[derive(Serialize)] |
245 |
| - pub struct ClientTokenRequest { |
246 |
| - pub grant_type: String, |
247 |
| - pub assertion: String, |
248 |
| - } |
249 |
| - |
250 |
| - /// For forwards API compatibility. Token type is always Bearer, |
251 |
| - /// but this might change in the future. |
252 |
| - #[derive(Deserialize, Serialize)] |
253 |
| - pub enum TokenType { |
254 |
| - Bearer |
255 |
| - } |
256 |
| - |
257 |
| - /// This is a token request that comes from the application we are serving. |
258 |
| - #[derive(Deserialize)] |
259 |
| - pub struct TokenRequest { |
260 |
| - pub target: String, // typically <cluster>:<namespace>:<app> |
261 |
| - pub identity_provider: IdentityProvider, |
262 |
| - #[serde(skip_serializing_if = "Option::is_none")] |
263 |
| - pub user_token: Option<String>, |
264 |
| - #[serde(skip_serializing_if = "Option::is_none")] |
265 |
| - pub force: Option<bool>, |
266 |
| - } |
267 |
| - |
268 |
| - #[derive(Deserialize)] |
269 |
| - pub struct IntrospectRequest { |
270 |
| - pub token: String, |
271 |
| - } |
272 |
| - |
273 |
| - #[derive(Deserialize, Serialize)] |
274 |
| - pub enum IdentityProvider { |
275 |
| - #[serde(rename = "entra")] |
276 |
| - EntraID, |
277 |
| - #[serde(rename = "tokenx")] |
278 |
| - TokenX, |
279 |
| - #[serde(rename = "maskinporten")] |
280 |
| - Maskinporten, |
281 |
| - } |
282 |
| -} |
283 | 84 |
|
0 commit comments