Skip to content

Commit

Permalink
make future targets available to the controller
Browse files Browse the repository at this point in the history
  • Loading branch information
nuwandavek committed May 23, 2024
1 parent a615066 commit d42dd27
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 15 deletions.
11 changes: 10 additions & 1 deletion controllers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
class BaseController:
def update(self, target_lataccel, current_lataccel, state):
def update(self, target_lataccel, current_lataccel, state, target_future):
"""
Args:
target_lataccel: The target lateral acceleration.
current_lataccel: The current lateral acceleration.
state: The current state of the vehicle.
target_future: The future target lateral acceleration plan for the next N frames.
Returns:
The control signal to be applied to the vehicle.
"""
raise NotImplementedError
2 changes: 1 addition & 1 deletion controllers/open.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ class Controller(BaseController):
"""
An open-loop controller
"""
def update(self, target_lataccel, current_lataccel, state):
def update(self, target_lataccel, current_lataccel, state, target_future):
return target_lataccel
2 changes: 1 addition & 1 deletion controllers/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ class Controller(BaseController):
"""
A simple controller that is the error between the target and current lateral acceleration times some factor
"""
def update(self, target_lataccel, current_lataccel, state):
def update(self, target_lataccel, current_lataccel, state, target_future):
return (target_lataccel - current_lataccel) * 0.3
34 changes: 22 additions & 12 deletions tinyphysics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
signal.signal(signal.SIGINT, signal.SIG_DFL) # Enable Ctrl-C on plot windows

ACC_G = 9.81
FPS = 10
CONTROL_START_IDX = 100
COST_END_IDX = 550
CONTEXT_LENGTH = 20
VOCAB_SIZE = 1024
LATACCEL_RANGE = [-5, 5]
Expand All @@ -29,6 +31,8 @@
DEL_T = 0.1
LAT_ACCEL_COST_MULTIPLIER = 5.0

FUTURE_PLAN_STEPS = FPS * 5 # 5 secs

State = namedtuple('State', ['roll_lataccel', 'v_ego', 'a_ego'])


Expand Down Expand Up @@ -95,10 +99,12 @@ def __init__(self, model: TinyPhysicsModel, data_path: str, controller: BaseCont

def reset(self) -> None:
self.step_idx = CONTEXT_LENGTH
self.state_history = [self.get_state_target(i)[0] for i in range(self.step_idx)]
state_targetfutures = [self.get_state_targetfuture(i) for i in range(self.step_idx)]
self.state_history = [x['state'] for x in state_targetfutures]
self.action_history = self.data['steer_command'].values[:self.step_idx].tolist()
self.current_lataccel_history = [self.get_state_target(i)[1] for i in range(self.step_idx)]
self.target_lataccel_history = [self.get_state_target(i)[1] for i in range(self.step_idx)]
self.current_lataccel_history = [x['targetfuture'][0] for x in state_targetfutures]
self.target_lataccel_history = [x['targetfuture'][0] for x in state_targetfutures]
self.target_future = None
self.current_lataccel = self.current_lataccel_history[-1]
seed = int(md5(self.data_path.encode()).hexdigest(), 16) % 10**4
np.random.seed(seed)
Expand All @@ -124,25 +130,29 @@ def sim_step(self, step_idx: int) -> None:
if step_idx >= CONTROL_START_IDX:
self.current_lataccel = pred
else:
self.current_lataccel = self.get_state_target(step_idx)[1]
self.current_lataccel = self.get_state_targetfuture(step_idx)['targetfuture'][0]

self.current_lataccel_history.append(self.current_lataccel)

def control_step(self, step_idx: int) -> None:
action = self.controller.update(self.target_lataccel_history[step_idx], self.current_lataccel, self.state_history[step_idx])
action = self.controller.update(self.target_lataccel_history[step_idx], self.current_lataccel, self.state_history[step_idx], target_future=self.target_future)
if step_idx < CONTROL_START_IDX:
action = self.data['steer_command'].values[step_idx]
action = np.clip(action, STEER_RANGE[0], STEER_RANGE[1])
self.action_history.append(action)

def get_state_target(self, step_idx: int) -> Tuple[State, float]:
def get_state_targetfuture(self, step_idx: int) -> Tuple[State, float]:
state = self.data.iloc[step_idx]
return State(roll_lataccel=state['roll_lataccel'], v_ego=state['v_ego'], a_ego=state['a_ego']), state['target_lataccel']
return {
'state': State(roll_lataccel=state['roll_lataccel'], v_ego=state['v_ego'], a_ego=state['a_ego']),
'targetfuture': self.data['target_lataccel'].values[step_idx:step_idx + FUTURE_PLAN_STEPS].tolist()
}

def step(self) -> None:
state, target = self.get_state_target(self.step_idx)
self.state_history.append(state)
self.target_lataccel_history.append(target)
state_targetfuture = self.get_state_targetfuture(self.step_idx)
self.state_history.append(state_targetfuture['state'])
self.target_lataccel_history.append(state_targetfuture['targetfuture'][0])
self.target_future = state_targetfuture['targetfuture'][1:]
self.control_step(self.step_idx)
self.sim_step(self.step_idx)
self.step_idx += 1
Expand All @@ -158,8 +168,8 @@ def plot_data(self, ax, lines, axis_labels, title) -> None:
ax.set_ylabel(axis_labels[1])

def compute_cost(self) -> dict:
target = np.array(self.target_lataccel_history)[CONTROL_START_IDX:]
pred = np.array(self.current_lataccel_history)[CONTROL_START_IDX:]
target = np.array(self.target_lataccel_history)[CONTROL_START_IDX:COST_END_IDX]
pred = np.array(self.current_lataccel_history)[CONTROL_START_IDX:COST_END_IDX]

lat_accel_cost = np.mean((target - pred)**2) * 100
jerk_cost = np.mean((np.diff(pred) / DEL_T)**2) * 100
Expand Down

0 comments on commit d42dd27

Please sign in to comment.