From c56ad05b3584adc16878ca969abcebea0fdf7e9f Mon Sep 17 00:00:00 2001 From: nuwandavek Date: Wed, 29 May 2024 18:16:54 -0700 Subject: [PATCH] fix report styling --- eval.py | 66 +++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 16 deletions(-) diff --git a/eval.py b/eval.py index 31fb8250..26e6532d 100644 --- a/eval.py +++ b/eval.py @@ -17,6 +17,11 @@ sns.set_theme() SAMPLE_ROLLOUTS = 5 +COLORS = { + 'test': '#c0392b', + 'baseline': '#2980b9' +} + def img2base64(fig): buf = BytesIO() @@ -25,39 +30,68 @@ def img2base64(fig): return data -def create_report(test, baseline, sample_rollouts, costs): +def create_report(test, baseline, sample_rollouts, costs, num_segs): res = [] - res.append("

Comma Controls Challenge: Report

") - res.append(f"Test Controller: {test}, Baseline Controller: {baseline}") - - res.append("

Aggregate Costs

") + res.append(""" + + + Comma Controls Challenge: Report + + + + + + + """) + res.append("

Comma Controls Challenge: Report

") + res.append(f"

Test Controller: {test} | Baseline Controller: {baseline}

") + + res.append(f"

Aggregate Costs (total rollouts: {num_segs})

") res_df = pd.DataFrame(costs) fig, axs = plt.subplots(ncols=3, figsize=(18, 6), sharey=True) bins = np.arange(0, 1000, 10) for ax, cost in zip(axs, ['lataccel_cost', 'jerk_cost', 'total_cost']): for controller in ['test', 'baseline']: - ax.hist(res_df[res_df['controller'] == controller][cost], bins=bins, label=controller, alpha=0.5) + ax.hist(res_df[res_df['controller'] == controller][cost], bins=bins, label=controller, alpha=0.5, color=COLORS[controller]) ax.set_xlabel('Cost') ax.set_ylabel('Frequency') ax.set_title(f'Cost Distribution: {cost}') ax.legend() - - res.append(f'Plot') - res.append(res_df.groupby('controller').agg({'lataccel_cost': 'mean', 'jerk_cost': 'mean', 'total_cost': 'mean'}).round(3).reset_index().to_html(index=False)) - - res.append("

Sample Rollouts

") + res.append(f'Plot') + agg_df = res_df.groupby('controller').agg({'lataccel_cost': 'mean', 'jerk_cost': 'mean', 'total_cost': 'mean'}).round(3).reset_index() + res.append(agg_df.to_html(index=False)) + + passed_baseline = agg_df[agg_df['controller'] == 'test']['total_cost'].values[0] < agg_df[agg_df['controller'] == 'baseline']['total_cost'].values[0] + if passed_baseline: + res.append(f"

✅ Test Controller ({test}) passed Baseline Controller ({baseline})! ✅

") + res.append("""

Check the leaderboard + here + and submit your results + here + !

""") + else: + res.append(f"

❌ Test Controller ({test}) failed to beat Baseline Controller ({baseline})! ❌

") + + res.append("
") + res.append("

Sample Rollouts

") fig, axs = plt.subplots(ncols=1, nrows=SAMPLE_ROLLOUTS, figsize=(15, 3 * SAMPLE_ROLLOUTS), sharex=True) for ax, rollout in zip(axs, sample_rollouts): - ax.plot(rollout['desired_lataccel'], label='Desired Lateral Acceleration') - ax.plot(rollout['test_controller_lataccel'], label='Test Controller Lateral Acceleration') - ax.plot(rollout['baseline_controller_lataccel'], label='Baseline Controller Lateral Acceleration') + ax.plot(rollout['desired_lataccel'], label='Desired Lateral Acceleration', color='#27ae60') + ax.plot(rollout['test_controller_lataccel'], label='Test Controller Lateral Acceleration', color=COLORS['test']) + ax.plot(rollout['baseline_controller_lataccel'], label='Baseline Controller Lateral Acceleration', color=COLORS['baseline']) ax.set_xlabel('Step') ax.set_ylabel('Lateral Acceleration') ax.set_title(f"Segment: {rollout['seg']}") ax.axline((CONTROL_START_IDX, 0), (CONTROL_START_IDX, 1), color='black', linestyle='--', alpha=0.5, label='Control Start') ax.legend() fig.tight_layout() - res.append(f'Plot') + res.append(f'Plot') + res.append("") with open("report.html", "w") as fob: fob.write("\n".join(res)) @@ -102,4 +136,4 @@ def create_report(test, baseline, sample_rollouts, costs): results = process_map(rollout_partial, files[SAMPLE_ROLLOUTS:], max_workers=16, chunksize=10) costs += [{'controller': controller_cat, **result[0]} for result in results] - create_report(args.test_controller, args.baseline_controller, sample_rollouts, costs) + create_report(args.test_controller, args.baseline_controller, sample_rollouts, costs, len(files))