Skip to content

Commit d53cfe6

Browse files
authored
Fix events test lockups (#3067)
1 parent b1844ee commit d53cfe6

File tree

4 files changed

+65
-36
lines changed

4 files changed

+65
-36
lines changed

appsync/aws-sdk-appsync-events/src/main/java/com/amazonaws/sdk/appsync/events/EventsWebSocket.kt

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,16 @@ import com.amazonaws.sdk.appsync.events.data.toEventsException
2525
import com.amazonaws.sdk.appsync.events.utils.ConnectionTimeoutTimer
2626
import com.amazonaws.sdk.appsync.events.utils.HeaderKeys
2727
import com.amazonaws.sdk.appsync.events.utils.HeaderValues
28+
import kotlinx.coroutines.CompletableDeferred
2829
import kotlinx.coroutines.CoroutineScope
2930
import kotlinx.coroutines.Dispatchers
3031
import kotlinx.coroutines.async
3132
import kotlinx.coroutines.coroutineScope
3233
import kotlinx.coroutines.flow.MutableSharedFlow
3334
import kotlinx.coroutines.flow.asSharedFlow
3435
import kotlinx.coroutines.flow.first
36+
import kotlinx.coroutines.flow.onCompletion
37+
import kotlinx.coroutines.flow.onStart
3538
import kotlinx.coroutines.withContext
3639
import kotlinx.serialization.encodeToString
3740
import kotlinx.serialization.json.Json
@@ -71,8 +74,12 @@ internal class EventsWebSocket(
7174
@Throws(ConnectException::class)
7275
suspend fun connect() = coroutineScope {
7376
logger?.debug("Opening Websocket Connection")
77+
7478
// Get deferred connect response. We need to listen before opening connection, but not block
75-
val deferredConnectResponse = async { getConnectResponse() }
79+
val listeningForConnectSignal = CompletableDeferred<Unit>()
80+
val deferredConnectResponse = async { getConnectResponse { listeningForConnectSignal.complete(Unit) } }
81+
listeningForConnectSignal.await()
82+
7683
// Create initial request without auth headers
7784
val preAuthRequest = createPreAuthConnectRequest(eventsEndpoints)
7885
// Fetch auth headers from authorizer
@@ -107,7 +114,11 @@ internal class EventsWebSocket(
107114
suspend fun disconnect(flushEvents: Boolean) = withContext(Dispatchers.IO) {
108115
if (isClosed) return@withContext
109116
disconnectReason = WebSocketDisconnectReason.UserInitiated
110-
val deferredClosedResponse = async { getClosedResponse() }
117+
118+
val listeningForClosedSignal = CompletableDeferred<Unit>()
119+
val deferredClosedResponse = async { getClosedResponse { listeningForClosedSignal.complete(Unit) } }
120+
listeningForClosedSignal.await()
121+
111122
when (flushEvents) {
112123
true -> webSocket.close(NORMAL_CLOSE_CODE, "User initiated disconnect")
113124
false -> webSocket.cancel()
@@ -231,21 +242,27 @@ internal class EventsWebSocket(
231242
}.build()
232243
}
233244

234-
private suspend fun getConnectResponse(): WebSocketMessage = events.first {
235-
when (it) {
236-
is WebSocketMessage.Received.ConnectionAck -> true
237-
is WebSocketMessage.Received.ConnectionError -> true
238-
is WebSocketMessage.Closed -> true
239-
else -> false
245+
private suspend fun getConnectResponse(onListening: () -> Unit): WebSocketMessage = events
246+
.onStart { onListening.invoke() }
247+
.onCompletion { onListening.invoke() } // just in case. No impact of calling onListening multiple times
248+
.first {
249+
when (it) {
250+
is WebSocketMessage.Received.ConnectionAck -> true
251+
is WebSocketMessage.Received.ConnectionError -> true
252+
is WebSocketMessage.Closed -> true
253+
else -> false
254+
}
240255
}
241-
}
242256

243-
private suspend fun getClosedResponse(): WebSocketMessage = events.first {
244-
when (it) {
245-
is WebSocketMessage.Closed -> true
246-
else -> false
257+
private suspend fun getClosedResponse(onListening: () -> Unit): WebSocketMessage = events
258+
.onStart { onListening.invoke() }
259+
.onCompletion { onListening.invoke() } // just in case. No impact of calling onListening multiple times
260+
.first {
261+
when (it) {
262+
is WebSocketMessage.Closed -> true
263+
else -> false
264+
}
247265
}
248-
}
249266
}
250267

251268
@VisibleForTesting

appsync/aws-sdk-appsync-events/src/main/java/com/amazonaws/sdk/appsync/events/EventsWebSocketClient.kt

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import com.amazonaws.sdk.appsync.events.data.UserClosedConnectionException
2323
import com.amazonaws.sdk.appsync.events.data.WebSocketMessage
2424
import com.amazonaws.sdk.appsync.events.data.toEventsException
2525
import com.amazonaws.sdk.appsync.events.utils.JsonUtils
26+
import kotlinx.coroutines.CompletableDeferred
2627
import java.util.UUID
2728
import kotlinx.coroutines.CoroutineDispatcher
2829
import kotlinx.coroutines.Dispatchers
@@ -168,7 +169,14 @@ class EventsWebSocketClient internal constructor(
168169
)
169170

170171
val webSocket = eventsWebSocketProvider.getConnectedWebSocket()
171-
val deferredResponse = async { getPublishResponse(webSocket, publishId) }
172+
173+
val listeningForPublishSignal = CompletableDeferred<Unit>()
174+
val deferredResponse = async {
175+
getPublishResponse(webSocket, publishId) {
176+
listeningForPublishSignal.complete(Unit)
177+
}
178+
}
179+
listeningForPublishSignal.await()
172180

173181
val queued = webSocket.sendWithAuthorizer(publishMessage, authorizer)
174182
if (!queued) {
@@ -226,8 +234,14 @@ class EventsWebSocketClient internal constructor(
226234
authorizer: AppSyncAuthorizer
227235
): Boolean = coroutineScope {
228236
// create a deferred holder for subscription response
237+
val listeningForSubscriptionSignal = CompletableDeferred<Unit>()
229238
val deferredSubscriptionResponse =
230-
async { getSubscriptionResponse(webSocket, subscriptionId) }
239+
async {
240+
getSubscriptionResponse(webSocket, subscriptionId) {
241+
listeningForSubscriptionSignal.complete(Unit)
242+
}
243+
}
244+
listeningForSubscriptionSignal.await()
231245

232246
// Publish subscription to websocket
233247
val queued = webSocket.sendWithAuthorizer(
@@ -262,8 +276,15 @@ class EventsWebSocketClient internal constructor(
262276
}
263277
}
264278

265-
private suspend fun getSubscriptionResponse(webSocket: EventsWebSocket, subscriptionId: String): WebSocketMessage =
266-
webSocket.events.first {
279+
private suspend fun getSubscriptionResponse(
280+
webSocket: EventsWebSocket,
281+
subscriptionId: String,
282+
onListening: () -> Unit
283+
): WebSocketMessage =
284+
webSocket.events
285+
.onStart { onListening.invoke() }
286+
.onCompletion { onListening.invoke() } // just in case. No impact of calling onListening multiple times
287+
.first {
267288
when {
268289
it is WebSocketMessage.Received.Subscription && it.id == subscriptionId -> true
269290
it is WebSocketMessage.ErrorContainer && it.id == subscriptionId -> true
@@ -272,8 +293,15 @@ class EventsWebSocketClient internal constructor(
272293
}
273294
}
274295

275-
private suspend fun getPublishResponse(webSocket: EventsWebSocket, publishId: String): WebSocketMessage =
276-
webSocket.events.first {
296+
private suspend fun getPublishResponse(
297+
webSocket: EventsWebSocket,
298+
publishId: String,
299+
onListening: () -> Unit
300+
): WebSocketMessage =
301+
webSocket.events
302+
.onStart { onListening.invoke() }
303+
.onCompletion { onListening.invoke() } // just in case. No impact of calling onListening multiple times
304+
.first {
277305
when {
278306
it is WebSocketMessage.Received.PublishSuccess && it.id == publishId -> true
279307
it is WebSocketMessage.ErrorContainer && it.id == publishId -> true

appsync/aws-sdk-appsync-events/src/test/java/com/amazonaws/sdk/appsync/events/EventsWebSocketClientTest.kt

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ import io.mockk.unmockkConstructor
3434
import kotlin.time.Duration
3535
import kotlinx.coroutines.CoroutineDispatcher
3636
import kotlinx.coroutines.coroutineScope
37-
import kotlinx.coroutines.delay
3837
import kotlinx.coroutines.flow.filter
3938
import kotlinx.coroutines.launch
4039
import kotlinx.coroutines.test.StandardTestDispatcher
@@ -117,7 +116,6 @@ internal class EventsWebSocketClientTest {
117116

118117
setupSendResult { _, _ ->
119118
launch {
120-
delay(1)
121119
websocketListenerSlot.captured.onClosed(websocket, 1000, "User initiated disconnect")
122120
}
123121
}
@@ -149,7 +147,6 @@ internal class EventsWebSocketClientTest {
149147
}
150148
""".trimIndent()
151149
backgroundScope.launch {
152-
delay(1)
153150
websocketListenerSlot.captured.onMessage(websocket, failedResult)
154151
}
155152
}
@@ -269,7 +266,6 @@ internal class EventsWebSocketClientTest {
269266

270267
setupSendResult { _, id ->
271268
launch {
272-
delay(1)
273269
websocketListenerSlot.captured.onClosed(websocket, 1000, "User initiated disconnect")
274270
}
275271
}
@@ -286,7 +282,6 @@ internal class EventsWebSocketClientTest {
286282
setupSendResult { _, id -> subscribeSuccessResult(id) }
287283
every { websocket.close(any(), any()) } answers {
288284
launch {
289-
delay(1)
290285
websocketListenerSlot.captured.onClosed(websocket, 1000, "User initiated disconnect")
291286
}
292287
true
@@ -311,7 +306,6 @@ internal class EventsWebSocketClientTest {
311306
setupSendResult { _, id -> subscribeSuccessResult(id) }
312307
every { websocket.cancel() } answers {
313308
launch {
314-
delay(1)
315309
websocketListenerSlot.captured.onFailure(websocket, Throwable("Cancelled"), null)
316310
}
317311
}
@@ -333,7 +327,6 @@ internal class EventsWebSocketClientTest {
333327
every { newWebSocket(any(), capture(websocketListenerSlot)) } answers {
334328
val ack = """ { "type": "connection_ack", "connectionTimeoutMs": 10000 } """
335329
backgroundScope.launch(testDispatcher) {
336-
delay(1)
337330
websocketListenerSlot.captured.onMessage(websocket, ack)
338331
}
339332
websocket
@@ -417,7 +410,6 @@ internal class EventsWebSocketClientTest {
417410
}
418411
""".trimIndent()
419412
backgroundScope.launch {
420-
delay(1)
421413
websocketListenerSlot.captured.onMessage(websocket, result)
422414
}
423415
}
@@ -430,7 +422,6 @@ internal class EventsWebSocketClientTest {
430422
}
431423
""".trimIndent()
432424
backgroundScope.launch {
433-
delay(1)
434425
websocketListenerSlot.captured.onMessage(websocket, result)
435426
}
436427
}
@@ -449,7 +440,6 @@ internal class EventsWebSocketClientTest {
449440
}
450441
""".trimIndent()
451442
backgroundScope.launch {
452-
delay(1)
453443
websocketListenerSlot.captured.onMessage(websocket, result)
454444
}
455445
}

appsync/aws-sdk-appsync-events/src/test/java/com/amazonaws/sdk/appsync/events/EventsWebSocketTest.kt

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import io.mockk.slot
2828
import io.mockk.verify
2929
import java.net.UnknownHostException
3030
import kotlinx.coroutines.coroutineScope
31-
import kotlinx.coroutines.delay
3231
import kotlinx.coroutines.launch
3332
import kotlinx.coroutines.test.runTest
3433
import kotlinx.serialization.json.JsonArray
@@ -80,7 +79,6 @@ internal class EventsWebSocketTest {
8079
val listener = arg<WebSocketListener>(1)
8180

8281
launch {
83-
delay(1) // on virtual timer, just moves to back of queue
8482
listener.onMessage(websocket, ack)
8583
}
8684
websocket
@@ -126,7 +124,6 @@ internal class EventsWebSocketTest {
126124

127125
every { websocket.close(any(), any()) } answers {
128126
launch {
129-
delay(1)
130127
eventsWebSocket.onClosed(websocket, 1000, "User initiated disconnect")
131128
}
132129
true
@@ -153,7 +150,6 @@ internal class EventsWebSocketTest {
153150

154151
every { websocket.cancel() } answers {
155152
launch {
156-
delay(1)
157153
eventsWebSocket.onClosed(websocket, 1000, "User initiated disconnect")
158154
}
159155
}
@@ -262,7 +258,6 @@ internal class EventsWebSocketTest {
262258
every { okHttpClient.newWebSocket(any(), any()) } answers {
263259
val listener = arg<WebSocketListener>(1)
264260
launch {
265-
delay(1) // on virtual timer, just moves to back of queue
266261
listener.onMessage(websocket, ack)
267262
}
268263
websocket
@@ -275,7 +270,6 @@ internal class EventsWebSocketTest {
275270
every { okHttpClient.newWebSocket(any(), any()) } answers {
276271
val listener = arg<WebSocketListener>(1)
277272
launch {
278-
delay(1) // on virtual timer, just moves to back of queue
279273
listener.onFailure(websocket, cause, null)
280274
}
281275
websocket

0 commit comments

Comments
 (0)