Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@ import com.amazonaws.sdk.appsync.events.data.toEventsException
import com.amazonaws.sdk.appsync.events.utils.ConnectionTimeoutTimer
import com.amazonaws.sdk.appsync.events.utils.HeaderKeys
import com.amazonaws.sdk.appsync.events.utils.HeaderValues
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.asSharedFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onStart
import kotlinx.coroutines.withContext
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
Expand Down Expand Up @@ -71,8 +74,12 @@ internal class EventsWebSocket(
@Throws(ConnectException::class)
suspend fun connect() = coroutineScope {
logger?.debug("Opening Websocket Connection")

// Get deferred connect response. We need to listen before opening connection, but not block
val deferredConnectResponse = async { getConnectResponse() }
val listeningForConnectSignal = CompletableDeferred<Unit>()
val deferredConnectResponse = async { getConnectResponse { listeningForConnectSignal.complete(Unit) } }
listeningForConnectSignal.await()

// Create initial request without auth headers
val preAuthRequest = createPreAuthConnectRequest(eventsEndpoints)
// Fetch auth headers from authorizer
Expand Down Expand Up @@ -107,7 +114,11 @@ internal class EventsWebSocket(
suspend fun disconnect(flushEvents: Boolean) = withContext(Dispatchers.IO) {
if (isClosed) return@withContext
disconnectReason = WebSocketDisconnectReason.UserInitiated
val deferredClosedResponse = async { getClosedResponse() }

val listeningForClosedSignal = CompletableDeferred<Unit>()
val deferredClosedResponse = async { getClosedResponse { listeningForClosedSignal.complete(Unit) } }
listeningForClosedSignal.await()

when (flushEvents) {
true -> webSocket.close(NORMAL_CLOSE_CODE, "User initiated disconnect")
false -> webSocket.cancel()
Expand Down Expand Up @@ -231,21 +242,27 @@ internal class EventsWebSocket(
}.build()
}

private suspend fun getConnectResponse(): WebSocketMessage = events.first {
when (it) {
is WebSocketMessage.Received.ConnectionAck -> true
is WebSocketMessage.Received.ConnectionError -> true
is WebSocketMessage.Closed -> true
else -> false
private suspend fun getConnectResponse(onListening: () -> Unit): WebSocketMessage = events
.onStart { onListening.invoke() }
.onCompletion { onListening.invoke() } // just in case. No impact of calling onListening multiple times
.first {
when (it) {
is WebSocketMessage.Received.ConnectionAck -> true
is WebSocketMessage.Received.ConnectionError -> true
is WebSocketMessage.Closed -> true
else -> false
}
}
}

private suspend fun getClosedResponse(): WebSocketMessage = events.first {
when (it) {
is WebSocketMessage.Closed -> true
else -> false
private suspend fun getClosedResponse(onListening: () -> Unit): WebSocketMessage = events
.onStart { onListening.invoke() }
.onCompletion { onListening.invoke() } // just in case. No impact of calling onListening multiple times
.first {
when (it) {
is WebSocketMessage.Closed -> true
else -> false
}
}
}
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.amazonaws.sdk.appsync.events.data.UserClosedConnectionException
import com.amazonaws.sdk.appsync.events.data.WebSocketMessage
import com.amazonaws.sdk.appsync.events.data.toEventsException
import com.amazonaws.sdk.appsync.events.utils.JsonUtils
import kotlinx.coroutines.CompletableDeferred
import java.util.UUID
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
Expand Down Expand Up @@ -168,7 +169,14 @@ class EventsWebSocketClient internal constructor(
)

val webSocket = eventsWebSocketProvider.getConnectedWebSocket()
val deferredResponse = async { getPublishResponse(webSocket, publishId) }

val listeningForPublishSignal = CompletableDeferred<Unit>()
val deferredResponse = async {
getPublishResponse(webSocket, publishId) {
listeningForPublishSignal.complete(Unit)
}
}
listeningForPublishSignal.await()

val queued = webSocket.sendWithAuthorizer(publishMessage, authorizer)
if (!queued) {
Expand Down Expand Up @@ -226,8 +234,14 @@ class EventsWebSocketClient internal constructor(
authorizer: AppSyncAuthorizer
): Boolean = coroutineScope {
// create a deferred holder for subscription response
val listeningForSubscriptionSignal = CompletableDeferred<Unit>()
val deferredSubscriptionResponse =
async { getSubscriptionResponse(webSocket, subscriptionId) }
async {
getSubscriptionResponse(webSocket, subscriptionId) {
listeningForSubscriptionSignal.complete(Unit)
}
}
listeningForSubscriptionSignal.await()

// Publish subscription to websocket
val queued = webSocket.sendWithAuthorizer(
Expand Down Expand Up @@ -262,8 +276,15 @@ class EventsWebSocketClient internal constructor(
}
}

private suspend fun getSubscriptionResponse(webSocket: EventsWebSocket, subscriptionId: String): WebSocketMessage =
webSocket.events.first {
private suspend fun getSubscriptionResponse(
webSocket: EventsWebSocket,
subscriptionId: String,
onListening: () -> Unit
): WebSocketMessage =
webSocket.events
.onStart { onListening.invoke() }
.onCompletion { onListening.invoke() } // just in case. No impact of calling onListening multiple times
.first {
when {
it is WebSocketMessage.Received.Subscription && it.id == subscriptionId -> true
it is WebSocketMessage.ErrorContainer && it.id == subscriptionId -> true
Expand All @@ -272,8 +293,15 @@ class EventsWebSocketClient internal constructor(
}
}

private suspend fun getPublishResponse(webSocket: EventsWebSocket, publishId: String): WebSocketMessage =
webSocket.events.first {
private suspend fun getPublishResponse(
webSocket: EventsWebSocket,
publishId: String,
onListening: () -> Unit
): WebSocketMessage =
webSocket.events
.onStart { onListening.invoke() }
.onCompletion { onListening.invoke() } // just in case. No impact of calling onListening multiple times
.first {
when {
it is WebSocketMessage.Received.PublishSuccess && it.id == publishId -> true
it is WebSocketMessage.ErrorContainer && it.id == publishId -> true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import io.mockk.unmockkConstructor
import kotlin.time.Duration
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.launch
import kotlinx.coroutines.test.StandardTestDispatcher
Expand Down Expand Up @@ -117,7 +116,6 @@ internal class EventsWebSocketClientTest {

setupSendResult { _, _ ->
launch {
delay(1)
websocketListenerSlot.captured.onClosed(websocket, 1000, "User initiated disconnect")
}
}
Expand Down Expand Up @@ -149,7 +147,6 @@ internal class EventsWebSocketClientTest {
}
""".trimIndent()
backgroundScope.launch {
delay(1)
websocketListenerSlot.captured.onMessage(websocket, failedResult)
}
}
Expand Down Expand Up @@ -269,7 +266,6 @@ internal class EventsWebSocketClientTest {

setupSendResult { _, id ->
launch {
delay(1)
websocketListenerSlot.captured.onClosed(websocket, 1000, "User initiated disconnect")
}
}
Expand All @@ -286,7 +282,6 @@ internal class EventsWebSocketClientTest {
setupSendResult { _, id -> subscribeSuccessResult(id) }
every { websocket.close(any(), any()) } answers {
launch {
delay(1)
websocketListenerSlot.captured.onClosed(websocket, 1000, "User initiated disconnect")
}
true
Expand All @@ -311,7 +306,6 @@ internal class EventsWebSocketClientTest {
setupSendResult { _, id -> subscribeSuccessResult(id) }
every { websocket.cancel() } answers {
launch {
delay(1)
websocketListenerSlot.captured.onFailure(websocket, Throwable("Cancelled"), null)
}
}
Expand All @@ -333,7 +327,6 @@ internal class EventsWebSocketClientTest {
every { newWebSocket(any(), capture(websocketListenerSlot)) } answers {
val ack = """ { "type": "connection_ack", "connectionTimeoutMs": 10000 } """
backgroundScope.launch(testDispatcher) {
delay(1)
websocketListenerSlot.captured.onMessage(websocket, ack)
}
websocket
Expand Down Expand Up @@ -417,7 +410,6 @@ internal class EventsWebSocketClientTest {
}
""".trimIndent()
backgroundScope.launch {
delay(1)
websocketListenerSlot.captured.onMessage(websocket, result)
}
}
Expand All @@ -430,7 +422,6 @@ internal class EventsWebSocketClientTest {
}
""".trimIndent()
backgroundScope.launch {
delay(1)
websocketListenerSlot.captured.onMessage(websocket, result)
}
}
Expand All @@ -449,7 +440,6 @@ internal class EventsWebSocketClientTest {
}
""".trimIndent()
backgroundScope.launch {
delay(1)
websocketListenerSlot.captured.onMessage(websocket, result)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import io.mockk.slot
import io.mockk.verify
import java.net.UnknownHostException
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.test.runTest
import kotlinx.serialization.json.JsonArray
Expand Down Expand Up @@ -80,7 +79,6 @@ internal class EventsWebSocketTest {
val listener = arg<WebSocketListener>(1)

launch {
delay(1) // on virtual timer, just moves to back of queue
listener.onMessage(websocket, ack)
}
websocket
Expand Down Expand Up @@ -126,7 +124,6 @@ internal class EventsWebSocketTest {

every { websocket.close(any(), any()) } answers {
launch {
delay(1)
eventsWebSocket.onClosed(websocket, 1000, "User initiated disconnect")
}
true
Expand All @@ -153,7 +150,6 @@ internal class EventsWebSocketTest {

every { websocket.cancel() } answers {
launch {
delay(1)
eventsWebSocket.onClosed(websocket, 1000, "User initiated disconnect")
}
}
Expand Down Expand Up @@ -262,7 +258,6 @@ internal class EventsWebSocketTest {
every { okHttpClient.newWebSocket(any(), any()) } answers {
val listener = arg<WebSocketListener>(1)
launch {
delay(1) // on virtual timer, just moves to back of queue
listener.onMessage(websocket, ack)
}
websocket
Expand All @@ -275,7 +270,6 @@ internal class EventsWebSocketTest {
every { okHttpClient.newWebSocket(any(), any()) } answers {
val listener = arg<WebSocketListener>(1)
launch {
delay(1) // on virtual timer, just moves to back of queue
listener.onFailure(websocket, cause, null)
}
websocket
Expand Down