Skip to content

Commit

Permalink
fix report styling
Browse files Browse the repository at this point in the history
  • Loading branch information
nuwandavek committed May 30, 2024
1 parent fdafbc6 commit c56ad05
Showing 1 changed file with 50 additions and 16 deletions.
66 changes: 50 additions & 16 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
sns.set_theme()
SAMPLE_ROLLOUTS = 5

COLORS = {
'test': '#c0392b',
'baseline': '#2980b9'
}


def img2base64(fig):
buf = BytesIO()
Expand All @@ -25,39 +30,68 @@ def img2base64(fig):
return data


def create_report(test, baseline, sample_rollouts, costs):
def create_report(test, baseline, sample_rollouts, costs, num_segs):
res = []
res.append("<h1>Comma Controls Challenge: Report</h1>")
res.append(f"<b>Test Controller: {test}, Baseline Controller: {baseline}</b>")

res.append("<h2>Aggregate Costs</h2>")
res.append("""
<html>
<head>
<title>Comma Controls Challenge: Report</title>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=JetBrains+Mono:ital,wght@0,100..800;1,100..800&display=swap" rel="stylesheet">
<style type='text/css'>
table {border-collapse: collapse; font-size: 30px; margin-top: 20px; margin-bottom: 30px;}
th, td {border: 1px solid black; text-align: left; padding: 20px;}
th {background-color: #f2f2f2;}
th {background-color: #f2f2f2;}
</style>
</head>
<body style="font-family: 'JetBrains Mono', monospace; margin: 20px; padding: 20px; display: flex; flex-direction: column; justify-content: center; align-items: center">
""")
res.append("<h1 style='font-size: 50px; font-weight: 700; text-align: center'>Comma Controls Challenge: Report</h1>")
res.append(f"<h3 style='font-size: 30px;'><span style='background: {COLORS['test']}; color: #fff; padding: 10px'>Test Controller: {test}</span> | <span style='background: {COLORS['baseline']}; color: #fff; padding: 10px'>Baseline Controller: {baseline}</span></h3>")

res.append(f"<h2 style='font-size: 30px; margin-top: 50px'>Aggregate Costs (total rollouts: {num_segs})</h2>")
res_df = pd.DataFrame(costs)
fig, axs = plt.subplots(ncols=3, figsize=(18, 6), sharey=True)
bins = np.arange(0, 1000, 10)
for ax, cost in zip(axs, ['lataccel_cost', 'jerk_cost', 'total_cost']):
for controller in ['test', 'baseline']:
ax.hist(res_df[res_df['controller'] == controller][cost], bins=bins, label=controller, alpha=0.5)
ax.hist(res_df[res_df['controller'] == controller][cost], bins=bins, label=controller, alpha=0.5, color=COLORS[controller])
ax.set_xlabel('Cost')
ax.set_ylabel('Frequency')
ax.set_title(f'Cost Distribution: {cost}')
ax.legend()

res.append(f'<img src="data:image/png;base64,{img2base64(fig)}" alt="Plot">')
res.append(res_df.groupby('controller').agg({'lataccel_cost': 'mean', 'jerk_cost': 'mean', 'total_cost': 'mean'}).round(3).reset_index().to_html(index=False))

res.append("<h2>Sample Rollouts</h2>")
res.append(f'<img style="max-width:100%" src="data:image/png;base64,{img2base64(fig)}" alt="Plot">')
agg_df = res_df.groupby('controller').agg({'lataccel_cost': 'mean', 'jerk_cost': 'mean', 'total_cost': 'mean'}).round(3).reset_index()
res.append(agg_df.to_html(index=False))

passed_baseline = agg_df[agg_df['controller'] == 'test']['total_cost'].values[0] < agg_df[agg_df['controller'] == 'baseline']['total_cost'].values[0]
if passed_baseline:
res.append(f"<h3 style='font-size: 20px; color: #27ae60'> ✅ Test Controller ({test}) passed Baseline Controller ({baseline})! ✅ </h3>")
res.append("""<p>Check the leaderboard
<a href='https://comma.ai/leaderboard'>here</a>
and submit your results
<a href='https://docs.google.com/forms/d/e/1FAIpQLSc_Qsh5egoseXKr8vI2TIlsskd6nNZLNVuMJBjkogZzLe79KQ/viewform'>here</a>
!</p>""")
else:
res.append(f"<h3 style='font-size: 20px; color: #c0392b'> ❌ Test Controller ({test}) failed to beat Baseline Controller ({baseline})! ❌</h3>")

res.append("<hr style='border: #ddd 1px solid; width: 80%'>")
res.append("<h2 style='font-size: 30px; margin-top: 50px'>Sample Rollouts</h2>")
fig, axs = plt.subplots(ncols=1, nrows=SAMPLE_ROLLOUTS, figsize=(15, 3 * SAMPLE_ROLLOUTS), sharex=True)
for ax, rollout in zip(axs, sample_rollouts):
ax.plot(rollout['desired_lataccel'], label='Desired Lateral Acceleration')
ax.plot(rollout['test_controller_lataccel'], label='Test Controller Lateral Acceleration')
ax.plot(rollout['baseline_controller_lataccel'], label='Baseline Controller Lateral Acceleration')
ax.plot(rollout['desired_lataccel'], label='Desired Lateral Acceleration', color='#27ae60')
ax.plot(rollout['test_controller_lataccel'], label='Test Controller Lateral Acceleration', color=COLORS['test'])
ax.plot(rollout['baseline_controller_lataccel'], label='Baseline Controller Lateral Acceleration', color=COLORS['baseline'])
ax.set_xlabel('Step')
ax.set_ylabel('Lateral Acceleration')
ax.set_title(f"Segment: {rollout['seg']}")
ax.axline((CONTROL_START_IDX, 0), (CONTROL_START_IDX, 1), color='black', linestyle='--', alpha=0.5, label='Control Start')
ax.legend()
fig.tight_layout()
res.append(f'<img src="data:image/png;base64,{img2base64(fig)}" alt="Plot">')
res.append(f'<img style="max-width:100%" src="data:image/png;base64,{img2base64(fig)}" alt="Plot">')
res.append("</body></html>")

with open("report.html", "w") as fob:
fob.write("\n".join(res))
Expand Down Expand Up @@ -102,4 +136,4 @@ def create_report(test, baseline, sample_rollouts, costs):
results = process_map(rollout_partial, files[SAMPLE_ROLLOUTS:], max_workers=16, chunksize=10)
costs += [{'controller': controller_cat, **result[0]} for result in results]

create_report(args.test_controller, args.baseline_controller, sample_rollouts, costs)
create_report(args.test_controller, args.baseline_controller, sample_rollouts, costs, len(files))

0 comments on commit c56ad05

Please sign in to comment.