Skip to content

Commit

Permalink
Merge pull request commaai#9 from commaai/controllers-refactor
Browse files Browse the repository at this point in the history
Refactor controllers to be separate files
  • Loading branch information
nuwandavek authored May 18, 2024
2 parents 35d09f3 + 04804d4 commit 8a40605
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 33 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ on:
branches:
- master
workflow_dispatch:
pull_request:

jobs:
run-script:
rollout:
runs-on: ubuntu-20.04

steps:
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ We'll be using a synthetic dataset based on the [comma-steering-control](https:/
bash ./download_dataset.sh
# install required packages
# recommended python==3.11
pip install -r requirements.txt
# test this works
Expand All @@ -30,11 +31,11 @@ python eval.py --model_path ./models/tinyphysics.onnx --data_path ./data --num_s
You can also use the notebook at [`experiment.ipynb`](https://github.com/commaai/controls_challenge/blob/master/experiment.ipynb) for exploration.

## TinyPhysics
This is a "simulated car" that has been trained to mimic a very simple physics model (bicycle model) based simulator, given realistic driving noise. It is an autoregressive model similar to [ML Controls Sim](https://blog.comma.ai/096release/#ml-controls-sim) in architecture. It's inputs are the car velocity (`v_ego`), forward acceleration (`a_ego`), lateral acceleration due to road roll (`road_lataccel`), current car lateral acceleration (`current_lataccel`) and a steer input (`steer_action`) and predicts the resultant lateral acceleration fo the car.
This is a "simulated car" that has been trained to mimic a very simple physics model (bicycle model) based simulator, given realistic driving noise. It is an autoregressive model similar to [ML Controls Sim](https://blog.comma.ai/096release/#ml-controls-sim) in architecture. It's inputs are the car velocity (`v_ego`), forward acceleration (`a_ego`), lateral acceleration due to road roll (`road_lataccel`), current car lateral acceleration (`current_lataccel`) and a steer input (`steer_action`) and predicts the resultant lateral acceleration of the car.


## Controllers
Your controller should implement an [update function](https://github.com/commaai/controls_challenge/blob/1a25ee200f5466cb7dc1ab0bf6b7d0c67a2481db/controllers.py#L2) that returns the `steer_action`. This controller is then run in-loop, in the simulator to autoregressively predict the car's response.
Your controller should implement a new [controller](https://github.com/commaai/controls_challenge/tree/master/controllers). This controller can be passed as an arg to run in-loop in the simulator to autoregressively predict the car's response.


## Evaluation
Expand Down
19 changes: 0 additions & 19 deletions controllers.py

This file was deleted.

3 changes: 3 additions & 0 deletions controllers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class BaseController:
def update(self, target_lataccel, current_lataccel, state):
raise NotImplementedError
9 changes: 9 additions & 0 deletions controllers/open.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from . import BaseController


class Controller(BaseController):
"""
An open-loop controller
"""
def update(self, target_lataccel, current_lataccel, state):
return target_lataccel
9 changes: 9 additions & 0 deletions controllers/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from . import BaseController


class Controller(BaseController):
"""
A simple controller that is the error between the target and current lateral acceleration times some factor
"""
def update(self, target_lataccel, current_lataccel, state):
return (target_lataccel - current_lataccel) * 0.3
12 changes: 7 additions & 5 deletions eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import base64
import importlib
import numpy as np
import pandas as pd
import seaborn as sns
Expand All @@ -10,7 +11,7 @@
from pathlib import Path
from tqdm import tqdm

from tinyphysics import TinyPhysicsModel, TinyPhysicsSimulator, CONTROLLERS, CONTROL_START_IDX
from tinyphysics import TinyPhysicsModel, TinyPhysicsSimulator, CONTROL_START_IDX, get_available_controllers

sns.set_theme()
SAMPLE_ROLLOUTS = 5
Expand Down Expand Up @@ -63,12 +64,13 @@ def create_report(test, baseline, sample_rollouts, costs):


if __name__ == "__main__":
available_controllers = get_available_controllers()
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--num_segs", type=int, default=100)
parser.add_argument("--test_controller", default='simple', choices=CONTROLLERS.keys())
parser.add_argument("--baseline_controller", default='simple', choices=CONTROLLERS.keys())
parser.add_argument("--test_controller", default='simple', choices=available_controllers)
parser.add_argument("--baseline_controller", default='simple', choices=available_controllers)
args = parser.parse_args()

tinyphysicsmodel = TinyPhysicsModel(args.model_path, debug=False)
Expand All @@ -80,8 +82,8 @@ def create_report(test, baseline, sample_rollouts, costs):
sample_rollouts = []
files = sorted(data_path.iterdir())[:args.num_segs]
for d, data_file in enumerate(tqdm(files, total=len(files))):
test_controller = CONTROLLERS[args.test_controller]()
baseline_controller = CONTROLLERS[args.baseline_controller]()
test_controller = importlib.import_module(f'controllers.{args.test_controller}').Controller()
baseline_controller = importlib.import_module(f'controllers.{args.baseline_controller}').Controller()
test_sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_file), controller=test_controller, debug=False)
test_cost = test_sim.rollout()
baseline_sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_file), controller=baseline_controller, debug=False)
Expand Down
4 changes: 2 additions & 2 deletions experiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"outputs": [],
"source": [
"from tinyphysics import TinyPhysicsModel, TinyPhysicsSimulator, CONTROL_START_IDX\n",
"from controllers import SimpleController\n",
"from controllers import simple\n",
"from matplotlib import pyplot as plt\n",
"import seaborn as sns\n",
"\n",
Expand Down Expand Up @@ -38,7 +38,7 @@
"outputs": [],
"source": [
"model = TinyPhysicsModel(\"./models/tinyphysics.onnx\", debug=True)\n",
"controller = SimpleController()"
"controller = simple.Controller()"
]
},
{
Expand Down
14 changes: 10 additions & 4 deletions tinyphysics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import importlib
import numpy as np
import onnxruntime as ort
import pandas as pd
Expand All @@ -12,7 +13,7 @@
from typing import List, Union, Tuple
from tqdm import tqdm

from controllers import BaseController, CONTROLLERS
from controllers import BaseController

sns.set_theme()
signal.signal(signal.SIGINT, signal.SIG_DFL) # Enable Ctrl-C on plot windows
Expand Down Expand Up @@ -193,28 +194,33 @@ def rollout(self) -> float:
return self.compute_cost()


def get_available_controllers():
return [f.stem for f in Path('controllers').iterdir() if f.is_file() and f.suffix == '.py' and f.stem != '__init__']


if __name__ == "__main__":
available_controllers = get_available_controllers()
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--num_segs", type=int, default=100)
parser.add_argument("--debug", action='store_true')
parser.add_argument("--controller", default='simple', choices=CONTROLLERS.keys())
parser.add_argument("--controller", default='simple', choices=available_controllers)
args = parser.parse_args()

tinyphysicsmodel = TinyPhysicsModel(args.model_path, debug=args.debug)

data_path = Path(args.data_path)
if data_path.is_file():
controller = CONTROLLERS[args.controller]()
controller = importlib.import_module(f'controllers.{args.controller}').Controller()
sim = TinyPhysicsSimulator(tinyphysicsmodel, args.data_path, controller=controller, debug=args.debug)
costs = sim.rollout()
print(f"\nAverage lataccel_cost: {costs['lataccel_cost']:>6.4}, average jerk_cost: {costs['jerk_cost']:>6.4}, average total_cost: {costs['total_cost']:>6.4}")
elif data_path.is_dir():
costs = []
files = sorted(data_path.iterdir())[:args.num_segs]
for data_file in tqdm(files, total=len(files)):
controller = CONTROLLERS[args.controller]()
controller = importlib.import_module(f'controllers.{args.controller}').Controller()
sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_file), controller=controller, debug=args.debug)
cost = sim.rollout()
costs.append(cost)
Expand Down

0 comments on commit 8a40605

Please sign in to comment.