Skip to content

Commit

Permalink
do not reuse the controller in each iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
nuwandavek committed May 7, 2024
1 parent ec68106 commit 0ee669a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ def create_report(test, baseline, sample_rollouts, costs):
args = parser.parse_args()

tinyphysicsmodel = TinyPhysicsModel(args.model_path, debug=False)
test_controller = CONTROLLERS[args.test_controller]()
baseline_controller = CONTROLLERS[args.baseline_controller]()

data_path = Path(args.data_path)
assert data_path.is_dir(), "data_path should be a directory"
Expand All @@ -82,6 +80,8 @@ def create_report(test, baseline, sample_rollouts, costs):
sample_rollouts = []
files = sorted(data_path.iterdir())[:args.num_segs]
for d, data_file in enumerate(tqdm(files, total=len(files))):
test_controller = CONTROLLERS[args.test_controller]()
baseline_controller = CONTROLLERS[args.baseline_controller]()
test_sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_file), controller=test_controller, debug=False)
test_cost = test_sim.rollout()
baseline_sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_file), controller=baseline_controller, debug=False)
Expand Down
3 changes: 2 additions & 1 deletion tinyphysics.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,17 +203,18 @@ def rollout(self) -> float:
args = parser.parse_args()

tinyphysicsmodel = TinyPhysicsModel(args.model_path, debug=args.debug)
controller = CONTROLLERS[args.controller]()

data_path = Path(args.data_path)
if data_path.is_file():
controller = CONTROLLERS[args.controller]()
sim = TinyPhysicsSimulator(tinyphysicsmodel, args.data_path, controller=controller, debug=args.debug)
costs = sim.rollout()
print(f"\nAverage lataccel_cost: {costs['lataccel_cost']:>6.4}, average jerk_cost: {costs['jerk_cost']:>6.4}, average total_cost: {costs['total_cost']:>6.4}")
elif data_path.is_dir():
costs = []
files = sorted(data_path.iterdir())[:args.num_segs]
for data_file in tqdm(files, total=len(files)):
controller = CONTROLLERS[args.controller]()
sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_file), controller=controller, debug=args.debug)
cost = sim.rollout()
costs.append(cost)
Expand Down

0 comments on commit 0ee669a

Please sign in to comment.