Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ internal class EventsWebSocketClientTests {
"onOpen: sending connection init",
"onMessage: processed ${WebSocketMessage.Received.ConnectionAck::class.java}",
"onMessage: processed ${WebSocketMessage.Received.Subscription.SubscribeSuccess::class.java}",
"Successfully subscribed to: $defaultChannel",
"onMessage: processed ${WebSocketMessage.Received.Subscription.UnsubscribeSuccess::class.java}",
"emit ${WebSocketMessage.Closed::class.java}"
)
Expand All @@ -175,7 +176,7 @@ internal class EventsWebSocketClientTests {
webSocketClient.subscribe(defaultChannel).test(timeout = 5.seconds) {
// Wait for subscription to return success
webSocketLogCapture.messages.filter {
it == "onMessage: processed ${WebSocketMessage.Received.Subscription.SubscribeSuccess::class.java}"
it == "Successfully subscribed to: $defaultChannel"
}.testIn(backgroundScope, timeout = 5.seconds).apply {
awaitItem()
cancelAndIgnoreRemainingEvents()
Expand Down Expand Up @@ -217,8 +218,8 @@ internal class EventsWebSocketClientTests {
webSocketClient.subscribe(customChannel).test {
// Wait for subscription to return success
webSocketLogCapture.messages.filter {
it ==
"onMessage: processed ${WebSocketMessage.Received.Subscription.SubscribeSuccess::class.java}"
it == "Successfully subscribed to: $defaultChannel" ||
it == "Successfully subscribed to: $customChannel"
}.testIn(backgroundScope).apply {
awaitItem() // subscription 1
awaitItem() // subscription 2
Expand Down Expand Up @@ -271,7 +272,7 @@ internal class EventsWebSocketClientTests {
webSocketClient.subscribe(defaultChannel).test {
// Wait for subscription to return success
webSocketLogCapture.messages.filter {
it == "onMessage: processed ${WebSocketMessage.Received.Subscription.SubscribeSuccess::class.java}"
it == "Successfully subscribed to: $defaultChannel"
}.testIn(backgroundScope).apply {
awaitItem()
cancelAndIgnoreRemainingEvents()
Expand Down Expand Up @@ -304,7 +305,7 @@ internal class EventsWebSocketClientTests {
webSocketClient.subscribe(defaultChannel).test {
// Wait for subscription to return success
webSocketLogCapture.messages.filter {
it == "onMessage: processed ${WebSocketMessage.Received.Subscription.SubscribeSuccess::class.java}"
it == "Successfully subscribed to: $defaultChannel"
}.testIn(backgroundScope).apply {
awaitItem()
cancelAndIgnoreRemainingEvents()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class EventsWebSocketClient internal constructor(
private val ioDispatcher: CoroutineDispatcher = Dispatchers.IO
) {

companion object {
const val TAG = "EventsWebSocketClient"
}

private val okHttpClient = OkHttpClient.Builder().apply {
options.okHttpConfigurationProvider?.applyConfiguration(this)
}.build()
Expand All @@ -64,6 +68,8 @@ class EventsWebSocketClient internal constructor(
options.loggerProvider
)

private val logger = options.loggerProvider?.getLogger(TAG)

/**
* Subscribe to a channel.
*
Expand All @@ -80,17 +86,26 @@ class EventsWebSocketClient internal constructor(
val newWebSocket = eventsWebSocketProvider.getConnectedWebSocket()
subscriptionHolder.webSocket = newWebSocket
// + send subscription. Returns true if successfully subscribed
subscriptionHolder.isSubscribed = initiateSubscription(
val isSubscribed = initiateSubscription(
channelName,
newWebSocket,
subscriptionHolder.id,
authorizer
)

subscriptionHolder.subscriptionState = if (isSubscribed) {
logger?.debug("Successfully subscribed to: $channelName")
SubscriptionHolder.SubscriptionState.SUBSCRIBED
} else {
SubscriptionHolder.SubscriptionState.CLOSED
}
}.flowOn(ioDispatcher) // io used for authorizers to pull headers asynchronously
.onCompletion {
// only unsubscribe if already subscribed and websocket is still open
val currentWebSocket = subscriptionHolder.webSocket
if (subscriptionHolder.isSubscribed && currentWebSocket != null) {
if (subscriptionHolder.subscriptionState != SubscriptionHolder.SubscriptionState.CLOSED &&
currentWebSocket != null
) {
completeSubscription(subscriptionHolder, it)
}
subscriptionHolder.webSocket = null
Expand Down Expand Up @@ -276,10 +291,10 @@ class EventsWebSocketClient internal constructor(
private fun completeSubscription(subscriptionHolder: SubscriptionHolder, throwable: Throwable?) {
// only unsubscribe if already subscribed and websocket is still open
val currentWebSocket = subscriptionHolder.webSocket
val isSubscribed = subscriptionHolder.isSubscribed
val isSubscriptionClosed = subscriptionHolder.subscriptionState == SubscriptionHolder.SubscriptionState.CLOSED
val isDisconnected = throwable is ConnectionClosedException || throwable is UserClosedConnectionException

if (currentWebSocket != null && isSubscribed && !isDisconnected) {
if (currentWebSocket != null && !isSubscriptionClosed && !isDisconnected) {
// Unsubscribe from channel when flow is completed
try {
currentWebSocket.send(
Expand All @@ -298,9 +313,15 @@ class EventsWebSocketClient internal constructor(
*/
internal data class SubscriptionHolder(
var webSocket: EventsWebSocket? = null,
var isSubscribed: Boolean = false
var subscriptionState: SubscriptionState = SubscriptionState.PENDING
) {
val id = UUID.randomUUID().toString()

enum class SubscriptionState {
PENDING,
SUBSCRIBED,
CLOSED
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import app.cash.turbine.turbineScope
import com.amazonaws.sdk.appsync.events.data.BadRequestException
import com.amazonaws.sdk.appsync.events.data.ConnectionClosedException
import com.amazonaws.sdk.appsync.events.data.PublishResult
import com.amazonaws.sdk.appsync.events.data.WebSocketMessage
import com.amazonaws.sdk.appsync.events.mocks.EventsLibraryLogCapture
import com.amazonaws.sdk.appsync.events.mocks.TestAuthorizer
import io.kotest.assertions.fail
Expand Down Expand Up @@ -187,7 +186,7 @@ internal class EventsWebSocketClientTest {

client.subscribe(channel).test {
webSocketLogCapture.messages.filter {
it == "emit ${WebSocketMessage.Received.Subscription.SubscribeSuccess::class.java}"
it == "Successfully subscribed to: $channel"
}.testIn(backgroundScope).apply {
awaitItem()
cancelAndIgnoreRemainingEvents()
Expand Down Expand Up @@ -229,7 +228,7 @@ internal class EventsWebSocketClientTest {

client.subscribe(channel, customAuthorizer).test {
webSocketLogCapture.messages.filter {
it == "emit ${WebSocketMessage.Received.Subscription.SubscribeSuccess::class.java}"
it == "Successfully subscribed to: $channel"
}.testIn(backgroundScope).apply {
awaitItem()
cancelAndIgnoreRemainingEvents()
Expand Down Expand Up @@ -295,7 +294,7 @@ internal class EventsWebSocketClientTest {

client.subscribe(channel).test {
webSocketLogCapture.messages.filter {
it == "emit ${WebSocketMessage.Received.Subscription.SubscribeSuccess::class.java}"
it == "Successfully subscribed to: $channel"
}.testIn(backgroundScope).apply {
awaitItem()
cancelAndIgnoreRemainingEvents()
Expand All @@ -319,7 +318,7 @@ internal class EventsWebSocketClientTest {

client.subscribe(channel).test {
webSocketLogCapture.messages.filter {
it == "emit ${WebSocketMessage.Received.Subscription.SubscribeSuccess::class.java}"
it == "Successfully subscribed to: $channel"
}.testIn(backgroundScope).apply {
awaitItem()
cancelAndIgnoreRemainingEvents()
Expand Down