Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package com.amplifyframework.auth.cognito.actions

import aws.sdk.kotlin.services.cognitoidentity.model.GetCredentialsForIdentityRequest
import aws.sdk.kotlin.services.cognitoidentity.model.GetIdRequest
import aws.sdk.kotlin.services.cognitoidentityprovider.getTokensFromRefreshToken
import aws.sdk.kotlin.services.cognitoidentityprovider.initiateAuth
import aws.sdk.kotlin.services.cognitoidentityprovider.model.AuthFlowType
import aws.smithy.kotlin.runtime.time.Instant
Expand All @@ -40,44 +41,27 @@ import com.amplifyframework.statemachine.codegen.events.RefreshSessionEvent
import kotlin.time.Duration.Companion.seconds

internal object FetchAuthSessionCognitoActions : FetchAuthSessionActions {
private const val KEY_SECRET_HASH = "SECRET_HASH"
private const val KEY_REFRESH_TOKEN = "REFRESH_TOKEN"
private const val KEY_DEVICE_KEY = "DEVICE_KEY"

override fun refreshUserPoolTokensAction(signedInData: SignedInData) =
Action<AuthEnvironment>("RefreshUserPoolTokens") { id, dispatcher ->
logger.verbose("$id Starting execution")
val evt = try {
val username = signedInData.username
val tokens = signedInData.cognitoUserPoolTokens

val authParameters = mutableMapOf<String, String>()
val secretHash = AuthHelper.getSecretHash(
username,
configuration.userPool?.appClient,
configuration.userPool?.appClientSecret
)
tokens.refreshToken?.let { authParameters[KEY_REFRESH_TOKEN] = it }
secretHash?.let { authParameters[KEY_SECRET_HASH] = it }

val encodedContextData = getUserContextData(username)
val deviceMetadata: DeviceMetadata.Metadata? = getDeviceMetadata(username)
deviceMetadata?.let { authParameters[KEY_DEVICE_KEY] = it.deviceKey }
val pinpointEndpointId = getPinpointEndpointId()

val response = cognitoAuthService.cognitoIdentityProviderClient?.initiateAuth {
authFlow = AuthFlowType.RefreshToken
val response = cognitoAuthService.cognitoIdentityProviderClient?.getTokensFromRefreshToken {
refreshToken = tokens.refreshToken
clientId = configuration.userPool?.appClient
this.authParameters = authParameters
pinpointEndpointId?.let { analyticsMetadata { analyticsEndpointId = it } }
encodedContextData?.let { userContextData { encodedData = it } }
clientSecret = configuration.userPool?.appClientSecret
deviceKey = deviceMetadata?.deviceKey
}

val expiresIn = response?.authenticationResult?.expiresIn?.toLong() ?: 0
val refreshedUserPoolTokens = CognitoUserPoolTokens(
idToken = response?.authenticationResult?.idToken,
accessToken = response?.authenticationResult?.accessToken,
refreshToken = tokens.refreshToken,
refreshToken = response?.authenticationResult?.refreshToken ?: tokens.refreshToken,
expiration = Instant.now().plus(expiresIn.seconds).epochSeconds
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amplifyframework.auth.cognito.actions

import androidx.test.core.app.ApplicationProvider
import aws.sdk.kotlin.services.cognitoidentityprovider.CognitoIdentityProviderClient
import aws.sdk.kotlin.services.cognitoidentityprovider.model.AuthenticationResultType
import aws.sdk.kotlin.services.cognitoidentityprovider.model.GetTokensFromRefreshTokenResponse
import aws.sdk.kotlin.services.cognitoidentityprovider.model.NotAuthorizedException
import com.amplifyframework.auth.cognito.AWSCognitoAuthService
import com.amplifyframework.auth.cognito.AuthConfiguration
import com.amplifyframework.auth.cognito.AuthEnvironment
import com.amplifyframework.auth.cognito.StoreClientBehavior
import com.amplifyframework.auth.cognito.mockSignedInData
import com.amplifyframework.logging.Logger
import com.amplifyframework.statemachine.EventDispatcher
import com.amplifyframework.statemachine.StateMachineEvent
import com.amplifyframework.statemachine.codegen.data.AmplifyCredential
import com.amplifyframework.statemachine.codegen.data.CognitoUserPoolTokens
import com.amplifyframework.statemachine.codegen.data.CredentialType
import com.amplifyframework.statemachine.codegen.data.DeviceMetadata
import com.amplifyframework.statemachine.codegen.data.UserPoolConfiguration
import com.amplifyframework.statemachine.codegen.events.AuthorizationEvent
import com.amplifyframework.statemachine.codegen.events.RefreshSessionEvent
import io.kotest.matchers.shouldBe
import io.kotest.matchers.types.shouldBeInstanceOf
import io.mockk.Runs
import io.mockk.coEvery
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
import io.mockk.slot
import kotlinx.coroutines.test.runTest
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner

@RunWith(RobolectricTestRunner::class)
class FetchAuthSessionCognitoActionsTest {

private val pool = mockk<UserPoolConfiguration> {
every { appClient } returns "client"
every { appClientSecret } returns "secret"
every { region } returns "us-east-1"
every { poolId } returns "pool_id"
}
private val configuration = mockk<AuthConfiguration> {
every { userPool } returns pool
every { identityPool } returns null
}
private val cognitoAuthService = mockk<AWSCognitoAuthService>()
private val credentialStoreClient = mockk<StoreClientBehavior> {
coEvery { loadCredentials(any<CredentialType.Device>()) } returns AmplifyCredential.DeviceData(
DeviceMetadata.Metadata(deviceKey = "device_key", deviceGroupKey = "device_group")
)
}
private val logger = mockk<Logger>(relaxed = true)
private val cognitoIdentityProviderClientMock = mockk<CognitoIdentityProviderClient>()

private val capturedEvent = slot<StateMachineEvent>()
private val dispatcher = mockk<EventDispatcher> {
every { send(capture(capturedEvent)) } just Runs
}

private lateinit var authEnvironment: AuthEnvironment

@Before
fun setup() {
every { cognitoAuthService.cognitoIdentityProviderClient }.answers { cognitoIdentityProviderClientMock }
authEnvironment = AuthEnvironment(
ApplicationProvider.getApplicationContext(),
configuration,
cognitoAuthService,
credentialStoreClient,
null,
null,
logger
)
}

@Test
fun `refreshUserPoolTokensAction calls getTokensFromRefreshToken and handles token rotation`() = runTest {
val originalRefreshToken = "original_refresh_token"
val newRefreshToken = "new_refresh_token"

coEvery { cognitoIdentityProviderClientMock.getTokensFromRefreshToken(any()) } returns GetTokensFromRefreshTokenResponse {
authenticationResult = AuthenticationResultType {
this.accessToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VySWQiLCJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjoxNTE2MjM5MDIyfQ.XbPfbIHMI6arZ3Y922BhjWgQzWXcXNrz0ogtVhfEd2o"
this.idToken = "id_token"
this.refreshToken = newRefreshToken
this.expiresIn = 3600
}
}

val signedInData = mockSignedInData(
username = "username",
cognitoUserPoolTokens = CognitoUserPoolTokens(
accessToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VySWQiLCJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjoxNTE2MjM5MDIyfQ.XbPfbIHMI6arZ3Y922BhjWgQzWXcXNrz0ogtVhfEd2o",
idToken = "old_id",
refreshToken = originalRefreshToken,
expiration = 0
)
)

FetchAuthSessionCognitoActions.refreshUserPoolTokensAction(signedInData).execute(dispatcher, authEnvironment)

val event = capturedEvent.captured.shouldBeInstanceOf<RefreshSessionEvent>()
val refreshedData = event.eventType.shouldBeInstanceOf<RefreshSessionEvent.EventType.Refreshed>().signedInData
refreshedData.cognitoUserPoolTokens.refreshToken shouldBe newRefreshToken
}

@Test
fun `refreshUserPoolTokensAction falls back to original refresh token when rotation is not enabled`() = runTest {
val originalRefreshToken = "original_refresh_token"

coEvery { cognitoIdentityProviderClientMock.getTokensFromRefreshToken(any()) } returns GetTokensFromRefreshTokenResponse {
authenticationResult = AuthenticationResultType {
this.accessToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VySWQiLCJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjoxNTE2MjM5MDIyfQ.XbPfbIHMI6arZ3Y922BhjWgQzWXcXNrz0ogtVhfEd2o"
this.idToken = "id_token"
this.refreshToken = null
this.expiresIn = 3600
}
}

val signedInData = mockSignedInData(
username = "username",
cognitoUserPoolTokens = CognitoUserPoolTokens(
accessToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VySWQiLCJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjoxNTE2MjM5MDIyfQ.XbPfbIHMI6arZ3Y922BhjWgQzWXcXNrz0ogtVhfEd2o",
idToken = "old_id",
refreshToken = originalRefreshToken,
expiration = 0
)
)

FetchAuthSessionCognitoActions.refreshUserPoolTokensAction(signedInData).execute(dispatcher, authEnvironment)

val event = capturedEvent.captured.shouldBeInstanceOf<RefreshSessionEvent>()
val refreshedData = event.eventType.shouldBeInstanceOf<RefreshSessionEvent.EventType.Refreshed>().signedInData
refreshedData.cognitoUserPoolTokens.refreshToken shouldBe originalRefreshToken
}

@Test
fun `refreshUserPoolTokensAction handles NotAuthorizedException`() = runTest {
coEvery { cognitoIdentityProviderClientMock.getTokensFromRefreshToken(any()) } throws NotAuthorizedException {
message = "Token expired"
}

val signedInData = mockSignedInData(
username = "username",
cognitoUserPoolTokens = CognitoUserPoolTokens(
accessToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VySWQiLCJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjoxNTE2MjM5MDIyfQ.XbPfbIHMI6arZ3Y922BhjWgQzWXcXNrz0ogtVhfEd2o",
idToken = "old_id",
refreshToken = "refresh_token",
expiration = 0
)
)

FetchAuthSessionCognitoActions.refreshUserPoolTokensAction(signedInData).execute(dispatcher, authEnvironment)

capturedEvent.captured.shouldBeInstanceOf<AuthorizationEvent>()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ object FetchAuthSessionTestCaseGenerator : SerializableProvider {
TimeZone.setDefault(initialTimeZone)
}

private val mockedRefreshInitiateAuthResponse = MockResponse(
private val mockedRefreshGetTokensFromRefreshTokenResponse = MockResponse(
CognitoType.CognitoIdentityProvider,
"initiateAuth",
"getTokensFromRefreshToken",
ResponseType.Success,
mapOf(
"authenticationResult" to mapOf(
Expand Down Expand Up @@ -103,9 +103,9 @@ object FetchAuthSessionTestCaseGenerator : SerializableProvider {
}.toJsonElement()
)

private val mockedRefreshInitiateAuthFailureResponse = MockResponse(
private val mockedRefreshGetTokensFromRefreshTokenFailureResponse = MockResponse(
CognitoType.CognitoIdentityProvider,
"initiateAuth",
"getTokensFromRefreshToken",
ResponseType.Failure,
ResourceNotFoundException.invoke {
message = "Error type: Client, Protocol response: (empty response)"
Expand Down Expand Up @@ -225,7 +225,7 @@ object FetchAuthSessionTestCaseGenerator : SerializableProvider {
"authconfiguration.json",
"SignedIn_SessionEstablished.json",
mockedResponses = listOf(
mockedRefreshInitiateAuthResponse,
mockedRefreshGetTokensFromRefreshTokenResponse,
mockedRefreshGetIdResponse,
mockedRefreshGetAWSCredentialsResponse
)
Expand All @@ -243,7 +243,7 @@ object FetchAuthSessionTestCaseGenerator : SerializableProvider {
preConditions = PreConditions(
"authconfiguration.json",
"SignedIn_SessionEstablished.json",
mockedResponses = listOf(mockedRefreshInitiateAuthFailureResponse)
mockedResponses = listOf(mockedRefreshGetTokensFromRefreshTokenFailureResponse)
),
api = API(
name = AuthAPI.fetchAuthSession,
Expand All @@ -259,7 +259,7 @@ object FetchAuthSessionTestCaseGenerator : SerializableProvider {
"authconfiguration.json",
"SignedIn_SessionEstablished.json",
mockedResponses = listOf(
mockedRefreshInitiateAuthResponse,
mockedRefreshGetTokensFromRefreshTokenResponse,
mockedRefreshGetIdFailureResponse,
mockedRefreshGetAWSCredentialsFailureResponse
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import aws.sdk.kotlin.services.cognitoidentityprovider.model.RespondToAuthChalle
import aws.sdk.kotlin.services.cognitoidentityprovider.model.RevokeTokenResponse
import aws.sdk.kotlin.services.cognitoidentityprovider.model.SignUpResponse
import aws.sdk.kotlin.services.cognitoidentityprovider.model.UpdateDeviceStatusResponse
import aws.sdk.kotlin.services.cognitoidentityprovider.model.GetTokensFromRefreshTokenResponse
import aws.smithy.kotlin.runtime.time.Instant
import com.amplifyframework.auth.cognito.featuretest.CognitoType
import com.amplifyframework.auth.cognito.featuretest.MockResponse
Expand Down Expand Up @@ -228,6 +229,16 @@ class CognitoMockFactory(
}
}
}
"getTokensFromRefreshToken" -> {
coEvery { mockCognitoIPClient.getTokensFromRefreshToken(any()) } coAnswers {
setupError(mockResponse, responseObject)
GetTokensFromRefreshTokenResponse.invoke {
this.authenticationResult = responseObject["authenticationResult"]?.let {
parseAuthenticationResult(it as JsonObject)
}
}
}
}
else -> throw Error("mock for ${mockResponse.apiName} not defined!")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import aws.sdk.kotlin.services.cognitoidentityprovider.model.ConfirmSignUpReques
import aws.sdk.kotlin.services.cognitoidentityprovider.model.ForgotPasswordRequest
import aws.sdk.kotlin.services.cognitoidentityprovider.model.InitiateAuthRequest
import aws.sdk.kotlin.services.cognitoidentityprovider.model.SignUpRequest
import aws.sdk.kotlin.services.cognitoidentityprovider.model.GetTokensFromRefreshTokenRequest
import com.amplifyframework.auth.cognito.featuretest.ExpectationShapes
import com.amplifyframework.auth.cognito.helpers.AuthHelper
import kotlinx.serialization.json.Json
Expand Down Expand Up @@ -101,6 +102,20 @@ object CognitoRequestFactory {
SignUpRequest.invoke(expectedRequest)
}

"getTokensFromRefreshToken" -> {
val params = targetApi.request as JsonObject
val expectedRequestBuilder: GetTokensFromRefreshTokenRequest.Builder.() -> Unit = {
refreshToken = (params["refreshToken"] as JsonPrimitive).content
clientId = (params["clientId"] as JsonPrimitive).content
clientSecret = (params["clientSecret"] as? JsonPrimitive)?.content
deviceKey = (params["deviceKey"] as? JsonPrimitive)?.content
clientMetadata = params["clientMetadata"]?.let {
Json.decodeFromJsonElement<Map<String, String>>(it as JsonObject)
}
}
GetTokensFromRefreshTokenRequest.invoke(expectedRequestBuilder)
}

else -> error("Expected request for $targetApi for Cognito is not defined")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"mockedResponses": [
{
"type": "cognitoIdentityProvider",
"apiName": "initiateAuth",
"apiName": "getTokensFromRefreshToken",
"responseType": "success",
"response": {
"authenticationResult": {
Expand Down Expand Up @@ -39,8 +39,7 @@
},
"api": {
"name": "fetchAuthSession",
"params": {
},
"params": {},
"options": {
"forceRefresh": true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"mockedResponses": [
{
"type": "cognitoIdentityProvider",
"apiName": "initiateAuth",
"apiName": "getTokensFromRefreshToken",
"responseType": "failure",
"response": {
"errorType": "ResourceNotFoundException",
Expand All @@ -17,8 +17,7 @@
},
"api": {
"name": "fetchAuthSession",
"params": {
},
"params": {},
"options": {
"forceRefresh": true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"mockedResponses": [
{
"type": "cognitoIdentityProvider",
"apiName": "initiateAuth",
"apiName": "getTokensFromRefreshToken",
"responseType": "success",
"response": {
"authenticationResult": {
Expand Down
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ androidx-test-orchestrator = "1.4.2"
androidx-test-runner = "1.3.0"
androidx-workmanager = "2.9.1"
apollo = "4.3.1"
aws-kotlin = "1.5.0" # ensure proper aws-smithy version also set
aws-kotlin = "1.5.15" # ensure proper aws-smithy version also set
aws-sdk = "2.62.2"
aws-smithy = "1.5.1" # ensure proper aws-kotlin version also set
binary-compatibility-validator = "0.18.1"
Expand Down