Skip to content

Commit 52f935d

Browse files
model_replay: check modelExecutionTime (commaai#34457)
* metric * fix * format * table * test failure * cleanup * 3 * 4
1 parent 645418e commit 52f935d

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

selfdrive/test/process_replay/model_replay.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import matplotlib.pyplot as plt
1010
import numpy as np
11+
from tabulate import tabulate
1112

1213
from openpilot.common.git import get_commit
1314
from openpilot.system.hardware import PC
@@ -30,6 +31,11 @@
3031
MODEL_REPLAY_BUCKET="model_replay_master"
3132
GITHUB = GithubUtils(API_TOKEN, DATA_TOKEN)
3233

34+
EXEC_TIMINGS = [
35+
# model, instant max, average max
36+
("modelV2", 0.03, 0.025),
37+
("driverStateV2", 0.02, 0.015),
38+
]
3339

3440
def get_log_fn(test_route, ref="master"):
3541
return f"{test_route}_model_tici_{ref}.zst"
@@ -156,7 +162,33 @@ def model_replay(lr, frs):
156162
del frs['roadCameraState'].frames
157163
del frs['wideRoadCameraState'].frames
158164
dmonitoringmodeld_msgs = replay_process(dmonitoringmodeld, dmodeld_logs, frs)
159-
return modeld_msgs + dmonitoringmodeld_msgs
165+
166+
msgs = modeld_msgs + dmonitoringmodeld_msgs
167+
168+
header = ['model', 'max instant', 'max instant allowed', 'average', 'max average allowed', 'test result']
169+
rows = []
170+
timings_ok = True
171+
for (s, instant_max, avg_max) in EXEC_TIMINGS:
172+
ts = [getattr(m, s).modelExecutionTime for m in msgs if m.which() == s]
173+
# TODO some init can happen in first iteration
174+
ts = ts[1:]
175+
176+
errors = []
177+
if np.max(ts) > instant_max:
178+
errors.append("❌ FAILED MAX TIMING CHECK ❌")
179+
if np.mean(ts) > avg_max:
180+
errors.append("❌ FAILED AVG TIMING CHECK ❌")
181+
182+
timings_ok = not errors and timings_ok
183+
rows.append([s, np.max(ts), instant_max, np.mean(ts), avg_max, "\n".join(errors) or "✅"])
184+
185+
print("------------------------------------------------")
186+
print("----------------- Model Timing -----------------")
187+
print("------------------------------------------------")
188+
print(tabulate(rows, header, tablefmt="simple_grid", stralign="center", numalign="center", floatfmt=".4f"))
189+
assert timings_ok
190+
191+
return msgs
160192

161193

162194
def get_frames():

0 commit comments

Comments
 (0)