-
Notifications
You must be signed in to change notification settings - Fork 1.2k
KTOR-7644 Make re-auth status codes configurable #4420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
33520dd
920f3f0
e3b442d
a226271
55f407f
cc7a176
4261ab3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,14 @@ | ||
| /* | ||
| * Copyright 2014-2019 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. | ||
| * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. | ||
| */ | ||
|
|
||
| package io.ktor.client.plugins.auth | ||
|
|
||
| import io.ktor.client.* | ||
| import io.ktor.client.call.* | ||
| import io.ktor.client.plugins.* | ||
| import io.ktor.client.plugins.api.* | ||
| import io.ktor.client.request.* | ||
| import io.ktor.client.statement.* | ||
| import io.ktor.http.* | ||
| import io.ktor.http.auth.* | ||
| import io.ktor.util.* | ||
|
|
@@ -23,9 +23,36 @@ private class AtomicCounter { | |
| val atomic = atomic(0) | ||
| } | ||
|
|
||
| /** | ||
| * Configuration used by [Auth] plugin. | ||
| */ | ||
| @KtorDsl | ||
| public class AuthConfig { | ||
| /** | ||
| * [AuthProvider] list to use. | ||
| */ | ||
| public val providers: MutableList<AuthProvider> = mutableListOf() | ||
|
|
||
| /** | ||
| * The currently set function to control whether a response is unauthorized and should trigger a refresh / re-auth. | ||
| * | ||
| * By default checks against HTTP status 401. | ||
| * | ||
| * You can set this value via [reAuthorizeOnResponse]. | ||
| */ | ||
| @InternalAPI | ||
| public var isUnauthorizedResponse: suspend (HttpResponse) -> Boolean = { it.status == HttpStatusCode.Unauthorized } | ||
| private set | ||
|
Comment on lines
+43
to
+45
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this an internal API? My intention was to allow access to the current value, so you can extend an existing rule.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We want to preserve the right to introduce breaking changes to this field. Keeping it public for those who’re ready to further breaking changes. |
||
|
|
||
| /** | ||
| * Sets a custom function to control whether a response is unauthorized and should trigger a refresh / re-auth. | ||
| * | ||
| * Use this to change the value of [isUnauthorizedResponse]. | ||
| */ | ||
| public fun reAuthorizeOnResponse(block: suspend (HttpResponse) -> Boolean) { | ||
| @OptIn(InternalAPI::class) | ||
| isUnauthorizedResponse = block | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -39,8 +66,9 @@ public val AuthCircuitBreaker: AttributeKey<Unit> = AttributeKey("auth-request") | |
| * | ||
| * You can learn more from [Authentication and authorization](https://ktor.io/docs/auth.html). | ||
| * | ||
| * [providers] - list of auth providers to use. | ||
| * @see [AuthConfig] for configuration options. | ||
| */ | ||
| @OptIn(InternalAPI::class) | ||
| public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthConfig) { | ||
| val providers = pluginConfig.providers.toList() | ||
|
|
||
|
|
@@ -50,7 +78,6 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon | |
| val tokenVersionsAttributeKey = | ||
| AttributeKey<MutableMap<AuthProvider, Int>>("ProviderVersionAttributeKey") | ||
|
|
||
| @OptIn(InternalAPI::class) | ||
| fun findProvider( | ||
| call: HttpClientCall, | ||
| candidateProviders: Set<AuthProvider> | ||
|
|
@@ -64,10 +91,10 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon | |
| } | ||
|
|
||
| authHeaders.isEmpty() -> { | ||
| LOGGER.trace( | ||
| "401 response ${call.request.url} has no or empty \"WWW-Authenticate\" header. " + | ||
| LOGGER.trace { | ||
| "Unauthorized response ${call.request.url} has no or empty \"WWW-Authenticate\" header. " + | ||
| "Can not add or refresh token" | ||
| ) | ||
| } | ||
| null | ||
| } | ||
|
|
||
|
|
@@ -88,9 +115,9 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon | |
| val requestTokenVersion = requestTokenVersions[provider] | ||
|
|
||
| if (requestTokenVersion != null && requestTokenVersion >= tokenVersion.atomic.value) { | ||
| LOGGER.trace("Refreshing token for ${call.request.url}") | ||
| LOGGER.trace { "Refreshing token for ${call.request.url}" } | ||
| if (!provider.refreshToken(call.response)) { | ||
| LOGGER.trace("Refreshing token failed for ${call.request.url}") | ||
| LOGGER.trace { "Refreshing token failed for ${call.request.url}" } | ||
| return false | ||
| } else { | ||
| requestTokenVersions[provider] = tokenVersion.atomic.incrementAndGet() | ||
|
|
@@ -99,7 +126,6 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon | |
| return true | ||
| } | ||
|
|
||
| @OptIn(InternalAPI::class) | ||
| suspend fun Send.Sender.executeWithNewToken( | ||
| call: HttpClientCall, | ||
| provider: AuthProvider, | ||
|
|
@@ -111,13 +137,13 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon | |
| provider.addRequestHeaders(request, authHeader) | ||
| request.attributes.put(AuthCircuitBreaker, Unit) | ||
|
|
||
| LOGGER.trace("Sending new request to ${call.request.url}") | ||
| LOGGER.trace { "Sending new request to ${call.request.url}" } | ||
| return proceed(request) | ||
| } | ||
|
|
||
| onRequest { request, _ -> | ||
| providers.filter { it.sendWithoutRequest(request) }.forEach { provider -> | ||
| LOGGER.trace("Adding auth headers for ${request.url} from provider $provider") | ||
| LOGGER.trace { "Adding auth headers for ${request.url} from provider $provider" } | ||
| val tokenVersion = tokenVersions.computeIfAbsent(provider) { AtomicCounter() } | ||
| val requestTokenVersions = request.attributes | ||
| .computeIfAbsent(tokenVersionsAttributeKey) { mutableMapOf() } | ||
|
|
@@ -128,22 +154,22 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon | |
|
|
||
| on(Send) { originalRequest -> | ||
| val origin = proceed(originalRequest) | ||
| if (origin.response.status != HttpStatusCode.Unauthorized) return@on origin | ||
| if (!pluginConfig.isUnauthorizedResponse(origin.response)) return@on origin | ||
| if (origin.request.attributes.contains(AuthCircuitBreaker)) return@on origin | ||
|
|
||
| var call = origin | ||
|
|
||
| val candidateProviders = HashSet(providers) | ||
|
|
||
| while (call.response.status == HttpStatusCode.Unauthorized) { | ||
| LOGGER.trace("Received 401 for ${call.request.url}") | ||
| while (pluginConfig.isUnauthorizedResponse(call.response)) { | ||
| LOGGER.trace { "Unauthorized response for ${call.request.url}" } | ||
|
|
||
| val (provider, authHeader) = findProvider(call, candidateProviders) ?: run { | ||
| LOGGER.trace("Can not find auth provider for ${call.request.url}") | ||
| LOGGER.trace { "Can not find auth provider for ${call.request.url}" } | ||
| return@on call | ||
| } | ||
|
|
||
| LOGGER.trace("Using provider $provider for ${call.request.url}") | ||
| LOGGER.trace { "Using provider $provider for ${call.request.url}" } | ||
|
|
||
| candidateProviders.remove(provider) | ||
| if (!refreshTokenIfNeeded(call, provider, originalRequest)) return@on call | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.