Skip to content

Commit

Permalink
massive speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
nuwandavek committed May 18, 2024
1 parent 8a40605 commit 760d9b2
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 46 deletions.
48 changes: 24 additions & 24 deletions eval.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import argparse
import base64
import importlib
import numpy as np
import pandas as pd
import seaborn as sns


from functools import partial
from io import BytesIO
from matplotlib import pyplot as plt
from pathlib import Path
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map

from tinyphysics import TinyPhysicsModel, TinyPhysicsSimulator, CONTROL_START_IDX, get_available_controllers
from tinyphysics import CONTROL_START_IDX, get_available_controllers, run_rollout

sns.set_theme()
SAMPLE_ROLLOUTS = 5
Expand Down Expand Up @@ -73,33 +74,32 @@ def create_report(test, baseline, sample_rollouts, costs):
parser.add_argument("--baseline_controller", default='simple', choices=available_controllers)
args = parser.parse_args()

tinyphysicsmodel = TinyPhysicsModel(args.model_path, debug=False)

data_path = Path(args.data_path)
assert data_path.is_dir(), "data_path should be a directory"

costs = []
sample_rollouts = []
files = sorted(data_path.iterdir())[:args.num_segs]
for d, data_file in enumerate(tqdm(files, total=len(files))):
test_controller = importlib.import_module(f'controllers.{args.test_controller}').Controller()
baseline_controller = importlib.import_module(f'controllers.{args.baseline_controller}').Controller()
test_sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_file), controller=test_controller, debug=False)
test_cost = test_sim.rollout()
baseline_sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_file), controller=baseline_controller, debug=False)
baseline_cost = baseline_sim.rollout()

if d < SAMPLE_ROLLOUTS:
sample_rollouts.append({
'seg': data_file.stem,
'test_controller': args.test_controller,
'baseline_controller': args.baseline_controller,
'desired_lataccel': test_sim.target_lataccel_history,
'test_controller_lataccel': test_sim.current_lataccel_history,
'baseline_controller_lataccel': baseline_sim.current_lataccel_history,
})

costs.append({'seg': data_file.stem, 'controller': 'test', **test_cost})
costs.append({'seg': data_file.stem, 'controller': 'baseline', **baseline_cost})
print("Running rollouts for visualizations...")
for d, data_file in enumerate(tqdm(files[:SAMPLE_ROLLOUTS], total=SAMPLE_ROLLOUTS)):
test_cost, test_target_lataccel, test_current_lataccel = run_rollout(data_file, args.test_controller, args.model_path, debug=False)
baseline_cost, baseline_target_lataccel, baseline_current_lataccel = run_rollout(data_file, args.baseline_controller, args.model_path, debug=False)
sample_rollouts.append({
'seg': data_file.stem,
'test_controller': args.test_controller,
'baseline_controller': args.baseline_controller,
'desired_lataccel': test_target_lataccel,
'test_controller_lataccel': test_current_lataccel,
'baseline_controller_lataccel': baseline_current_lataccel,
})

costs.append({'controller': 'test', **test_cost})
costs.append({'controller': 'baseline', **baseline_cost})

for controller_cat, controller_type in [('baseline', args.baseline_controller), ('test', args.test_controller)]:
print(f"Running batch rollouts => {controller_cat} controller: {controller_type}")
rollout_partial = partial(run_rollout, controller_type=controller_type, model_path=args.model_path, debug=False)
results = process_map(rollout_partial, files[SAMPLE_ROLLOUTS:], max_workers=16)
costs += [{'controller': controller_cat, **result[0]} for result in results]

create_report(args.test_controller, args.baseline_controller, sample_rollouts, costs)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy==1.25.2
onnxruntime-gpu==1.16.3
onnxruntime
pandas==2.1.2
matplotlib==3.8.1
seaborn==0.13.2
Expand Down
36 changes: 15 additions & 21 deletions tinyphysics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import signal

from collections import namedtuple
from functools import partial
from hashlib import md5
from pathlib import Path
from typing import List, Union, Tuple
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map

from controllers import BaseController

Expand Down Expand Up @@ -54,14 +55,7 @@ def __init__(self, model_path: str, debug: bool) -> None:
options.intra_op_num_threads = 1
options.inter_op_num_threads = 1
options.log_severity_level = 3
if 'CUDAExecutionProvider' in ort.get_available_providers():
if debug:
print("ONNX Runtime is using GPU")
provider = ('CUDAExecutionProvider', {'cudnn_conv_algo_search': 'DEFAULT'})
else:
if debug:
print("ONNX Runtime is using CPU")
provider = 'CPUExecutionProvider'
provider = 'CPUExecutionProvider'

with open(model_path, "rb") as f:
self.ort_session = ort.InferenceSession(f.read(), options, [provider])
Expand Down Expand Up @@ -198,6 +192,13 @@ def get_available_controllers():
return [f.stem for f in Path('controllers').iterdir() if f.is_file() and f.suffix == '.py' and f.stem != '__init__']


def run_rollout(data_path, controller_type, model_path, debug=False):
tinyphysicsmodel = TinyPhysicsModel(model_path, debug=debug)
controller = importlib.import_module(f'controllers.{controller_type}').Controller()
sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_path), controller=controller, debug=debug)
return sim.rollout(), sim.target_lataccel_history, sim.current_lataccel_history


if __name__ == "__main__":
available_controllers = get_available_controllers()
parser = argparse.ArgumentParser()
Expand All @@ -208,22 +209,15 @@ def get_available_controllers():
parser.add_argument("--controller", default='simple', choices=available_controllers)
args = parser.parse_args()

tinyphysicsmodel = TinyPhysicsModel(args.model_path, debug=args.debug)

data_path = Path(args.data_path)
if data_path.is_file():
controller = importlib.import_module(f'controllers.{args.controller}').Controller()
sim = TinyPhysicsSimulator(tinyphysicsmodel, args.data_path, controller=controller, debug=args.debug)
costs = sim.rollout()
print(f"\nAverage lataccel_cost: {costs['lataccel_cost']:>6.4}, average jerk_cost: {costs['jerk_cost']:>6.4}, average total_cost: {costs['total_cost']:>6.4}")
cost, _, _ = run_rollout(data_path, args.controller, args.model_path, debug=args.debug)
print(f"\nAverage lataccel_cost: {cost['lataccel_cost']:>6.4}, average jerk_cost: {cost['jerk_cost']:>6.4}, average total_cost: {cost['total_cost']:>6.4}")
elif data_path.is_dir():
costs = []
run_rollout_partial = partial(run_rollout, controller_type=args.controller, model_path=args.model_path, debug=False)
files = sorted(data_path.iterdir())[:args.num_segs]
for data_file in tqdm(files, total=len(files)):
controller = importlib.import_module(f'controllers.{args.controller}').Controller()
sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_file), controller=controller, debug=args.debug)
cost = sim.rollout()
costs.append(cost)
results = process_map(run_rollout_partial, files, max_workers=16)
costs = [result[0] for result in results]
costs_df = pd.DataFrame(costs)
print(f"\nAverage lataccel_cost: {np.mean(costs_df['lataccel_cost']):>6.4}, average jerk_cost: {np.mean(costs_df['jerk_cost']):>6.4}, average total_cost: {np.mean(costs_df['total_cost']):>6.4}")
for cost in costs_df.columns:
Expand Down

0 comments on commit 760d9b2

Please sign in to comment.