Skip to content

Commit ecb8021

Browse files
tjleingThomas Leingtylerjroach
authored
fix(liveness): correct websocket retry logic (#2634)
Co-authored-by: Thomas Leing <[email protected]> Co-authored-by: tjroach <[email protected]>
1 parent 13d7fa0 commit ecb8021

File tree

4 files changed

+104
-53
lines changed

4 files changed

+104
-53
lines changed

aws-predictions/src/main/java/com/amplifyframework/predictions/aws/http/LivenessWebSocket.kt

Lines changed: 56 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import com.amplifyframework.predictions.aws.models.liveness.ColorDisplayed
3737
import com.amplifyframework.predictions.aws.models.liveness.FaceMovementAndLightClientChallenge
3838
import com.amplifyframework.predictions.aws.models.liveness.FreshnessColor
3939
import com.amplifyframework.predictions.aws.models.liveness.InitialFace
40+
import com.amplifyframework.predictions.aws.models.liveness.InvalidSignatureException
4041
import com.amplifyframework.predictions.aws.models.liveness.LivenessResponseStream
4142
import com.amplifyframework.predictions.aws.models.liveness.SessionInformation
4243
import com.amplifyframework.predictions.aws.models.liveness.TargetFace
@@ -78,23 +79,13 @@ internal class LivenessWebSocket(
7879
private val signer = AWSV4Signer()
7980
private var credentials: Credentials? = null
8081

81-
internal var offset = 0L
82-
internal enum class ReconnectState {
83-
INITIAL,
84-
RECONNECTING,
85-
RECONNECTING_AGAIN;
86-
87-
companion object {
88-
fun next(state: ReconnectState): ReconnectState {
89-
return when (state) {
90-
INITIAL -> RECONNECTING
91-
RECONNECTING -> RECONNECTING_AGAIN
92-
RECONNECTING_AGAIN -> RECONNECTING_AGAIN
93-
}
94-
}
95-
}
82+
// The reported time difference between the server and client. Only set if diff is higher than 4 minutes
83+
internal var timeDiffOffsetInMillis = 0L
84+
internal enum class ConnectionState {
85+
NORMAL,
86+
ATTEMPT_RECONNECT,
9687
}
97-
internal var reconnectState = ReconnectState.INITIAL
88+
internal var reconnectState = ConnectionState.NORMAL
9889

9990
@VisibleForTesting
10091
internal var webSocket: WebSocket? = null
@@ -103,7 +94,7 @@ internal class LivenessWebSocket(
10394
private var faceDetectedStart = 0L
10495
private var videoStartTimestamp = 0L
10596
private var videoEndTimestamp = 0L
106-
private var webSocketError: PredictionsException? = null
97+
@VisibleForTesting internal var webSocketError: PredictionsException? = null
10798
internal var clientStoppedSession = false
10899
val json = Json { ignoreUnknownKeys = true }
109100

@@ -119,15 +110,15 @@ internal class LivenessWebSocket(
119110
date.time - adjustedDate()
120111
} else 0
121112

122-
reconnectState = ReconnectState.next(reconnectState)
123-
// if offset is > 5 minutes, server will reject the request
124-
if (kotlin.math.abs(tempOffset) < FIVE_MINUTES) {
125-
super.onOpen(webSocket, response)
126-
this@LivenessWebSocket.webSocket = webSocket
127-
} else {
128-
// server will close this websocket, don't report that failure back
129-
offset = tempOffset
130-
start()
113+
super.onOpen(webSocket, response)
114+
115+
this@LivenessWebSocket.webSocket = webSocket
116+
117+
// If offset is > 4 minutes, server may reject the request
118+
// The real allowed diff from serer is < 5 but we check for 4 to add a buffer
119+
if (!isTimeDiffSafe(tempOffset)) {
120+
LOG.info("Server reported a time difference between client and server of > 4 minutes")
121+
timeDiffOffsetInMillis = tempOffset
131122
}
132123
}
133124

@@ -169,14 +160,29 @@ internal class LivenessWebSocket(
169160
override fun onClosed(webSocket: WebSocket, code: Int, reason: String) {
170161
LOG.debug("WebSocket onClosed")
171162
super.onClosed(webSocket, code, reason)
172-
if (reconnectState == ReconnectState.RECONNECTING) {
173-
// do nothing; we expected the server to close the connection
174-
} else if (code != NORMAL_SOCKET_CLOSURE_STATUS_CODE && !clientStoppedSession) {
175-
val faceLivenessException = webSocketError ?: PredictionsException(
176-
"An error occurred during the face liveness check.",
177-
reason
178-
)
179-
onErrorReceived.accept(faceLivenessException)
163+
if (code != NORMAL_SOCKET_CLOSURE_STATUS_CODE && !clientStoppedSession) {
164+
val recordedError = webSocketError
165+
166+
/*
167+
If the server reports an invalid signature due to a time difference between the local clock and the
168+
server clock, AND we haven't already tried to reconnect, then we should try to reconnect with an offset
169+
*/
170+
if (reconnectState == ConnectionState.NORMAL &&
171+
!isTimeDiffSafe(timeDiffOffsetInMillis) &&
172+
recordedError is PredictionsException &&
173+
recordedError.cause is InvalidSignatureException
174+
) {
175+
LOG.info("The server rejected the connection due to a likely time difference. Attempting reconnect")
176+
reconnectState = ConnectionState.ATTEMPT_RECONNECT
177+
webSocketError = null
178+
start()
179+
} else {
180+
val faceLivenessException = recordedError ?: PredictionsException(
181+
"An error occurred during the face liveness check.",
182+
reason
183+
)
184+
onErrorReceived.accept(faceLivenessException)
185+
}
180186
} else {
181187
onComplete.call()
182188
}
@@ -197,14 +203,6 @@ internal class LivenessWebSocket(
197203
}
198204

199205
fun start() {
200-
if (reconnectState == ReconnectState.RECONNECTING_AGAIN) {
201-
onErrorReceived.accept(
202-
PredictionsException(
203-
"Invalid device time",
204-
"Too many attempts were made to correct device time"
205-
)
206-
)
207-
}
208206
val userAgent = getUserAgent()
209207

210208
val okHttpClient = OkHttpClient.Builder()
@@ -312,6 +310,18 @@ internal class LivenessWebSocket(
312310
AccessDeniedException(
313311
cause = livenessResponse.accessDeniedException
314312
)
313+
} else if (livenessResponse.unrecognizedClientException != null) {
314+
PredictionsException(
315+
"Unrecognized client",
316+
livenessResponse.unrecognizedClientException,
317+
"Please check your credentials"
318+
)
319+
} else if (livenessResponse.invalidSignatureException != null) {
320+
PredictionsException(
321+
"Invalid signature",
322+
livenessResponse.invalidSignatureException,
323+
"Please check your credentials"
324+
)
315325
} else {
316326
PredictionsException(
317327
"An unknown error occurred during the Liveness flow.",
@@ -477,12 +487,14 @@ internal class LivenessWebSocket(
477487
}
478488

479489
fun adjustedDate(date: Long = Date().time): Long {
480-
return date + offset
490+
return date + timeDiffOffsetInMillis
481491
}
482492

493+
private fun isTimeDiffSafe(diffInMillis: Long) = kotlin.math.abs(diffInMillis) < FOUR_MINUTES
494+
483495
companion object {
484496
private const val NORMAL_SOCKET_CLOSURE_STATUS_CODE = 1000
485-
private val FIVE_MINUTES = 1000 * 60 * 5
497+
private const val FOUR_MINUTES = 1000 * 60 * 4
486498
@VisibleForTesting val datePattern = "EEE, d MMM yyyy HH:mm:ss z"
487499
private val LOG = Amplify.Logging.logger(CategoryType.PREDICTIONS, "amplify:aws-predictions")
488500
}

aws-predictions/src/main/java/com/amplifyframework/predictions/aws/models/liveness/LivenessResponseStream.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ internal data class LivenessResponseStream(
3030
@SerialName("ServiceUnavailableException") val serviceUnavailableException: ServiceUnavailableException? = null,
3131
@SerialName("SessionNotFoundException") val sessionNotFoundException: SessionNotFoundException? = null,
3232
@SerialName("AccessDeniedException") val accessDeniedException: AccessDeniedException? = null,
33-
@SerialName("InvalidSignatureException") val invalidSignatureException: InvalidSignatureException? = null
33+
@SerialName("InvalidSignatureException") val invalidSignatureException: InvalidSignatureException? = null,
34+
@SerialName("UnrecognizedClientException") val unrecognizedClientException: UnrecognizedClientException? = null
3435
)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
package com.amplifyframework.predictions.aws.models.liveness
16+
17+
import kotlinx.serialization.SerialName
18+
import kotlinx.serialization.Serializable
19+
20+
/**
21+
* Constructs a new UnrecognizedClientException with the specified error message.
22+
*
23+
* @param message Describes the error encountered.
24+
*/
25+
@Serializable
26+
internal data class UnrecognizedClientException(
27+
@SerialName("Message") override val message: String
28+
) : Exception(message)

aws-predictions/src/test/java/com/amplifyframework/predictions/aws/http/LivenessWebSocketTest.kt

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import com.amplifyframework.predictions.aws.models.liveness.ColorSequence
2828
import com.amplifyframework.predictions.aws.models.liveness.DisconnectionEvent
2929
import com.amplifyframework.predictions.aws.models.liveness.FaceMovementAndLightServerChallenge
3030
import com.amplifyframework.predictions.aws.models.liveness.FreshnessColor
31+
import com.amplifyframework.predictions.aws.models.liveness.InvalidSignatureException
3132
import com.amplifyframework.predictions.aws.models.liveness.LightChallengeType
3233
import com.amplifyframework.predictions.aws.models.liveness.OvalParameters
3334
import com.amplifyframework.predictions.aws.models.liveness.ServerChallenge
@@ -326,7 +327,9 @@ internal class LivenessWebSocketTest {
326327
fun `web socket detects clock skew from server response`() {
327328
val livenessWebSocket = createLivenessWebSocket()
328329
mockkConstructor(WebSocket::class)
329-
val socket: WebSocket = mockk()
330+
val socket: WebSocket = mockk {
331+
every { close(any(), any()) } returns true
332+
}
330333
livenessWebSocket.webSocket = socket
331334
val sdf = SimpleDateFormat(LivenessWebSocket.datePattern, Locale.US)
332335

@@ -339,7 +342,7 @@ internal class LivenessWebSocketTest {
339342
livenessWebSocket.webSocketListener.onOpen(socket, response)
340343

341344
// now we should restart the websocket with an adjusted time
342-
val openLatch = CountDownLatch(1)
345+
val openLatch = CountDownLatch(2)
343346
val latchingListener = LatchingWebSocketResponseListener(
344347
livenessWebSocket.webSocketListener,
345348
openLatch = openLatch
@@ -349,25 +352,32 @@ internal class LivenessWebSocketTest {
349352
server.enqueue(MockResponse().withWebSocketUpgrade(ServerWebSocketListener()))
350353
server.start()
351354

355+
livenessWebSocket.webSocketError = PredictionsException(
356+
"invalid signature",
357+
InvalidSignatureException("invalid signature"),
358+
"invalid signature"
359+
)
360+
livenessWebSocket.webSocketListener.onClosed(mockk(), 1011, "closing")
361+
352362
openLatch.await(3, TimeUnit.SECONDS)
353363

354364
assertTrue(livenessWebSocket.webSocket != null)
355-
val originalRequest = livenessWebSocket.webSocket!!.request()
365+
val reconnectRequest = livenessWebSocket.webSocket!!.request()
356366

357367
// make sure that followup request sends offset date
358368
val sdfGMT = SimpleDateFormat("yyyyMMdd'T'HHmmss'Z'", Locale.US)
359369
sdfGMT.timeZone = TimeZone.getTimeZone("GMT")
360-
val sentDate = originalRequest.url.queryParameter("X-Amz-Date") ?.let { sdfGMT.parse(it) }
370+
val sentDate = reconnectRequest.url.queryParameter("X-Amz-Date") ?.let { sdfGMT.parse(it) }
361371
val diff = abs(Date().time - sentDate?.time!!)
362372
assert(oneHour - 10000 < diff && diff < oneHour + 10000)
363373

364374
// also make sure that followup request is valid
365375
assertTrue(
366-
originalRequest.url.queryParameter("X-Amz-Credential")!!.endsWith("//rekognition/aws4_request")
376+
reconnectRequest.url.queryParameter("X-Amz-Credential")!!.endsWith("//rekognition/aws4_request")
367377
)
368-
assertEquals("299", originalRequest.url.queryParameter("X-Amz-Expires"))
369-
assertEquals("host", originalRequest.url.queryParameter("X-Amz-SignedHeaders"))
370-
assertEquals("AWS4-HMAC-SHA256", originalRequest.url.queryParameter("X-Amz-Algorithm"))
378+
assertEquals("299", reconnectRequest.url.queryParameter("X-Amz-Expires"))
379+
assertEquals("host", reconnectRequest.url.queryParameter("X-Amz-SignedHeaders"))
380+
assertEquals("AWS4-HMAC-SHA256", reconnectRequest.url.queryParameter("X-Amz-Algorithm"))
371381
}
372382

373383
@Test

0 commit comments

Comments
 (0)