Skip to content

Commit f64361a

Browse files
authored
fix(Predictions): Fix Liveness InvalidSignatureException (#2729)
1 parent 956ba11 commit f64361a

File tree

1 file changed

+90
-66
lines changed

1 file changed

+90
-66
lines changed

aws-predictions/src/main/java/com/amplifyframework/predictions/aws/http/LivenessWebSocket.kt

Lines changed: 90 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@ import java.util.Date
5252
import java.util.Locale
5353
import java.util.UUID
5454
import kotlinx.coroutines.CoroutineScope
55+
import kotlinx.coroutines.CoroutineStart
5556
import kotlinx.coroutines.Dispatchers
57+
import kotlinx.coroutines.Job
58+
import kotlinx.coroutines.channels.Channel
59+
import kotlinx.coroutines.channels.consumeEach
5660
import kotlinx.coroutines.launch
5761
import kotlinx.serialization.encodeToString
5862
import kotlinx.serialization.json.Json
@@ -98,6 +102,15 @@ internal class LivenessWebSocket(
98102
internal var clientStoppedSession = false
99103
val json = Json { ignoreUnknownKeys = true }
100104

105+
// Sending events to the websocket requires processing synchronously because we rely on proper ordered
106+
// prior signatures. When sending events, we send each of these events to an async queue to process 1 at a time.
107+
private val sendEventScope = CoroutineScope(Job() + Dispatchers.IO)
108+
private val sendEventQueueChannel = Channel<Job>(capacity = Channel.UNLIMITED).apply {
109+
sendEventScope.launch {
110+
consumeEach { it.join() }
111+
}
112+
}
113+
101114
@VisibleForTesting
102115
internal var webSocketListener = object : WebSocketListener() {
103116
override fun onOpen(webSocket: WebSocket, response: Response) {
@@ -409,76 +422,87 @@ internal class LivenessWebSocket(
409422
}
410423

411424
private fun sendClientInfoEvent(clientInfoEvent: ClientSessionInformationEvent) {
412-
credentials?.let {
413-
val jsonString = Json.encodeToString(clientInfoEvent)
414-
val jsonPayload = jsonString.encodeUtf8().toByteArray()
415-
val encodedPayload = LivenessEventStream.encode(
416-
jsonPayload,
417-
mapOf(
418-
":event-type" to "ClientSessionInformationEvent",
419-
":message-type" to "event",
420-
":content-type" to "application/json"
421-
)
422-
)
423-
val eventDate = Date(adjustedDate())
424-
val signedPayload = signer.getSignedFrame(
425-
region,
426-
encodedPayload.array(),
427-
it.secretAccessKey,
428-
Pair(":date", eventDate)
429-
)
430-
val signedPayloadBytes = signedPayload.chunked(2).map { hexChar -> hexChar.toInt(16).toByte() }
431-
.toByteArray()
432-
val encodedRequest = LivenessEventStream.encode(
433-
encodedPayload.array(),
434-
mapOf(
435-
":date" to eventDate,
436-
":chunk-signature" to signedPayloadBytes
437-
)
438-
)
425+
// Add event to send queue to ensure proper ordering of signatures
426+
sendEventQueueChannel.trySend(
427+
sendEventScope.launch(start = CoroutineStart.LAZY) {
428+
credentials?.let {
429+
val jsonString = Json.encodeToString(clientInfoEvent)
430+
val jsonPayload = jsonString.encodeUtf8().toByteArray()
431+
val encodedPayload = LivenessEventStream.encode(
432+
jsonPayload,
433+
mapOf(
434+
":event-type" to "ClientSessionInformationEvent",
435+
":message-type" to "event",
436+
":content-type" to "application/json"
437+
)
438+
)
439+
val eventDate = Date(adjustedDate())
440+
val signedPayload = signer.getSignedFrame(
441+
region,
442+
encodedPayload.array(),
443+
it.secretAccessKey,
444+
Pair(":date", eventDate)
445+
)
446+
val signedPayloadBytes = signedPayload.chunked(2).map { hexChar ->
447+
hexChar.toInt(16).toByte()
448+
}.toByteArray()
449+
val encodedRequest = LivenessEventStream.encode(
450+
encodedPayload.array(),
451+
mapOf(
452+
":date" to eventDate,
453+
":chunk-signature" to signedPayloadBytes
454+
)
455+
)
439456

440-
webSocket?.send(ByteString.of(*encodedRequest.array()))
441-
}
457+
webSocket?.send(ByteString.of(*encodedRequest.array()))
458+
}
459+
}
460+
)
442461
}
443462

444463
fun sendVideoEvent(videoBytes: ByteArray, videoEventTime: Long) {
445-
if (videoBytes.isNotEmpty()) {
446-
videoEndTimestamp = adjustedDate(videoEventTime)
447-
}
448-
credentials?.let {
449-
val videoBuffer = ByteBuffer.wrap(videoBytes)
450-
val videoEvent = VideoEvent(
451-
timestampMillis = adjustedDate(videoEventTime),
452-
videoChunk = videoBuffer
453-
)
454-
val videoJsonString = Json.encodeToString(videoEvent)
455-
val videoJsonPayload = videoJsonString.encodeUtf8().toByteArray()
456-
val encodedVideoPayload = LivenessEventStream.encode(
457-
videoJsonPayload,
458-
mapOf(
459-
":event-type" to "VideoEvent",
460-
":message-type" to "event",
461-
":content-type" to "application/json"
462-
)
463-
)
464-
val videoEventDate = Date(adjustedDate())
465-
val signedVideoPayload = signer.getSignedFrame(
466-
region,
467-
encodedVideoPayload.array(),
468-
it.secretAccessKey,
469-
Pair(":date", videoEventDate)
470-
)
471-
val signedVideoPayloadBytes = signedVideoPayload.chunked(2)
472-
.map { hexChar -> hexChar.toInt(16).toByte() }.toByteArray()
473-
val encodedVideoRequest = LivenessEventStream.encode(
474-
encodedVideoPayload.array(),
475-
mapOf(
476-
":date" to videoEventDate,
477-
":chunk-signature" to signedVideoPayloadBytes
478-
)
479-
)
480-
webSocket?.send(ByteString.of(*encodedVideoRequest.array()))
481-
}
464+
// Add event to send queue to ensure proper ordering of signatures
465+
sendEventQueueChannel.trySend(
466+
sendEventScope.launch(start = CoroutineStart.LAZY) {
467+
if (videoBytes.isNotEmpty()) {
468+
videoEndTimestamp = adjustedDate(videoEventTime)
469+
}
470+
credentials?.let {
471+
val videoBuffer = ByteBuffer.wrap(videoBytes)
472+
val videoEvent = VideoEvent(
473+
timestampMillis = adjustedDate(videoEventTime),
474+
videoChunk = videoBuffer
475+
)
476+
val videoJsonString = Json.encodeToString(videoEvent)
477+
val videoJsonPayload = videoJsonString.encodeUtf8().toByteArray()
478+
val encodedVideoPayload = LivenessEventStream.encode(
479+
videoJsonPayload,
480+
mapOf(
481+
":event-type" to "VideoEvent",
482+
":message-type" to "event",
483+
":content-type" to "application/json"
484+
)
485+
)
486+
val videoEventDate = Date(adjustedDate())
487+
val signedVideoPayload = signer.getSignedFrame(
488+
region,
489+
encodedVideoPayload.array(),
490+
it.secretAccessKey,
491+
Pair(":date", videoEventDate)
492+
)
493+
val signedVideoPayloadBytes = signedVideoPayload.chunked(2)
494+
.map { hexChar -> hexChar.toInt(16).toByte() }.toByteArray()
495+
val encodedVideoRequest = LivenessEventStream.encode(
496+
encodedVideoPayload.array(),
497+
mapOf(
498+
":date" to videoEventDate,
499+
":chunk-signature" to signedVideoPayloadBytes
500+
)
501+
)
502+
webSocket?.send(ByteString.of(*encodedVideoRequest.array()))
503+
}
504+
}
505+
)
482506
}
483507

484508
fun destroy(reasonCode: Int = NORMAL_SOCKET_CLOSURE_STATUS_CODE) {

0 commit comments

Comments
 (0)