Skip to content

Commit a2e3aab

Browse files
wkornewaldosipxd
andcommitted
KTOR-7644 Make re-auth status codes configurable (#4420)
Some services use 403 instead of 401. Changing them might be impossible. With this change Ktor can flexibly work with any broken service. --------- Co-authored-by: Osip Fatkullin <[email protected]>
1 parent 5269c0f commit a2e3aab

File tree

5 files changed

+72
-17
lines changed

5 files changed

+72
-17
lines changed

ktor-client/ktor-client-plugins/ktor-client-auth/api/ktor-client-auth.api

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
public final class io/ktor/client/plugins/auth/AuthConfig {
22
public fun <init> ()V
33
public final fun getProviders ()Ljava/util/List;
4+
public final fun isUnauthorizedResponse ()Lkotlin/jvm/functions/Function2;
5+
public final fun reAuthorizeOnResponse (Lkotlin/jvm/functions/Function2;)V
46
}
57

68
public final class io/ktor/client/plugins/auth/AuthKt {

ktor-client/ktor-client-plugins/ktor-client-auth/api/ktor-client-auth.klib.api

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ final class io.ktor.client.plugins.auth/AuthConfig { // io.ktor.client.plugins.a
155155

156156
final val providers // io.ktor.client.plugins.auth/AuthConfig.providers|{}providers[0]
157157
final fun <get-providers>(): kotlin.collections/MutableList<io.ktor.client.plugins.auth/AuthProvider> // io.ktor.client.plugins.auth/AuthConfig.providers.<get-providers>|<get-providers>(){}[0]
158+
159+
final var isUnauthorizedResponse // io.ktor.client.plugins.auth/AuthConfig.isUnauthorizedResponse|{}isUnauthorizedResponse[0]
160+
final fun <get-isUnauthorizedResponse>(): kotlin.coroutines/SuspendFunction1<io.ktor.client.statement/HttpResponse, kotlin/Boolean> // io.ktor.client.plugins.auth/AuthConfig.isUnauthorizedResponse.<get-isUnauthorizedResponse>|<get-isUnauthorizedResponse>(){}[0]
161+
162+
final fun reAuthorizeOnResponse(kotlin.coroutines/SuspendFunction1<io.ktor.client.statement/HttpResponse, kotlin/Boolean>) // io.ktor.client.plugins.auth/AuthConfig.reAuthorizeOnResponse|reAuthorizeOnResponse(kotlin.coroutines.SuspendFunction1<io.ktor.client.statement.HttpResponse,kotlin.Boolean>){}[0]
158163
}
159164

160165
final val io.ktor.client.plugins.auth/Auth // io.ktor.client.plugins.auth/Auth|{}Auth[0]

ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ package io.ktor.client.plugins.auth
66

77
import io.ktor.client.*
88
import io.ktor.client.call.*
9-
import io.ktor.client.plugins.*
109
import io.ktor.client.plugins.api.*
1110
import io.ktor.client.request.*
11+
import io.ktor.client.statement.*
1212
import io.ktor.http.*
1313
import io.ktor.http.auth.*
1414
import io.ktor.util.*
@@ -23,9 +23,36 @@ private class AtomicCounter {
2323
val atomic = atomic(0)
2424
}
2525

26+
/**
27+
* Configuration used by [Auth] plugin.
28+
*/
2629
@KtorDsl
2730
public class AuthConfig {
31+
/**
32+
* [AuthProvider] list to use.
33+
*/
2834
public val providers: MutableList<AuthProvider> = mutableListOf()
35+
36+
/**
37+
* The currently set function to control whether a response is unauthorized and should trigger a refresh / re-auth.
38+
*
39+
* By default checks against HTTP status 401.
40+
*
41+
* You can set this value via [reAuthorizeOnResponse].
42+
*/
43+
@InternalAPI
44+
public var isUnauthorizedResponse: suspend (HttpResponse) -> Boolean = { it.status == HttpStatusCode.Unauthorized }
45+
private set
46+
47+
/**
48+
* Sets a custom function to control whether a response is unauthorized and should trigger a refresh / re-auth.
49+
*
50+
* Use this to change the value of [isUnauthorizedResponse].
51+
*/
52+
public fun reAuthorizeOnResponse(block: suspend (HttpResponse) -> Boolean) {
53+
@OptIn(InternalAPI::class)
54+
isUnauthorizedResponse = block
55+
}
2956
}
3057

3158
/**
@@ -39,8 +66,9 @@ public val AuthCircuitBreaker: AttributeKey<Unit> = AttributeKey("auth-request")
3966
*
4067
* You can learn more from [Authentication and authorization](https://ktor.io/docs/auth.html).
4168
*
42-
* [providers] - list of auth providers to use.
69+
* @see [AuthConfig] for configuration options.
4370
*/
71+
@OptIn(InternalAPI::class)
4472
public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthConfig) {
4573
val providers = pluginConfig.providers.toList()
4674

@@ -50,7 +78,6 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
5078
val tokenVersionsAttributeKey =
5179
AttributeKey<MutableMap<AuthProvider, Int>>("ProviderVersionAttributeKey")
5280

53-
@OptIn(InternalAPI::class)
5481
fun findProvider(
5582
call: HttpClientCall,
5683
candidateProviders: Set<AuthProvider>
@@ -64,10 +91,10 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
6491
}
6592

6693
authHeaders.isEmpty() -> {
67-
LOGGER.trace(
68-
"401 response ${call.request.url} has no or empty \"WWW-Authenticate\" header. " +
94+
LOGGER.trace {
95+
"Unauthorized response ${call.request.url} has no or empty \"WWW-Authenticate\" header. " +
6996
"Can not add or refresh token"
70-
)
97+
}
7198
null
7299
}
73100

@@ -88,9 +115,9 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
88115
val requestTokenVersion = requestTokenVersions[provider]
89116

90117
if (requestTokenVersion != null && requestTokenVersion >= tokenVersion.atomic.value) {
91-
LOGGER.trace("Refreshing token for ${call.request.url}")
118+
LOGGER.trace { "Refreshing token for ${call.request.url}" }
92119
if (!provider.refreshToken(call.response)) {
93-
LOGGER.trace("Refreshing token failed for ${call.request.url}")
120+
LOGGER.trace { "Refreshing token failed for ${call.request.url}" }
94121
return false
95122
} else {
96123
requestTokenVersions[provider] = tokenVersion.atomic.incrementAndGet()
@@ -99,7 +126,6 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
99126
return true
100127
}
101128

102-
@OptIn(InternalAPI::class)
103129
suspend fun Send.Sender.executeWithNewToken(
104130
call: HttpClientCall,
105131
provider: AuthProvider,
@@ -111,13 +137,13 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
111137
provider.addRequestHeaders(request, authHeader)
112138
request.attributes.put(AuthCircuitBreaker, Unit)
113139

114-
LOGGER.trace("Sending new request to ${call.request.url}")
140+
LOGGER.trace { "Sending new request to ${call.request.url}" }
115141
return proceed(request)
116142
}
117143

118144
onRequest { request, _ ->
119145
providers.filter { it.sendWithoutRequest(request) }.forEach { provider ->
120-
LOGGER.trace("Adding auth headers for ${request.url} from provider $provider")
146+
LOGGER.trace { "Adding auth headers for ${request.url} from provider $provider" }
121147
val tokenVersion = tokenVersions.computeIfAbsent(provider) { AtomicCounter() }
122148
val requestTokenVersions = request.attributes
123149
.computeIfAbsent(tokenVersionsAttributeKey) { mutableMapOf() }
@@ -128,22 +154,22 @@ public val Auth: ClientPlugin<AuthConfig> = createClientPlugin("Auth", ::AuthCon
128154

129155
on(Send) { originalRequest ->
130156
val origin = proceed(originalRequest)
131-
if (origin.response.status != HttpStatusCode.Unauthorized) return@on origin
157+
if (!pluginConfig.isUnauthorizedResponse(origin.response)) return@on origin
132158
if (origin.request.attributes.contains(AuthCircuitBreaker)) return@on origin
133159

134160
var call = origin
135161

136162
val candidateProviders = HashSet(providers)
137163

138-
while (call.response.status == HttpStatusCode.Unauthorized) {
139-
LOGGER.trace("Received 401 for ${call.request.url}")
164+
while (pluginConfig.isUnauthorizedResponse(call.response)) {
165+
LOGGER.trace { "Unauthorized response for ${call.request.url}" }
140166

141167
val (provider, authHeader) = findProvider(call, candidateProviders) ?: run {
142-
LOGGER.trace("Can not find auth provider for ${call.request.url}")
168+
LOGGER.trace { "Can not find auth provider for ${call.request.url}" }
143169
return@on call
144170
}
145171

146-
LOGGER.trace("Using provider $provider for ${call.request.url}")
172+
LOGGER.trace { "Using provider $provider for ${call.request.url}" }
147173

148174
candidateProviders.remove(provider)
149175
if (!refreshTokenIfNeeded(call, provider, originalRequest)) return@on call

ktor-client/ktor-client-plugins/ktor-client-auth/common/test/io/ktor/client/plugins/auth/AuthTest.kt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,27 @@ class AuthTest : ClientLoader() {
403403
}
404404
}
405405

406+
@Test
407+
fun testForbiddenBearerAuthWithInvalidAccessAndValidRefreshTokens() = clientTests {
408+
config {
409+
install(Auth) {
410+
reAuthorizeOnResponse { it.status == HttpStatusCode.Forbidden }
411+
bearer {
412+
refreshTokens { BearerTokens("valid", "refresh") }
413+
loadTokens { BearerTokens("invalid", "refresh") }
414+
}
415+
}
416+
417+
expectSuccess = false
418+
}
419+
420+
test { client ->
421+
client.prepareGet("$TEST_SERVER/auth/bearer/test-refresh?status=403").execute {
422+
assertEquals(HttpStatusCode.OK, it.status)
423+
}
424+
}
425+
}
426+
406427
// The return of refreshTokenFun is null, cause it should not be called at all, if loadTokensFun returns valid tokens
407428
@Test
408429
fun testUnauthorizedBearerAuthWithValidAccessTokenAndInvalidRefreshToken() = clientTests {

ktor-test-server/src/main/kotlin/test/server/tests/Auth.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ internal fun Application.authTestServer() {
129129
val token = call.request.headers["Authorization"]
130130
if (token.isNullOrEmpty() || token.contains("invalid")) {
131131
call.response.header(HttpHeaders.WWWAuthenticate, "Bearer realm=\"TestServer\"")
132-
call.respond(HttpStatusCode.Unauthorized)
132+
val status = call.request.queryParameters["status"]?.toIntOrNull() ?: 401
133+
call.respond(HttpStatusCode.fromValue(status))
133134
return@get
134135
}
135136

0 commit comments

Comments
 (0)