|
8 | 8 |
|
9 | 9 | import matplotlib.pyplot as plt
|
10 | 10 | import numpy as np
|
| 11 | +from tabulate import tabulate |
11 | 12 |
|
12 | 13 | from openpilot.common.git import get_commit
|
13 | 14 | from openpilot.system.hardware import PC
|
|
30 | 31 | MODEL_REPLAY_BUCKET="model_replay_master"
|
31 | 32 | GITHUB = GithubUtils(API_TOKEN, DATA_TOKEN)
|
32 | 33 |
|
| 34 | +EXEC_TIMINGS = [ |
| 35 | + # model, instant max, average max |
| 36 | + ("modelV2", 0.03, 0.025), |
| 37 | + ("driverStateV2", 0.02, 0.015), |
| 38 | +] |
33 | 39 |
|
34 | 40 | def get_log_fn(test_route, ref="master"):
|
35 | 41 | return f"{test_route}_model_tici_{ref}.zst"
|
@@ -156,7 +162,33 @@ def model_replay(lr, frs):
|
156 | 162 | del frs['roadCameraState'].frames
|
157 | 163 | del frs['wideRoadCameraState'].frames
|
158 | 164 | dmonitoringmodeld_msgs = replay_process(dmonitoringmodeld, dmodeld_logs, frs)
|
159 |
| - return modeld_msgs + dmonitoringmodeld_msgs |
| 165 | + |
| 166 | + msgs = modeld_msgs + dmonitoringmodeld_msgs |
| 167 | + |
| 168 | + header = ['model', 'max instant', 'max instant allowed', 'average', 'max average allowed', 'test result'] |
| 169 | + rows = [] |
| 170 | + timings_ok = True |
| 171 | + for (s, instant_max, avg_max) in EXEC_TIMINGS: |
| 172 | + ts = [getattr(m, s).modelExecutionTime for m in msgs if m.which() == s] |
| 173 | + # TODO some init can happen in first iteration |
| 174 | + ts = ts[1:] |
| 175 | + |
| 176 | + errors = [] |
| 177 | + if np.max(ts) > instant_max: |
| 178 | + errors.append("❌ FAILED MAX TIMING CHECK ❌") |
| 179 | + if np.mean(ts) > avg_max: |
| 180 | + errors.append("❌ FAILED AVG TIMING CHECK ❌") |
| 181 | + |
| 182 | + timings_ok = not errors and timings_ok |
| 183 | + rows.append([s, np.max(ts), instant_max, np.mean(ts), avg_max, "\n".join(errors) or "✅"]) |
| 184 | + |
| 185 | + print("------------------------------------------------") |
| 186 | + print("----------------- Model Timing -----------------") |
| 187 | + print("------------------------------------------------") |
| 188 | + print(tabulate(rows, header, tablefmt="simple_grid", stralign="center", numalign="center", floatfmt=".4f")) |
| 189 | + assert timings_ok |
| 190 | + |
| 191 | + return msgs |
160 | 192 |
|
161 | 193 |
|
162 | 194 | def get_frames():
|
|
0 commit comments