Skip to content

Commit cd8b379

Browse files
committed
feat: replace custom Ktor route plugin with AuthenticationProvider
This is more semantically correct and also plays more nicely with the native Ktor authentication constructs, e.g. custom authorization plugins with `on(AuthenticationChecked)`, `ApplicationCall.principal()` and so on.
1 parent 70eeff5 commit cd8b379

File tree

5 files changed

+233
-163
lines changed

5 files changed

+233
-163
lines changed

build.gradle.kts

+1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ subprojects {
6464
dependencies {
6565
implementation(kotlin("stdlib"))
6666
implementation("io.ktor:ktor-server:${ktorVersion}")
67+
implementation("io.ktor:ktor-server-auth:${ktorVersion}")
6768
implementation("io.ktor:ktor-server-cio:${ktorVersion}")
6869
implementation("io.ktor:ktor-server-content-negotiation:${ktorVersion}")
6970
implementation("io.ktor:ktor-serialization-jackson:${ktorVersion}")

wonderwalled-azure/src/main/kotlin/io/nais/Wonderwalled.kt

+52-45
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,82 @@
11
package io.nais
22

33
import io.ktor.http.HttpStatusCode
4+
import io.ktor.server.application.install
5+
import io.ktor.server.auth.Authentication
6+
import io.ktor.server.auth.authenticate
7+
import io.ktor.server.auth.principal
48
import io.ktor.server.response.respond
59
import io.ktor.server.routing.get
610
import io.ktor.server.routing.route
711
import io.ktor.server.routing.routing
812
import io.nais.common.AuthClient
913
import io.nais.common.IdentityProvider
10-
import io.nais.common.NaisAuth
14+
import io.nais.common.TexasPrincipal
1115
import io.nais.common.TokenResponse
1216
import io.nais.common.bearerToken
1317
import io.nais.common.requestHeaders
1418
import io.nais.common.server
19+
import io.nais.common.texas
1520

1621
fun main() {
1722
server { config ->
1823
val azure = AuthClient(config.auth, IdentityProvider.AZURE_AD)
1924

20-
routing {
21-
route("api") {
22-
install(NaisAuth) {
23-
client = azure
24-
ingress = config.ingress
25-
}
26-
27-
get("headers") {
28-
call.respond(call.requestHeaders())
29-
}
25+
install(Authentication) {
26+
texas {
27+
client = azure
28+
ingress = config.ingress
29+
}
30+
}
3031

31-
get("me") {
32-
val token = call.bearerToken()
33-
if (token == null) {
34-
call.respond(HttpStatusCode.Unauthorized, "missing bearer token in Authorization header")
35-
return@get
32+
routing {
33+
authenticate {
34+
route("api") {
35+
get("headers") {
36+
call.respond(call.requestHeaders())
3637
}
3738

38-
val introspection = azure.introspect(token)
39-
call.respond(introspection)
40-
}
41-
42-
get("obo") {
43-
val token = call.bearerToken()
44-
if (token == null) {
45-
call.respond(HttpStatusCode.Unauthorized, "missing bearer token in Authorization header")
46-
return@get
39+
get("me") {
40+
val principal = call.principal<TexasPrincipal>()
41+
if (principal == null) {
42+
call.respond(HttpStatusCode.Unauthorized, "missing principal")
43+
return@get
44+
}
45+
call.respond(principal.claims)
4746
}
4847

49-
val audience = call.request.queryParameters["aud"]
50-
if (audience == null) {
51-
call.respond(HttpStatusCode.BadRequest, "missing 'aud' query parameter")
52-
return@get
53-
}
48+
get("obo") {
49+
val token = call.bearerToken()
50+
if (token == null) {
51+
call.respond(HttpStatusCode.Unauthorized, "missing bearer token in Authorization header")
52+
return@get
53+
}
5454

55-
val target = audience.toScope()
56-
when (val response = azure.exchange(target, token)) {
57-
is TokenResponse.Success -> call.respond(response)
58-
is TokenResponse.Error -> call.respond(response.status, response.error)
59-
}
60-
}
55+
val audience = call.request.queryParameters["aud"]
56+
if (audience == null) {
57+
call.respond(HttpStatusCode.BadRequest, "missing 'aud' query parameter")
58+
return@get
59+
}
6160

62-
get("m2m") {
63-
val audience = call.request.queryParameters["aud"]
64-
if (audience == null) {
65-
call.respond(HttpStatusCode.BadRequest, "missing 'aud' query parameter")
66-
return@get
61+
val target = audience.toScope()
62+
when (val response = azure.exchange(target, token)) {
63+
is TokenResponse.Success -> call.respond(response)
64+
is TokenResponse.Error -> call.respond(response.status, response.error)
65+
}
6766
}
6867

69-
val target = audience.toScope()
70-
when (val response = azure.token(target)) {
71-
is TokenResponse.Success -> call.respond(response)
72-
is TokenResponse.Error -> call.respond(response.status, response.error)
68+
get("m2m") {
69+
val audience = call.request.queryParameters["aud"]
70+
if (audience == null) {
71+
call.respond(HttpStatusCode.BadRequest, "missing 'aud' query parameter")
72+
return@get
73+
}
74+
75+
val target = audience.toScope()
76+
when (val response = azure.token(target)) {
77+
is TokenResponse.Success -> call.respond(response)
78+
is TokenResponse.Error -> call.respond(response.status, response.error)
79+
}
7380
}
7481
}
7582
}

wonderwalled-common/src/main/kotlin/io/nais/common/Auth.kt

+92-55
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ import io.ktor.client.request.forms.submitForm
1212
import io.ktor.http.HttpStatusCode
1313
import io.ktor.http.parameters
1414
import io.ktor.server.application.ApplicationCall
15-
import io.ktor.server.application.createRouteScopedPlugin
15+
import io.ktor.server.auth.AuthenticationConfig
16+
import io.ktor.server.auth.AuthenticationContext
17+
import io.ktor.server.auth.AuthenticationFailedCause
18+
import io.ktor.server.auth.AuthenticationProvider
1619
import io.ktor.server.request.host
1720
import io.ktor.server.request.uri
1821
import io.ktor.server.response.respondRedirect
@@ -57,8 +60,8 @@ data class TokenErrorResponse(
5760
*
5861
* The `other` field is a generic map that contains an arbitrary combination of additional claims contained in the token.
5962
*
60-
* If you know the exact claims to expect, you should instead explicitly define these as fields on the
61-
* data class itself, for example:
63+
* TODO(user): If you know the exact claims to expect, you should instead explicitly define these as fields on the
64+
* data class itself and ignore everything else, for example:
6265
*
6366
* ```kotlin
6467
* data class TokenIntrospectionResponse(
@@ -79,6 +82,9 @@ data class TokenIntrospectionResponse(
7982
val other: Map<String, Any?> = mutableMapOf(),
8083
)
8184

85+
/**
86+
* AuthClient is a client that interacts with Texas.
87+
*/
8288
class AuthClient(
8389
private val config: Config.Auth,
8490
private val provider: IdentityProvider,
@@ -150,66 +156,97 @@ class AuthClient(
150156
}
151157
}
152158

153-
class AuthPluginConfiguration(
154-
var client: AuthClient? = null,
155-
var ingress: String? = null,
156-
var logger: Logger = LoggerFactory.getLogger("io.nais.common.ktor.NaisAuth"),
157-
)
159+
fun AuthenticationConfig.texas(
160+
name: String? = null,
161+
configure: TexasAuthenticationProvider.Config.() -> Unit,
162+
) {
163+
register(TexasAuthenticationProvider.Config(name).apply(configure).build())
164+
}
158165

159-
val NaisAuth =
160-
createRouteScopedPlugin(
161-
name = "NaisAuth",
162-
createConfiguration = ::AuthPluginConfiguration,
163-
) {
164-
val logger = pluginConfig.logger
165-
val client = pluginConfig.client ?: throw IllegalArgumentException("NaisAuth plugin: client must be set")
166-
val ingress = pluginConfig.ingress ?: ""
167-
168-
val challenge: suspend (ApplicationCall) -> Unit = { call ->
169-
val target = call.loginUrl(ingress)
170-
logger.info("unauthenticated: redirecting to '$target'")
171-
call.respondRedirect(target)
172-
}
166+
/**
167+
* TexasAuthenticationProvider is an [io.ktor.server.auth.AuthenticationProvider] that validates tokens by using Texas's introspection endpoint.
168+
*/
169+
class TexasAuthenticationProvider(
170+
config: Config,
171+
) : AuthenticationProvider(config) {
172+
class Config internal constructor(
173+
name: String?,
174+
) : AuthenticationProvider.Config(name) {
175+
lateinit var client: AuthClient
176+
var logger: Logger = LoggerFactory.getLogger("io.nais.common.TexasAuthenticationProvider")
177+
var ingress: String = ""
178+
179+
internal fun build() = TexasAuthenticationProvider(this)
180+
}
173181

174-
pluginConfig.apply {
175-
onCall { call ->
176-
val token = call.bearerToken()
177-
if (token == null) {
178-
logger.warn("unauthenticated: no Bearer token found in Authorization header")
179-
challenge(call)
180-
return@onCall
181-
}
182+
private val client = config.client
183+
private val logger = config.logger
184+
private val ingress = config.ingress
182185

183-
val introspectResponse =
184-
try {
185-
client.introspect(token)
186-
} catch (e: Exception) {
187-
logger.error("unauthenticated: introspect request failed: ${e.message}")
188-
challenge(call)
189-
return@onCall
190-
}
191-
192-
if (introspectResponse.active) {
193-
logger.info("authenticated - claims='${introspectResponse.other}'")
194-
return@onCall
195-
}
186+
override suspend fun onAuthenticate(context: AuthenticationContext) {
187+
val applicationCall = context.call
188+
val token = applicationCall.bearerToken()
196189

197-
logger.warn("unauthenticated: ${introspectResponse.error}")
198-
challenge(call)
199-
return@onCall
190+
if (token == null) {
191+
logger.warn("unauthenticated: no Bearer token found in Authorization header")
192+
context.loginChallenge(AuthenticationFailedCause.NoCredentials)
193+
return
194+
}
195+
196+
val introspectResponse =
197+
try {
198+
client.introspect(token)
199+
} catch (e: Exception) {
200+
// TODO(user): You should handle the specific exceptions that can be thrown by the HTTP client, e.g. retry on network errors and so on
201+
logger.error("unauthenticated: introspect request failed: ${e.message}")
202+
context.loginChallenge(AuthenticationFailedCause.Error(e.message ?: "introspect request failed"))
203+
return
200204
}
205+
206+
if (!introspectResponse.active) {
207+
logger.warn("unauthenticated: ${introspectResponse.error}")
208+
context.loginChallenge(AuthenticationFailedCause.InvalidCredentials)
209+
return
201210
}
202211

203-
logger.info("NaisAuth plugin loaded.")
212+
logger.info("authenticated - claims='${introspectResponse.other}'")
213+
context.principal(
214+
TexasPrincipal(
215+
claims = introspectResponse.other,
216+
token = token,
217+
),
218+
)
204219
}
205220

206-
// loginUrl constructs a URL string that points to the login endpoint (Wonderwall) for redirecting a request.
207-
// It also indicates that the user should be redirected back to the original request path after authentication
208-
private fun ApplicationCall.loginUrl(defaultHost: String): String {
209-
val host =
210-
defaultHost.ifEmpty(defaultValue = {
211-
"${this.request.local.scheme}://${this.request.host()}"
212-
})
221+
private fun AuthenticationContext.loginChallenge(cause: AuthenticationFailedCause) {
222+
challenge("Texas", cause) { authenticationProcedureChallenge, call ->
223+
val target = call.loginUrl()
224+
logger.info("unauthenticated: redirecting to '$target'")
225+
call.respondRedirect(target)
226+
authenticationProcedureChallenge.complete()
227+
}
228+
}
213229

214-
return "$host/oauth2/login?redirect=${this.request.uri}"
230+
/**
231+
* loginUrl constructs a URL string that points to the login endpoint (Wonderwall) for redirecting a request.
232+
* It also indicates that the user should be redirected back to the original request path after authentication
233+
*/
234+
private fun ApplicationCall.loginUrl(): String {
235+
val host =
236+
ingress.ifEmpty(defaultValue = {
237+
"${this.request.local.scheme}://${this.request.host()}"
238+
})
239+
240+
return "$host/oauth2/login?redirect=${this.request.uri}"
241+
}
215242
}
243+
244+
/**
245+
* TexasPrincipal represents the authenticated principal.
246+
* The `claims` field is a map of arbitrary claims from the [TokenIntrospectionResponse].
247+
* TODO(user): You should explicitly define expected claims as fields on the data class itself instead of using a generic map.
248+
*/
249+
data class TexasPrincipal(
250+
val claims: Map<String, Any?>,
251+
val token: String,
252+
)

0 commit comments

Comments
 (0)