Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
* permissions and limitations under the License.
*/

package com.amplifyframework.aws.appsync.core.util
package com.amplifyframework.aws.appsync.core

import java.util.function.Supplier

fun interface LoggerProvider {
fun getLogger(namespace: String): Logger
}

/**
* A component which can emit logs.
*/
Expand All @@ -28,12 +32,6 @@ interface Logger {
*/
val thresholdLevel: LogLevel

/**
* Gets the namespace of the logger.
* @return namespace for logger
*/
val namespace: String

/**
* Logs a message at the [LogLevel.ERROR] level.
* @param message An error message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
package com.amplifyframework.aws.appsync.events

import com.amplifyframework.aws.appsync.core.AppSyncAuthorizer
import com.amplifyframework.aws.appsync.core.util.Logger
import com.amplifyframework.aws.appsync.core.LoggerProvider
import com.amplifyframework.aws.appsync.events.data.ChannelAuthorizers
import com.amplifyframework.aws.appsync.events.data.EventsException
import com.amplifyframework.aws.appsync.events.data.PublishResult
Expand All @@ -41,7 +41,7 @@ class Events @VisibleForTesting internal constructor(
) {

data class Options(
val logger: Logger? = null
val loggerProvider: LoggerProvider? = null
)

/**
Expand Down Expand Up @@ -75,7 +75,7 @@ class Events @VisibleForTesting internal constructor(
connectAuthorizer,
okHttpClient,
json,
options.logger
options.loggerProvider
)

/**
Expand Down Expand Up @@ -133,6 +133,6 @@ class Events @VisibleForTesting internal constructor(
* @return a channel to manage subscriptions and publishes.
*/
suspend fun disconnect(flushEvents: Boolean = true): Unit = coroutineScope {
eventsWebSocketProvider.getExistingWebSocket()?.disconnect(flushEvents)
eventsWebSocketProvider.existingWebSocket?.disconnect(flushEvents)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ class EventsChannel internal constructor(
emit(EventsMessage(it.event))
}
it is WebSocketMessage.Closed -> {
if (it.userInitiated) {
if (it.reason is DisconnectReason.UserInitiated) {
throw UserClosedConnectionException()
} else {
throw ConnectionClosedException(it.throwable)
throw ConnectionClosedException(it.reason.throwable)
}
}
else -> Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ package com.amplifyframework.aws.appsync.events

import com.amplifyframework.aws.appsync.core.AppSyncAuthorizer
import com.amplifyframework.aws.appsync.core.AppSyncRequest
import com.amplifyframework.aws.appsync.core.util.Logger
import com.amplifyframework.aws.appsync.core.LoggerProvider
import com.amplifyframework.aws.appsync.events.data.ConnectException
import com.amplifyframework.aws.appsync.events.data.EventsException
import com.amplifyframework.aws.appsync.events.data.WebSocketMessage
import com.amplifyframework.aws.appsync.events.utils.ConnectionTimeoutTimer
import com.amplifyframework.aws.appsync.events.utils.HeaderKeys
import com.amplifyframework.aws.appsync.events.utils.HeaderValues
import kotlinx.coroutines.async
Expand All @@ -35,22 +36,23 @@ import okhttp3.Request
import okhttp3.Response
import okhttp3.WebSocket
import okhttp3.WebSocketListener
import java.util.concurrent.atomic.AtomicBoolean

internal class EventsWebSocket(
private val eventsEndpoints: EventsEndpoints,
private val authorizer: AppSyncAuthorizer,
private val okHttpClient: OkHttpClient,
private val json: Json,
private val logger: Logger?
loggerProvider: LoggerProvider?
) : WebSocketListener() {

private val _events = MutableSharedFlow<WebSocketMessage>(extraBufferCapacity = Int.MAX_VALUE)
val events = _events.asSharedFlow() // publicly exposed as read-only shared flow

private lateinit var webSocket: WebSocket
internal val isClosed = AtomicBoolean(false)
private var userInitiatedDisconnect = false
@Volatile internal var isClosed = false
private var disconnectReason: DisconnectReason? = null
private val connectionTimeoutTimer = ConnectionTimeoutTimer(onTimeout = ::onTimeout)
private val logger = loggerProvider?.getLogger(TAG)

@Throws(ConnectException::class)
suspend fun connect() = coroutineScope {
Expand All @@ -71,7 +73,7 @@ internal class EventsWebSocket(
when (val connectionResponse = deferredConnectResponse.await()) {
is WebSocketMessage.Closed -> {
webSocket.cancel()
throw ConnectException(connectionResponse.throwable)
throw ConnectException(connectionResponse.reason.throwable)
}
is WebSocketMessage.Received.ConnectionError -> {
webSocket.cancel()
Expand All @@ -80,13 +82,16 @@ internal class EventsWebSocket(
?: EventsException.unknown()
)
}
else -> Unit // It isn't obvious here, but only other connect response type is ConnectionAck
is WebSocketMessage.Received.ConnectionAck -> {
connectionTimeoutTimer.resetTimeoutTimer(connectionResponse.connectionTimeoutMs)
}
else -> Unit // Not obvious here but this block should never run
}
logger?.debug("Websocket Connection Open")
}

suspend fun disconnect(flushEvents: Boolean) = coroutineScope {
userInitiatedDisconnect = true
disconnectReason = DisconnectReason.UserInitiated
val deferredClosedResponse = async { getClosedResponse() }
when (flushEvents) {
true -> webSocket.close(NORMAL_CLOSE_CODE, "User initiated disconnect")
Expand All @@ -96,12 +101,12 @@ internal class EventsWebSocket(
}

override fun onOpen(webSocket: WebSocket, response: Response) {
val connectionInitMessage = json.encodeToString(WebSocketMessage.Send.ConnectionInit())
logger?.debug { "$TAG onOpen: sending connection init" }
webSocket.send(connectionInitMessage)
logger?.debug { "onOpen: sending connection init" }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Static logs like this can use direct function invocation instead of lambda form

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed for all the simple, non throwable providing log types.

send(WebSocketMessage.Send.ConnectionInit())
}

override fun onMessage(webSocket: WebSocket, text: String) {
connectionTimeoutTimer.resetTimeoutTimer()
logger?.debug { "Websocket onMessage: $text" }
try {
val eventMessage = json.decodeFromString<WebSocketMessage.Received>(text)
Expand All @@ -112,29 +117,37 @@ internal class EventsWebSocket(
}

override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
logger?.error(t) { "$TAG onFailure" }
notifyClosed() // onClosed doesn't get called in failure. Treat this block the same as onClosed
logger?.error(t) { "onFailure" }
handleClosed() // onClosed doesn't get called in failure. Treat this block the same as onClosed
}

override fun onClosing(webSocket: WebSocket, code: Int, reason: String) {
logger?.debug("$TAG onClosing")
logger?.debug("onClosing")
}

override fun onClosed(webSocket: WebSocket, code: Int, reason: String) {
// Events api sends normal close code even in failure
// so inspecting code/reason isn't helpful as it should be
logger?.debug("$TAG onClosed: userInitiated = $userInitiatedDisconnect")
notifyClosed()
logger?.debug("onClosed: reason = $disconnectReason")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logs with concatenation should probably be using lambda form

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

handleClosed()
}

private fun onTimeout() {
disconnectReason = DisconnectReason.Timeout
webSocket.cancel()
}

private fun notifyClosed() {
_events.tryEmit(WebSocketMessage.Closed(userInitiated = userInitiatedDisconnect))
isClosed.set(true)
private fun handleClosed() {
connectionTimeoutTimer.stop()
_events.tryEmit(
WebSocketMessage.Closed(reason = disconnectReason ?: DisconnectReason.Service())
)
isClosed = true
}

inline fun <reified T : WebSocketMessage> send(webSocketMessage: T) {
val message = json.encodeToString(webSocketMessage)
logger?.debug("$TAG send: $message")
logger?.debug("send: $message")
webSocket.send(message)
}

Expand Down Expand Up @@ -193,3 +206,9 @@ private class ConnectAppSyncRequest(
override val body: String
get() = "{}"
}

internal sealed class DisconnectReason(val throwable: Throwable?) {
data object UserInitiated : DisconnectReason(null)
data object Timeout : DisconnectReason(EventsException("Connection timed out."))
class Service(throwable: Throwable? = null) : DisconnectReason(throwable)
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
package com.amplifyframework.aws.appsync.events

import com.amplifyframework.aws.appsync.core.AppSyncAuthorizer
import com.amplifyframework.aws.appsync.core.util.Logger
import com.amplifyframework.aws.appsync.core.LoggerProvider
import java.util.concurrent.atomic.AtomicReference
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.async
Expand All @@ -31,56 +31,58 @@ internal class EventsWebSocketProvider(
private val authorizer: AppSyncAuthorizer,
private val okHttpClient: OkHttpClient,
private val json: Json,
private val logger: Logger?
private val loggerProvider: LoggerProvider?
) {
private val mutex = Mutex()
private val _connectResult = AtomicReference<Result<EventsWebSocket>?>(null)
private val _connectionInProgress = AtomicReference<Deferred<Result<EventsWebSocket>>?>(null)
private val connectionResultReference = AtomicReference<Result<EventsWebSocket>?>(null)
private val connectionInProgressReference = AtomicReference<Deferred<Result<EventsWebSocket>>?>(null)

val existingWebSocket: EventsWebSocket?
get() = connectionResultReference.get()?.getOrNull()

fun getExistingWebSocket(): EventsWebSocket? = _connectResult.get()?.getOrNull()

suspend fun getConnectedWebSocket(): EventsWebSocket = getConnectedWebSocketResult().getOrThrow()

private suspend fun getConnectedWebSocketResult(): Result<EventsWebSocket> = coroutineScope {
// If connection is already established, return it
mutex.withLock {
val existingResult = _connectResult.get()
val existingResult = connectionResultReference.get()
val existingWebSocket = existingResult?.getOrNull()
if (existingWebSocket != null) {
if (existingWebSocket.isClosed.get()) {
_connectResult.set(null)
if (existingWebSocket.isClosed) {
connectionResultReference.set(null)
} else {
return@coroutineScope existingResult
}
}
}

val deferredInProgressConnection = _connectionInProgress.get()
val deferredInProgressConnection = connectionInProgressReference.get()
if (deferredInProgressConnection != null && !deferredInProgressConnection.isCompleted) {
return@coroutineScope deferredInProgressConnection.await()
}

mutex.withLock {
val existingResultInLock = _connectResult.get()
val existingResultInLock = connectionResultReference.get()
val existingWebSocket = existingResultInLock?.getOrNull()
if (existingWebSocket != null) {
if (existingWebSocket.isClosed.get()) {
_connectResult.set(null)
if (existingWebSocket.isClosed) {
connectionResultReference.set(null)
} else {
return@coroutineScope existingResultInLock
}
}

val deferredInProgressConnectionInLock = _connectionInProgress.get()
val deferredInProgressConnectionInLock = connectionInProgressReference.get()
if (deferredInProgressConnectionInLock != null && !deferredInProgressConnectionInLock.isCompleted) {
return@coroutineScope deferredInProgressConnectionInLock.await()
}

val newDeferredInProgressConnection = async { attemptConnection() }
_connectionInProgress.set(newDeferredInProgressConnection)
connectionInProgressReference.set(newDeferredInProgressConnection)
val connectionResult = newDeferredInProgressConnection.await()
_connectResult.set(connectionResult)
_connectionInProgress.set(null)
connectionResultReference.set(connectionResult)
connectionInProgressReference.set(null)
connectionResult
}
}
Expand All @@ -92,7 +94,7 @@ internal class EventsWebSocketProvider(
authorizer,
okHttpClient,
json,
logger
loggerProvider
)
eventsWebSocket.connect()
Result.success(eventsWebSocket)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@

package com.amplifyframework.aws.appsync.events.data

import com.amplifyframework.aws.appsync.events.DisconnectReason
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonElement

@Serializable
Expand Down Expand Up @@ -51,7 +53,8 @@ internal sealed class WebSocketMessage {
internal data class Publish(
val id: String,
val channel: String,
val events: List<Boolean>
val events: JsonArray,
val authorization: Map<String, String>
) : Send() {
override val type = "publish"
}
Expand Down Expand Up @@ -97,22 +100,30 @@ internal sealed class WebSocketMessage {
override val id: String,
val errors: List<WebSocketError>
) : Subscription()

@Serializable @SerialName("publish_success")
internal data class PublishSuccess(
override val id: String,
@SerialName("successful") val successfulEvents: List<SuccessfulEvent>,
@SerialName("failed") val failedEvents: List<FailedEvent>
) : Subscription()
}

@Serializable @SerialName("error")
data class Error(val errors: List<WebSocketError>)
}

internal data class Closed(val userInitiated: Boolean, val throwable: Throwable? = null) : WebSocketMessage()
internal data class Closed(val reason: DisconnectReason) : WebSocketMessage()
}

@Serializable
data class WebSocketError(val errorType: String, val message: String? = null) {

// fallback message is only used if WebSocketError didn't provide a message
fun toEventsException(fallbackMessage: String? = null): EventsException {
val message = this.message ?: fallbackMessage
return when (errorType) {
"UnauthorizedException" -> UnauthorizedException(message ?: fallbackMessage)
"UnauthorizedException" -> UnauthorizedException(message)
else -> EventsException(message = "$errorType: $message")
}
}
Expand Down
Loading