Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package com.amplifyframework.aws.appsync.core.authorizers

import com.amplifyframework.aws.appsync.core.AppSyncRequest
import com.amplifyframework.aws.appsync.core.util.AppSyncRequestSigner
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.maps.shouldContainExactly
import io.mockk.coEvery
import io.mockk.mockk
Expand All @@ -38,4 +39,20 @@ class AmplifyIamAuthorizerTest {

authorizer.getAuthorizationHeaders(request) shouldContainExactly mapOf("Authorization" to "test-signature")
}

@Test
fun `iam authorizer throws if failed to fetch token from amplify`() = runTest {
val request = mockk<AppSyncRequest>()
val signer = mockk<AppSyncRequestSigner> {
coEvery {
signAppSyncRequest(request, region)
} throws IllegalStateException()
}

val authorizer = AmplifyIamAuthorizer(region, signer)

shouldThrow<IllegalStateException> {
authorizer.getAuthorizationHeaders(request)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package com.amplifyframework.aws.appsync.core.authorizers

import com.amplifyframework.auth.AuthCredentialsProvider
import com.amplifyframework.core.Consumer
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.maps.shouldContainExactly
import io.mockk.CapturingSlot
import io.mockk.every
Expand All @@ -39,4 +40,17 @@ class AmplifyUserPoolAuthorizerTest {

authorizer.getAuthorizationHeaders(mockk()) shouldContainExactly mapOf("Authorization" to expectedValue)
}

@Test
fun `user pool authorizer throws if failed to fetch token from amplify`() = runTest {
val cognitoCredentialsProvider = mockk<AuthCredentialsProvider> {
every { getAccessToken(any(), any()) } throws IllegalStateException()
}
val accessTokenProvider = AccessTokenProvider(cognitoCredentialsProvider)
val authorizer = AmplifyUserPoolAuthorizer(accessTokenProvider)

shouldThrow<IllegalStateException> {
authorizer.getAuthorizationHeaders(mockk())
}
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
/*
* Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package com.amplifyframework.aws.appsync.core.util

import aws.smithy.kotlin.runtime.InternalApi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,4 @@ internal object HeaderKeys {
const val AMAZON_DATE = "x-amz-date"
const val API_KEY = "x-api-key"
const val AUTHORIZATION = "Authorization"
const val HOST = "host"
const val ACCEPT = "accept"
const val CONTENT_TYPE = "content-type"
const val SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"
}

internal object HeaderValues {
const val ACCEPT_APPLICATION_JSON = "application/json, text/javascript"
const val CONTENT_TYPE_APPLICATION_JSON = "application/json; charset=UTF-8"
const val SEC_WEBSOCKET_PROTOCOL_APPSYNC_EVENTS = "aws-appsync-event-ws"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package com.amplifyframework.aws.appsync.core

import io.kotest.matchers.shouldBe
import org.junit.Test

class AppSyncRequestTest {

@Test
fun `test request implementation`() {
val testRequest = object : AppSyncRequest {
override val method = AppSyncRequest.HttpMethod.POST
override val url = "https://amazon.com"
override val headers = mapOf(
HeaderKeys.API_KEY to "123",
HeaderKeys.AUTHORIZATION to "345",
HeaderKeys.AMAZON_DATE to "2025"
)
override val body = "b"
}

testRequest.method shouldBe AppSyncRequest.HttpMethod.POST
testRequest.url shouldBe "https://amazon.com"
testRequest.headers shouldBe mapOf(
HeaderKeys.API_KEY to "123",
HeaderKeys.AUTHORIZATION to "345",
HeaderKeys.AMAZON_DATE to "2025"
)
testRequest.body shouldBe "b"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amplifyframework.aws.appsync.core

import io.kotest.matchers.collections.shouldHaveSize
import io.kotest.matchers.equals.shouldBeEqual
import org.junit.Test

// We only need to test the suppliers as the other logs levels don't react to thresholds. That is up to the
// class that implements Logger and writes the implementation for each log message type.
class LoggerTest {

private val errorLog = "error"
private val errorLogWithThrowable = Pair(errorLog, IllegalStateException())
private val warnLog = "warn"
private val warnLogWithThrowable = Pair(warnLog, IllegalStateException())
private val infoLog = "info"
private val debugLog = "debug"
private val verboseLog = "verbose"

@Test
fun `test suppliers with none threshold`() {
val logger = TestSupplierLogger(LogLevel.NONE)

writeTestLogs(logger)

logger.warnLogs shouldHaveSize 0
logger.warnLogs shouldHaveSize 0
logger.infoLogs shouldHaveSize 0
logger.debugLogs shouldHaveSize 0
logger.verboseLogs shouldHaveSize 0
}

@Test
fun `test suppliers with error threshold`() {
val logger = TestSupplierLogger(LogLevel.ERROR)

writeTestLogs(logger)

logger.errorLogs shouldBeEqual listOf(Pair(errorLog, null), errorLogWithThrowable)
logger.warnLogs shouldHaveSize 0
logger.infoLogs shouldHaveSize 0
logger.debugLogs shouldHaveSize 0
logger.verboseLogs shouldHaveSize 0
}

@Test
fun `test suppliers with warn threshold`() {
val logger = TestSupplierLogger(LogLevel.WARN)

writeTestLogs(logger)

logger.errorLogs shouldBeEqual listOf(Pair(errorLog, null), errorLogWithThrowable)
logger.warnLogs shouldBeEqual listOf(Pair(warnLog, null), warnLogWithThrowable)
logger.infoLogs shouldHaveSize 0
logger.debugLogs shouldHaveSize 0
logger.verboseLogs shouldHaveSize 0
}

@Test
fun `test suppliers with info threshold`() {
val logger = TestSupplierLogger(LogLevel.INFO)

writeTestLogs(logger)

logger.errorLogs shouldBeEqual listOf(Pair(errorLog, null), errorLogWithThrowable)
logger.warnLogs shouldBeEqual listOf(Pair(warnLog, null), warnLogWithThrowable)
logger.infoLogs shouldBeEqual listOf(infoLog)
logger.debugLogs shouldHaveSize 0
logger.verboseLogs shouldHaveSize 0
}

@Test
fun `test suppliers with debug threshold`() {
val logger = TestSupplierLogger(LogLevel.DEBUG)

writeTestLogs(logger)

logger.errorLogs shouldBeEqual listOf(Pair(errorLog, null), errorLogWithThrowable)
logger.warnLogs shouldBeEqual listOf(Pair(warnLog, null), warnLogWithThrowable)
logger.infoLogs shouldBeEqual listOf(infoLog)
logger.debugLogs shouldBeEqual listOf(debugLog)
logger.verboseLogs shouldHaveSize 0
}

@Test
fun `test suppliers with verbose threshold`() {
val logger = TestSupplierLogger(LogLevel.VERBOSE)

writeTestLogs(logger)

logger.errorLogs shouldBeEqual listOf(Pair(errorLog, null), errorLogWithThrowable)
logger.warnLogs shouldBeEqual listOf(Pair(warnLog, null), warnLogWithThrowable)
logger.infoLogs shouldBeEqual listOf(infoLog)
logger.debugLogs shouldBeEqual listOf(debugLog)
logger.verboseLogs shouldBeEqual listOf(verboseLog)
}

private fun writeTestLogs(logger: Logger) {
logger.error { errorLog }
logger.error(errorLogWithThrowable.second) { errorLogWithThrowable.first }
logger.warn { warnLog }
logger.warn(warnLogWithThrowable.second) { warnLogWithThrowable.first }
logger.info { infoLog }
logger.debug { debugLog }
logger.verbose { verboseLog }
}
}

private class TestSupplierLogger(override val thresholdLevel: LogLevel) : Logger {
val errorLogs = mutableListOf<Pair<String, Throwable?>>()
val warnLogs = mutableListOf<Pair<String, Throwable?>>()
val infoLogs = mutableListOf<String>()
val debugLogs = mutableListOf<String>()
val verboseLogs = mutableListOf<String>()

override fun error(message: String) {
errorLogs.add(Pair(message, null))
}

override fun error(message: String, error: Throwable?) {
errorLogs.add(Pair(message, error))
}

override fun warn(message: String) {
warnLogs.add(Pair(message, null))
}

override fun warn(message: String, issue: Throwable?) {
warnLogs.add(Pair(message, issue))
}

override fun info(message: String) {
infoLogs.add(message)
}

override fun debug(message: String) {
debugLogs.add(message)
}

override fun verbose(message: String) {
verboseLogs.add(message)
}
}
1 change: 1 addition & 0 deletions appsync/aws-appsync-events/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,6 @@ dependencies {
testImplementation(libs.test.mockk)
testImplementation(libs.test.kotlin.coroutines)
testImplementation(libs.test.kotest.assertions)
testImplementation(libs.test.kotest.assertions.json)
testImplementation(libs.test.mockwebserver)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ import com.amplifyframework.aws.appsync.events.data.ChannelAuthorizers
import com.amplifyframework.aws.appsync.events.data.EventsException
import com.amplifyframework.aws.appsync.events.data.PublishResult
import com.amplifyframework.aws.appsync.events.data.toEventsException
import kotlinx.coroutines.coroutineScope
import kotlinx.serialization.json.Json
import com.amplifyframework.aws.appsync.events.utils.JsonUtils
import kotlinx.serialization.json.JsonElement
import okhttp3.OkHttpClient

Expand Down Expand Up @@ -52,10 +51,7 @@ class Events(
* @param defaultChannelAuthorizers passed to created channels if not overridden.
*/

private val json = Json {
encodeDefaults = true
ignoreUnknownKeys = true
}
private val json = JsonUtils.createJsonForLibrary()
private val endpoints = EventsEndpoints(endpoint)
private val okHttpClient = OkHttpClient.Builder().apply {
options.okHttpConfigurationProvider?.applyConfiguration(this)
Expand Down Expand Up @@ -121,15 +117,15 @@ class Events(
fun channel(
channelName: String,
authorizers: ChannelAuthorizers = this.defaultChannelAuthorizers,
) = EventsChannel(channelName, authorizers, eventsWebSocketProvider)
) = EventsChannel(channelName, authorizers, eventsWebSocketProvider, json)

/**
* Method to disconnect from all channels.
*
* @param flushEvents set to true (default) to allow all pending publish calls to succeed before disconnecting.
* Setting to false will immediately disconnect, cancelling any in-progress or queued event publishes.
*/
suspend fun disconnect(flushEvents: Boolean = true): Unit = coroutineScope {
suspend fun disconnect(flushEvents: Boolean = true) {
eventsWebSocketProvider.existingWebSocket?.disconnect(flushEvents)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onStart
import kotlinx.coroutines.withContext
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonPrimitive
Expand All @@ -47,7 +49,8 @@ import kotlinx.serialization.json.JsonPrimitive
class EventsChannel internal constructor(
val name: String,
val authorizers: ChannelAuthorizers,
private val eventsWebSocketProvider: EventsWebSocketProvider
private val eventsWebSocketProvider: EventsWebSocketProvider,
private val json: Json
) {

/**
Expand Down Expand Up @@ -118,7 +121,7 @@ class EventsChannel internal constructor(
private suspend fun publishToWebSocket(
events: List<JsonElement>,
authorizer: AppSyncAuthorizer
): WebSocketMessage.Received.PublishSuccess = coroutineScope {
): WebSocketMessage.Received.PublishSuccess = withContext(Dispatchers.IO) {
val publishId = UUID.randomUUID().toString()
val publishMessage = WebSocketMessage.Send.Publish(
id = publishId,
Expand All @@ -134,7 +137,7 @@ class EventsChannel internal constructor(
throw webSocket.disconnectReason?.toCloseException() ?: ConnectionClosedException()
}

return@coroutineScope when (val response = deferredResponse.await()) {
return@withContext when (val response = deferredResponse.await()) {
is WebSocketMessage.Received.PublishSuccess -> {
response
}
Expand Down
Loading