Skip to content

Commit

Permalink
Merge pull request #10 from commaai/speedup
Browse files Browse the repository at this point in the history
Speedup report/eval
  • Loading branch information
nuwandavek authored May 20, 2024
2 parents 8a40605 + 8c3bd5d commit c319dcd
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 46 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,11 @@ jobs:
- name: Run Simple controller rollout
run: |
python tinyphysics.py --model_path ./models/tinyphysics.onnx --data_path ./data/00000.csv --controller simple
- name: Run batch rollouts
run: |
python tinyphysics.py --model_path ./models/tinyphysics.onnx --data_path ./data --num_segs 20 --controller simple
- name: Run report
run: |
python eval.py --model_path ./models/tinyphysics.onnx --data_path ./data --num_segs 50 --test_controller open --baseline_controller simple
48 changes: 24 additions & 24 deletions eval.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import argparse
import base64
import importlib
import numpy as np
import pandas as pd
import seaborn as sns


from functools import partial
from io import BytesIO
from matplotlib import pyplot as plt
from pathlib import Path
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map

from tinyphysics import TinyPhysicsModel, TinyPhysicsSimulator, CONTROL_START_IDX, get_available_controllers
from tinyphysics import CONTROL_START_IDX, get_available_controllers, run_rollout

sns.set_theme()
SAMPLE_ROLLOUTS = 5
Expand Down Expand Up @@ -73,33 +74,32 @@ def create_report(test, baseline, sample_rollouts, costs):
parser.add_argument("--baseline_controller", default='simple', choices=available_controllers)
args = parser.parse_args()

tinyphysicsmodel = TinyPhysicsModel(args.model_path, debug=False)

data_path = Path(args.data_path)
assert data_path.is_dir(), "data_path should be a directory"

costs = []
sample_rollouts = []
files = sorted(data_path.iterdir())[:args.num_segs]
for d, data_file in enumerate(tqdm(files, total=len(files))):
test_controller = importlib.import_module(f'controllers.{args.test_controller}').Controller()
baseline_controller = importlib.import_module(f'controllers.{args.baseline_controller}').Controller()
test_sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_file), controller=test_controller, debug=False)
test_cost = test_sim.rollout()
baseline_sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_file), controller=baseline_controller, debug=False)
baseline_cost = baseline_sim.rollout()

if d < SAMPLE_ROLLOUTS:
sample_rollouts.append({
'seg': data_file.stem,
'test_controller': args.test_controller,
'baseline_controller': args.baseline_controller,
'desired_lataccel': test_sim.target_lataccel_history,
'test_controller_lataccel': test_sim.current_lataccel_history,
'baseline_controller_lataccel': baseline_sim.current_lataccel_history,
})

costs.append({'seg': data_file.stem, 'controller': 'test', **test_cost})
costs.append({'seg': data_file.stem, 'controller': 'baseline', **baseline_cost})
print("Running rollouts for visualizations...")
for d, data_file in enumerate(tqdm(files[:SAMPLE_ROLLOUTS], total=SAMPLE_ROLLOUTS)):
test_cost, test_target_lataccel, test_current_lataccel = run_rollout(data_file, args.test_controller, args.model_path, debug=False)
baseline_cost, baseline_target_lataccel, baseline_current_lataccel = run_rollout(data_file, args.baseline_controller, args.model_path, debug=False)
sample_rollouts.append({
'seg': data_file.stem,
'test_controller': args.test_controller,
'baseline_controller': args.baseline_controller,
'desired_lataccel': test_target_lataccel,
'test_controller_lataccel': test_current_lataccel,
'baseline_controller_lataccel': baseline_current_lataccel,
})

costs.append({'controller': 'test', **test_cost})
costs.append({'controller': 'baseline', **baseline_cost})

for controller_cat, controller_type in [('baseline', args.baseline_controller), ('test', args.test_controller)]:
print(f"Running batch rollouts => {controller_cat} controller: {controller_type}")
rollout_partial = partial(run_rollout, controller_type=controller_type, model_path=args.model_path, debug=False)
results = process_map(rollout_partial, files[SAMPLE_ROLLOUTS:], max_workers=16)
costs += [{'controller': controller_cat, **result[0]} for result in results]

create_report(args.test_controller, args.baseline_controller, sample_rollouts, costs)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy==1.25.2
onnxruntime-gpu==1.16.3
onnxruntime
pandas==2.1.2
matplotlib==3.8.1
seaborn==0.13.2
Expand Down
36 changes: 15 additions & 21 deletions tinyphysics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import signal

from collections import namedtuple
from functools import partial
from hashlib import md5
from pathlib import Path
from typing import List, Union, Tuple
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map

from controllers import BaseController

Expand Down Expand Up @@ -54,14 +55,7 @@ def __init__(self, model_path: str, debug: bool) -> None:
options.intra_op_num_threads = 1
options.inter_op_num_threads = 1
options.log_severity_level = 3
if 'CUDAExecutionProvider' in ort.get_available_providers():
if debug:
print("ONNX Runtime is using GPU")
provider = ('CUDAExecutionProvider', {'cudnn_conv_algo_search': 'DEFAULT'})
else:
if debug:
print("ONNX Runtime is using CPU")
provider = 'CPUExecutionProvider'
provider = 'CPUExecutionProvider'

with open(model_path, "rb") as f:
self.ort_session = ort.InferenceSession(f.read(), options, [provider])
Expand Down Expand Up @@ -198,6 +192,13 @@ def get_available_controllers():
return [f.stem for f in Path('controllers').iterdir() if f.is_file() and f.suffix == '.py' and f.stem != '__init__']


def run_rollout(data_path, controller_type, model_path, debug=False):
tinyphysicsmodel = TinyPhysicsModel(model_path, debug=debug)
controller = importlib.import_module(f'controllers.{controller_type}').Controller()
sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_path), controller=controller, debug=debug)
return sim.rollout(), sim.target_lataccel_history, sim.current_lataccel_history


if __name__ == "__main__":
available_controllers = get_available_controllers()
parser = argparse.ArgumentParser()
Expand All @@ -208,22 +209,15 @@ def get_available_controllers():
parser.add_argument("--controller", default='simple', choices=available_controllers)
args = parser.parse_args()

tinyphysicsmodel = TinyPhysicsModel(args.model_path, debug=args.debug)

data_path = Path(args.data_path)
if data_path.is_file():
controller = importlib.import_module(f'controllers.{args.controller}').Controller()
sim = TinyPhysicsSimulator(tinyphysicsmodel, args.data_path, controller=controller, debug=args.debug)
costs = sim.rollout()
print(f"\nAverage lataccel_cost: {costs['lataccel_cost']:>6.4}, average jerk_cost: {costs['jerk_cost']:>6.4}, average total_cost: {costs['total_cost']:>6.4}")
cost, _, _ = run_rollout(data_path, args.controller, args.model_path, debug=args.debug)
print(f"\nAverage lataccel_cost: {cost['lataccel_cost']:>6.4}, average jerk_cost: {cost['jerk_cost']:>6.4}, average total_cost: {cost['total_cost']:>6.4}")
elif data_path.is_dir():
costs = []
run_rollout_partial = partial(run_rollout, controller_type=args.controller, model_path=args.model_path, debug=False)
files = sorted(data_path.iterdir())[:args.num_segs]
for data_file in tqdm(files, total=len(files)):
controller = importlib.import_module(f'controllers.{args.controller}').Controller()
sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_file), controller=controller, debug=args.debug)
cost = sim.rollout()
costs.append(cost)
results = process_map(run_rollout_partial, files, max_workers=16)
costs = [result[0] for result in results]
costs_df = pd.DataFrame(costs)
print(f"\nAverage lataccel_cost: {np.mean(costs_df['lataccel_cost']):>6.4}, average jerk_cost: {np.mean(costs_df['jerk_cost']):>6.4}, average total_cost: {np.mean(costs_df['total_cost']):>6.4}")
for cost in costs_df.columns:
Expand Down

0 comments on commit c319dcd

Please sign in to comment.