Skip to content

Fixes Root cause of sharding race condition and allows driver to run on a custom port #2339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions maestro-cli/src/main/java/maestro/cli/App.kt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class App {
@Option(names = ["--port"], hidden = true)
var port: Int? = null

@Option(names = ["--driver-host-port"], hidden = true)
var driverHostPort: Int? = null

@Option(
names = ["--device", "--udid"],
description = ["(Optional) Device ID to run on explicitly, can be a comma separated list of IDs: --device \"Emulator_1,Emulator_2\" "],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class PrintHierarchyCommand : Runnable {
MaestroSessionManager.newSession(
host = parent?.host,
port = parent?.port,
driverHostPort = null,
driverHostPort = parent?.driverHostPort,
deviceId = parent?.deviceId,
platform = parent?.platform,
) { session ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class QueryCommand : Runnable {
MaestroSessionManager.newSession(
host = parent?.host,
port = parent?.port,
driverHostPort = null,
driverHostPort = parent?.driverHostPort,
deviceId = parent?.deviceId,
platform = parent?.platform,
) { session ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class RecordCommand : Callable<Int> {
return MaestroSessionManager.newSession(
host = parent?.host,
port = parent?.port,
driverHostPort = null,
driverHostPort = parent?.driverHostPort,
deviceId = deviceId,
platform = parent?.platform,
) { session ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class StartDeviceCommand : Callable<Int> {
PrintUtils.message(if (p == Platform.IOS) "Launching simulator..." else "Launching emulator...")
DeviceService.startDevice(
device = device,
driverHostPort = parent?.port
driverHostPort = parent?.driverHostPort,
)
}
} catch (e: LocaleValidationIosException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class StudioCommand : Callable<Int> {
MaestroSessionManager.newSession(
host = parent?.host,
port = parent?.port,
driverHostPort = null,
driverHostPort = parent?.driverHostPort,
deviceId = parent?.deviceId,
platform = parent?.platform,
isStudio = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ class TestCommand : Callable<Int> {
}

private fun selectPort(effectiveShards: Int): Int =
if (effectiveShards == 1) 7001
if (effectiveShards == 1) parent?.driverHostPort ?: 7001
else (7001..7128).shuffled().find { port ->
usedPorts.putIfAbsent(port, true) == null
} ?: error("No available ports found")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ object MaestroSessionManager {
val heartbeatFuture = executor.scheduleAtFixedRate(
{
try {
Thread.sleep(1000) // Add a 1-second delay here for fixing race condition
SessionStore.heartbeat(sessionId, selectedDevice.platform)
SessionStore.heartbeat(sessionId, selectedDevice.platform, selectedDevice.deviceId)
} catch (e: Exception) {
logger.error("Failed to record heartbeat", e)
}
Expand All @@ -85,19 +84,19 @@ object MaestroSessionManager {

val session = createMaestro(
selectedDevice = selectedDevice,
connectToExistingSession = SessionStore.hasActiveSessions(
sessionId,
selectedDevice.platform
connectToExistingSession = SessionStore.hasActiveSessionOnDevice(
selectedDevice.platform,
selectedDevice.deviceId
),
isStudio = isStudio,
isHeadless = isHeadless,
driverHostPort = driverHostPort,
)
Runtime.getRuntime().addShutdownHook(thread(start = false) {
heartbeatFuture.cancel(true)
SessionStore.delete(sessionId, selectedDevice.platform)
SessionStore.delete(sessionId, selectedDevice.platform, selectedDevice.deviceId)
runCatching { ScreenReporter.reportMaxDepth() }
if (SessionStore.activeSessions().isEmpty()) {
if (!SessionStore.hasActiveSessionOnDevice(selectedDevice.platform, selectedDevice.deviceId)) {
session.close()
}
})
Expand All @@ -124,6 +123,7 @@ object MaestroSessionManager {
return SelectedDevice(
platform = device.platform,
device = device,
deviceId = deviceId,
)
}

Expand Down
37 changes: 16 additions & 21 deletions maestro-cli/src/main/java/maestro/cli/session/SessionStore.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ object SessionStore {
)
}

fun heartbeat(sessionId: String, platform: Platform) {
fun heartbeat(sessionId: String, platform: Platform, deviceId: String?) {
synchronized(keyValueStore) {
keyValueStore.set(
key = key(sessionId, platform),
key = key(sessionId, platform, deviceId),
value = System.currentTimeMillis().toString(),
)

Expand All @@ -37,37 +37,32 @@ object SessionStore {
}
}

fun delete(sessionId: String, platform: Platform) {
fun delete(sessionId: String, platform: Platform, deviceId: String?) {
synchronized(keyValueStore) {
keyValueStore.delete(
key(sessionId, platform)
key(sessionId, platform, deviceId)
)
}
}

fun activeSessions(): List<String> {
fun hasActiveSessionOnDevice(
platform: Platform,
deviceId: String?
): Boolean {
synchronized(keyValueStore) {
return keyValueStore
.keys()
.filter { key ->
val lastHeartbeat = keyValueStore.get(key)?.toLongOrNull()
lastHeartbeat != null && System.currentTimeMillis() - lastHeartbeat < TimeUnit.SECONDS.toMillis(21)
}
return keyValueStore.keys().any { key ->
key.contains("${platform}_$deviceId") && isHeartbeatActive(key)
}
}
}

fun hasActiveSessions(
sessionId: String,
platform: Platform
): Boolean {
synchronized(keyValueStore) {
return activeSessions()
.any { it != key(sessionId, platform) }
}
private fun isHeartbeatActive(key: String): Boolean {
val lastHeartbeat = keyValueStore.get(key)?.toLongOrNull()
return lastHeartbeat != null && System.currentTimeMillis() - lastHeartbeat < TimeUnit.SECONDS.toMillis(21)
}

private fun key(sessionId: String, platform: Platform): String {
return "${platform}_$sessionId"
private fun key(sessionId: String, platform: Platform, deviceId: String?): String {
return "${platform}_${sessionId}_$deviceId"
}

}
Loading