@@ -17,10 +17,11 @@ package com.amplifyframework.aws.appsync.events
1717
1818import com.amplifyframework.aws.appsync.core.AppSyncAuthorizer
1919import com.amplifyframework.aws.appsync.core.AppSyncRequest
20- import com.amplifyframework.aws.appsync.core.util.Logger
20+ import com.amplifyframework.aws.appsync.core.LoggerProvider
2121import com.amplifyframework.aws.appsync.events.data.ConnectException
2222import com.amplifyframework.aws.appsync.events.data.EventsException
2323import com.amplifyframework.aws.appsync.events.data.WebSocketMessage
24+ import com.amplifyframework.aws.appsync.events.utils.ConnectionTimeoutTimer
2425import com.amplifyframework.aws.appsync.events.utils.HeaderKeys
2526import com.amplifyframework.aws.appsync.events.utils.HeaderValues
2627import kotlinx.coroutines.async
@@ -35,22 +36,23 @@ import okhttp3.Request
3536import okhttp3.Response
3637import okhttp3.WebSocket
3738import okhttp3.WebSocketListener
38- import java.util.concurrent.atomic.AtomicBoolean
3939
4040internal class EventsWebSocket (
4141 private val eventsEndpoints : EventsEndpoints ,
4242 private val authorizer : AppSyncAuthorizer ,
4343 private val okHttpClient : OkHttpClient ,
4444 private val json : Json ,
45- private val logger : Logger ?
45+ loggerProvider : LoggerProvider ?
4646) : WebSocketListener() {
4747
4848 private val _events = MutableSharedFlow <WebSocketMessage >(extraBufferCapacity = Int .MAX_VALUE )
4949 val events = _events .asSharedFlow() // publicly exposed as read-only shared flow
5050
5151 private lateinit var webSocket: WebSocket
52- internal val isClosed = AtomicBoolean (false )
53- private var userInitiatedDisconnect = false
52+ @Volatile internal var isClosed = false
53+ private var disconnectReason: DisconnectReason ? = null
54+ private val connectionTimeoutTimer = ConnectionTimeoutTimer (onTimeout = ::onTimeout)
55+ private val logger = loggerProvider?.getLogger(TAG )
5456
5557 @Throws(ConnectException ::class )
5658 suspend fun connect () = coroutineScope {
@@ -71,7 +73,7 @@ internal class EventsWebSocket(
7173 when (val connectionResponse = deferredConnectResponse.await()) {
7274 is WebSocketMessage .Closed -> {
7375 webSocket.cancel()
74- throw ConnectException (connectionResponse.throwable)
76+ throw ConnectException (connectionResponse.reason. throwable)
7577 }
7678 is WebSocketMessage .Received .ConnectionError -> {
7779 webSocket.cancel()
@@ -80,13 +82,16 @@ internal class EventsWebSocket(
8082 ? : EventsException .unknown()
8183 )
8284 }
83- else -> Unit // It isn't obvious here, but only other connect response type is ConnectionAck
85+ is WebSocketMessage .Received .ConnectionAck -> {
86+ connectionTimeoutTimer.resetTimeoutTimer(connectionResponse.connectionTimeoutMs)
87+ }
88+ else -> Unit // Not obvious here but this block should never run
8489 }
8590 logger?.debug(" Websocket Connection Open" )
8691 }
8792
8893 suspend fun disconnect (flushEvents : Boolean ) = coroutineScope {
89- userInitiatedDisconnect = true
94+ disconnectReason = DisconnectReason . UserInitiated
9095 val deferredClosedResponse = async { getClosedResponse() }
9196 when (flushEvents) {
9297 true -> webSocket.close(NORMAL_CLOSE_CODE , " User initiated disconnect" )
@@ -96,12 +101,12 @@ internal class EventsWebSocket(
96101 }
97102
98103 override fun onOpen (webSocket : WebSocket , response : Response ) {
99- val connectionInitMessage = json.encodeToString(WebSocketMessage .Send .ConnectionInit ())
100- logger?.debug { " $TAG onOpen: sending connection init" }
101- webSocket.send(connectionInitMessage)
104+ logger?.debug (" onOpen: sending connection init" )
105+ send(WebSocketMessage .Send .ConnectionInit ())
102106 }
103107
104108 override fun onMessage (webSocket : WebSocket , text : String ) {
109+ connectionTimeoutTimer.resetTimeoutTimer()
105110 logger?.debug { " Websocket onMessage: $text " }
106111 try {
107112 val eventMessage = json.decodeFromString<WebSocketMessage .Received >(text)
@@ -112,29 +117,37 @@ internal class EventsWebSocket(
112117 }
113118
114119 override fun onFailure (webSocket : WebSocket , t : Throwable , response : Response ? ) {
115- logger?.error(t) { " $TAG onFailure" }
116- notifyClosed () // onClosed doesn't get called in failure. Treat this block the same as onClosed
120+ logger?.error(t) { " onFailure" }
121+ handleClosed () // onClosed doesn't get called in failure. Treat this block the same as onClosed
117122 }
118123
119124 override fun onClosing (webSocket : WebSocket , code : Int , reason : String ) {
120- logger?.debug(" $TAG onClosing" )
125+ logger?.debug(" onClosing" )
121126 }
122127
123128 override fun onClosed (webSocket : WebSocket , code : Int , reason : String ) {
124129 // Events api sends normal close code even in failure
125130 // so inspecting code/reason isn't helpful as it should be
126- logger?.debug(" $TAG onClosed: userInitiated = $userInitiatedDisconnect " )
127- notifyClosed()
131+ logger?.debug {" onClosed: reason = $disconnectReason " }
132+ handleClosed()
133+ }
134+
135+ private fun onTimeout () {
136+ disconnectReason = DisconnectReason .Timeout
137+ webSocket.cancel()
128138 }
129139
130- private fun notifyClosed () {
131- _events .tryEmit(WebSocketMessage .Closed (userInitiated = userInitiatedDisconnect))
132- isClosed.set(true )
140+ private fun handleClosed () {
141+ connectionTimeoutTimer.stop()
142+ _events .tryEmit(
143+ WebSocketMessage .Closed (reason = disconnectReason ? : DisconnectReason .Service ())
144+ )
145+ isClosed = true
133146 }
134147
135148 inline fun <reified T : WebSocketMessage > send (webSocketMessage : T ) {
136149 val message = json.encodeToString(webSocketMessage)
137- logger?.debug( " $TAG send: $message " )
150+ logger?.debug { " send: ${webSocketMessage:: class .java} " }
138151 webSocket.send(message)
139152 }
140153
@@ -193,3 +206,9 @@ private class ConnectAppSyncRequest(
193206 override val body: String
194207 get() = " {}"
195208}
209+
210+ internal sealed class DisconnectReason (val throwable : Throwable ? ) {
211+ data object UserInitiated : DisconnectReason (null )
212+ data object Timeout : DisconnectReason (EventsException ("Connection timed out."))
213+ class Service (throwable : Throwable ? = null ) : DisconnectReason(throwable)
214+ }
0 commit comments