Skip to content

Commit

Permalink
add attention
Browse files Browse the repository at this point in the history
  • Loading branch information
xiayuyang committed Dec 17, 2022
1 parent d3b8055 commit b6a4d20
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 54 deletions.
111 changes: 61 additions & 50 deletions algs/pdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class ReplayBuffer:
"""经验回放池"""

def __init__(self, capacity) -> None:
self.buffer = collections.deque(maxlen=capacity) # 队列,先进先出
Expand All @@ -22,10 +21,10 @@ def add(self, state, action, action_param, reward, next_state, truncated, done,
lane_center = info["offlane"]
reward_ttc = info["TTC"]
reward_eff = info["velocity"]
# if reward_ttc < -0.1 or reward_eff < 3:
# self.change_buffer.append((state, action, action_param, reward, next_state, truncated, done))
# if truncated:
# self.change_buffer.append((state, action, action_param, reward, next_state, truncated, done))
if reward_ttc < -0.1 or reward_eff < 3:
self.change_buffer.append((state, action, action_param, reward, next_state, truncated, done))
if truncated:
self.change_buffer.append((state, action, action_param, reward, next_state, truncated, done))
if action == 0 or action == 2:
self.change_buffer.append((state, action, action_param, reward, next_state, truncated, done))
self.tmp_buffer.append((state, action, action_param, reward, next_state, truncated, done))
Expand Down Expand Up @@ -72,10 +71,11 @@ def _compress(self, state):
state_veh_rear = np.array(state['vehicle_info'][4], dtype=np.float32).reshape((1, -1))
state_veh_right_rear = np.array(state['vehicle_info'][5], dtype=np.float32).reshape((1, -1))
state_ev = np.array(state['ego_vehicle'], dtype=np.float32).reshape((1, -1))

state_ = np.concatenate((state_left_wps, state_veh_left_front, state_veh_left_rear,
state_center_wps, state_veh_front, state_veh_rear,
state_right_wps, state_veh_right_front, state_veh_right_rear, state_ev), axis=1)
state_light = np.array(state['light'], dtype=np.float32).reshape((1, -1))
print('state[light].shape', state_light.shape)
state_ = np.concatenate((state_left_wps, state_veh_left_front, state_veh_left_rear, state_light,
state_center_wps, state_veh_front, state_veh_rear, state_light,
state_right_wps, state_veh_right_front, state_veh_right_rear, state_light, state_ev), axis=1)
return state_


Expand All @@ -89,45 +89,54 @@ def __init__(self, state_dim, train=True):
self.light_encoder = nn.Linear(state_dim['light'], 32)
self.agg = nn.Linear(128, 64)

def forward(self, lane_veh):
def forward(self, lane_veh, ego_info):
lane = lane_veh[:, :self.state_dim["waypoints"]]
veh = lane_veh[:, self.state_dim["waypoints"]:-self.state_dim['light']]
light = lane_veh[:, -self.state_dim['light']:]
lane_enc = F.relu(self.lane_encoder(lane))
veh_enc = F.relu(self.veh_encoder(veh))
light_enc = F.relu(self.light_encoder(light))
state_cat = torch.cat((lane_enc, veh_enc, lane_enc), dim=1)
state_cat = torch.cat((lane_enc, veh_enc, light_enc), dim=1)
state_enc = F.relu(self.agg(state_cat))
return state_enc


# class lane_wise_cross_attention_encoder(torch.nn.Module):
# def __init__(self, state_dim, train=True):
# super().__init__()
# self.state_dim = state_dim
# self.train = train
# self.lane_encoder = nn.Linear(state_dim['waypoints'], 32)
# self.veh_encoder = nn.Linear(state_dim['conventional_vehicle'] * 2, 32)
# self.light_encoder = nn.Linear(state_dim['light'], 32)
# self.ego_encoder = nn.Linear(state_dim['ego_vehicle'], 32)
# self.w = nn.Linear(64, 64)
# self.a = nn.Linear(64, 1)
# self.leaky_relu = nn.LeakyReLU(negative_slope=0.1)
#
#
# def forward(self, lane_veh, ego_info):
# lane = lane_veh[:, :self.state_dim["waypoints"]]
# veh = lane_veh[:, self.state_dim["waypoints"]:-self.state_dim['light']]
# light = lane_veh[:, -self.state_dim['light']]
# ego_enc = F.relu(self.ego_encoder(ego_info))
# lane_enc = self.w(torch.cat((F.relu(self.lane_encoder(lane)), ego_enc), 1))
# veh_enc = self.w(torch.cat((F.relu(self.veh_encoder(veh)), ego_enc), 1))
# light_enc = self.w(torch.cat((F.relu(self.light_encoder(light)), ego_enc), 1))
# score_lane = self.a(lane_enc)
# score_veh = self.a(veh_enc)
# score_light = self.a(light_enc)
#
# return state_enc
class lane_wise_cross_attention_encoder(torch.nn.Module):
def __init__(self, state_dim, train=True):
super().__init__()
self.state_dim = state_dim
self.train = train
self.hidden_size = 64
self.lane_encoder = nn.Linear(state_dim['waypoints'], self.hidden_size)
self.veh_encoder = nn.Linear(state_dim['conventional_vehicle'] * 2, self.hidden_size)
self.light_encoder = nn.Linear(state_dim['light'], self.hidden_size)
self.ego_encoder = nn.Linear(state_dim['ego_vehicle'], self.hidden_size)
self.w = nn.Linear(self.hidden_size, self.hidden_size)
self.ego_a = nn.Linear(self.hidden_size, 1)
self.ego_o = nn.Linear(self.hidden_size, 1)
self.leaky_relu = nn.LeakyReLU(negative_slope=0.1)

def forward(self, lane_veh, ego_info):
batch_size = lane_veh.shape[0]
lane = lane_veh[:, :self.state_dim["waypoints"]]
veh = lane_veh[:, self.state_dim["waypoints"]:-self.state_dim['light']]
light = lane_veh[:, -self.state_dim['light']]
print('ego_info.shape: ', ego_info.shape)
ego_enc = self.w(F.relu(self.ego_encoder(ego_info)))
lane_enc = self.w(F.relu(self.lane_encoder(lane)))
veh_enc = self.w(F.relu(self.veh_encoder(veh)))
light_enc = self.w(F.relu(self.light_encoder(light)))
state_enc = torch.cat((lane_enc, veh_enc, light_enc), 1).reshape(batch_size, 3, self.hidden_size)
# _enc: [batch_size, 32]
score_lane = self.leaky_relu(self.ego_a(ego_enc) + self.ego_o(lane_enc))
score_veh = self.leaky_relu(self.ego_a(ego_enc) + self.ego_o(veh_enc))
score_light = self.leaky_relu(self.ego_a(ego_enc) + self.ego_o(light_enc))
# score_: [batch_size, 1]
score = torch.cat((score_lane, score_veh, score_light), 1)
score = F.softmax(score, 1).reshape(batch_size, 1, 3)
state_enc = torch.matmul(score, state_enc)
# state_enc: [N, 64]
return state_enc


class PolicyNet_multi(torch.nn.Module):
Expand All @@ -138,9 +147,9 @@ def __init__(self, state_dim, action_parameter_size, action_bound, train=True) -
self.action_bound = action_bound
self.action_parameter_size = action_parameter_size
self.train = train
self.left_encoder = veh_lane_encoder(self.state_dim)
self.center_encoder = veh_lane_encoder(self.state_dim)
self.right_encoder = veh_lane_encoder(self.state_dim)
self.left_encoder = lane_wise_cross_attention_encoder(self.state_dim)
self.center_encoder = lane_wise_cross_attention_encoder(self.state_dim)
self.right_encoder = lane_wise_cross_attention_encoder(self.state_dim)
self.ego_encoder = nn.Linear(self.state_dim['ego_vehicle'], 64)
self.fc = nn.Linear(256, 256)
self.fc_out = nn.Linear(256, self.action_parameter_size)
Expand All @@ -155,10 +164,12 @@ def __init__(self, state_dim, action_parameter_size, action_bound, train=True) -
def forward(self, state):
# state: (waypoints + 2 * conventional_vehicle0 * 3
one_state_dim = self.state_dim['waypoints'] + self.state_dim['conventional_vehicle'] * 2 + self.state_dim['light']
print(state.shape, one_state_dim)
ego_info = state[:, 3*one_state_dim:]
left_enc = self.left_encoder(state[:, :one_state_dim])
center_enc = self.center_encoder(state[:, one_state_dim:2*one_state_dim])
right_enc = self.right_encoder(state[:, 2*one_state_dim:3*one_state_dim])
print(ego_info.shape)
left_enc = self.left_encoder(state[:, :one_state_dim], ego_info)
center_enc = self.center_encoder(state[:, one_state_dim:2*one_state_dim], ego_info)
right_enc = self.right_encoder(state[:, 2*one_state_dim:3*one_state_dim], ego_info)
ego_enc = self.ego_encoder(ego_info)
state_ = torch.cat((left_enc, center_enc, right_enc, ego_enc), dim=1)
hidden = F.relu(self.fc(state_))
Expand Down Expand Up @@ -187,9 +198,9 @@ def __init__(self, state_dim, action_param_dim, num_actions) -> None:
self.state_dim = state_dim
self.action_param_dim = action_param_dim
self.num_actions = num_actions
self.left_encoder = veh_lane_encoder(self.state_dim)
self.center_encoder = veh_lane_encoder(self.state_dim)
self.right_encoder = veh_lane_encoder(self.state_dim)
self.left_encoder = lane_wise_cross_attention_encoder(self.state_dim)
self.center_encoder = lane_wise_cross_attention_encoder(self.state_dim)
self.right_encoder = lane_wise_cross_attention_encoder(self.state_dim)
self.ego_encoder = nn.Linear(self.state_dim['ego_vehicle'], 32)
self.action_encoder = nn.Linear(self.action_param_dim, 32)
self.fc = nn.Linear(256, 256)
Expand All @@ -203,9 +214,9 @@ def __init__(self, state_dim, action_param_dim, num_actions) -> None:
def forward(self, state, action):
one_state_dim = self.state_dim['waypoints'] + self.state_dim['conventional_vehicle'] * 2 + self.state_dim['light']
ego_info = state[:, 3*one_state_dim:]
left_enc = self.left_encoder(state[:, :one_state_dim])
center_enc = self.center_encoder(state[:, one_state_dim:2*one_state_dim])
right_enc = self.right_encoder(state[:, 2*one_state_dim:3*one_state_dim])
left_enc = self.left_encoder(state[:, :one_state_dim], ego_info)
center_enc = self.center_encoder(state[:, one_state_dim:2*one_state_dim], ego_info)
right_enc = self.right_encoder(state[:, 2*one_state_dim:3*one_state_dim], ego_info)
ego_enc = self.ego_encoder(ego_info)
action_enc = self.action_encoder(action)
state_ = torch.cat((left_enc, center_enc, right_enc, ego_enc, action_enc), dim=1)
Expand Down
4 changes: 2 additions & 2 deletions gym_carla/env/carla_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,9 +854,9 @@ def _get_reward(self, last_action, last_lane, current_lane, distance_to_front_ve
v_s = v_3d.length() * math.cos(theta_v)
if v_s * 3.6 > self.speed_limit:
# fEff = 1
fEff = math.exp(self.speed_limit - v_s * 3.6) - 1
fEff = math.exp(self.speed_limit - v_s * 3.6)
else:
fEff = v_s * 3.6 / self.speed_limit - 1
fEff = v_s * 3.6 / self.speed_limit

cur_acc = self.get_acc_s(self.ego_vehicle.get_acceleration(), yaw_forward)

Expand Down
2 changes: 1 addition & 1 deletion gym_carla/env/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@
)
ARGS.add_argument(
'--pre_train_steps', type=int,
default=20000,
default=200,
help='Let the RL controller and PID controller alternatively take control every 500 steps'
)
ARGS.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion train_pdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
TAU = 0.01 # 软更新参数
EPSILON = 0.5 # epsilon-greedy
BUFFER_SIZE = 40000
MINIMAL_SIZE = 10000
MINIMAL_SIZE = 300
BATCH_SIZE = 128
REPLACE_A = 500
REPLACE_C = 300
Expand Down

0 comments on commit b6a4d20

Please sign in to comment.