Skip to content

Commit

Permalink
Merge pull request commaai#4 from grekiki2/types
Browse files Browse the repository at this point in the history
Fix typehints
  • Loading branch information
nuwandavek authored May 2, 2024
2 parents 240715a + 70d3074 commit ec68106
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tinyphysics.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def softmax(self, x, axis=-1):
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return e_x / np.sum(e_x, axis=axis, keepdims=True)

def predict(self, input_data: dict, temperature=1.) -> dict:
def predict(self, input_data: dict, temperature=1.) -> int:
res = self.ort_session.run(None, input_data)[0]
probs = self.softmax(res / temperature, axis=-1)
# we only care about the last timestep (batch size is just 1)
Expand All @@ -96,7 +96,6 @@ def __init__(self, model: TinyPhysicsModel, data_path: str, controller: BaseCont
self.data = self.get_data(data_path)
self.controller = controller
self.debug = debug
self.times = []
self.reset()

def reset(self) -> None:
Expand Down Expand Up @@ -142,7 +141,7 @@ def control_step(self, step_idx: int) -> None:
action = np.clip(action, STEER_RANGE[0], STEER_RANGE[1])
self.action_history.append(action)

def get_state_target(self, step_idx: int) -> Tuple[List, float]:
def get_state_target(self, step_idx: int) -> Tuple[State, float]:
state = self.data.iloc[step_idx]
return State(roll_lataccel=state['roll_lataccel'], v_ego=state['v_ego'], a_ego=state['a_ego']), state['target_lataccel']

Expand All @@ -164,7 +163,7 @@ def plot_data(self, ax, lines, axis_labels, title) -> None:
ax.set_xlabel(axis_labels[0])
ax.set_ylabel(axis_labels[1])

def compute_cost(self) -> float:
def compute_cost(self) -> dict:
target = np.array(self.target_lataccel_history)[CONTROL_START_IDX:]
pred = np.array(self.current_lataccel_history)[CONTROL_START_IDX:]

Expand All @@ -173,7 +172,7 @@ def compute_cost(self) -> float:
total_cost = (lat_accel_cost * LAT_ACCEL_COST_MULTIPLIER) + jerk_cost
return {'lataccel_cost': lat_accel_cost, 'jerk_cost': jerk_cost, 'total_cost': total_cost}

def rollout(self) -> None:
def rollout(self) -> float:
if self.debug:
plt.ion()
fig, ax = plt.subplots(4, figsize=(12, 14), constrained_layout=True)
Expand Down

0 comments on commit ec68106

Please sign in to comment.