Skip to content

Commit

Permalink
add attention
Browse files Browse the repository at this point in the history
  • Loading branch information
xiayuyang committed Dec 17, 2022
1 parent b6a4d20 commit ff964a5
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 11 deletions.
17 changes: 9 additions & 8 deletions algs/pdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def forward(self, lane_veh, ego_info):
batch_size = lane_veh.shape[0]
lane = lane_veh[:, :self.state_dim["waypoints"]]
veh = lane_veh[:, self.state_dim["waypoints"]:-self.state_dim['light']]
light = lane_veh[:, -self.state_dim['light']]
print('ego_info.shape: ', ego_info.shape)
light = lane_veh[:, -self.state_dim['light']:]
# print('ego_info.shape: ', ego_info.shape)
ego_enc = self.w(F.relu(self.ego_encoder(ego_info)))
lane_enc = self.w(F.relu(self.lane_encoder(lane)))
veh_enc = self.w(F.relu(self.veh_encoder(veh)))
Expand All @@ -134,7 +134,7 @@ def forward(self, lane_veh, ego_info):
# score_: [batch_size, 1]
score = torch.cat((score_lane, score_veh, score_light), 1)
score = F.softmax(score, 1).reshape(batch_size, 1, 3)
state_enc = torch.matmul(score, state_enc)
state_enc = torch.matmul(score, state_enc).reshape(batch_size, self.hidden_size)
# state_enc: [N, 64]
return state_enc

Expand Down Expand Up @@ -164,9 +164,9 @@ def __init__(self, state_dim, action_parameter_size, action_bound, train=True) -
def forward(self, state):
# state: (waypoints + 2 * conventional_vehicle0 * 3
one_state_dim = self.state_dim['waypoints'] + self.state_dim['conventional_vehicle'] * 2 + self.state_dim['light']
print(state.shape, one_state_dim)
# print(state.shape, one_state_dim)
ego_info = state[:, 3*one_state_dim:]
print(ego_info.shape)
# print(ego_info.shape)
left_enc = self.left_encoder(state[:, :one_state_dim], ego_info)
center_enc = self.center_encoder(state[:, one_state_dim:2*one_state_dim], ego_info)
right_enc = self.right_encoder(state[:, 2*one_state_dim:3*one_state_dim], ego_info)
Expand Down Expand Up @@ -289,10 +289,11 @@ def take_action(self, state, lane_id=-2, action_mask=True):
state_veh_left_rear = torch.tensor(state['vehicle_info'][3], dtype=torch.float32).view(1, -1).to(self.device)
state_veh_rear = torch.tensor(state['vehicle_info'][4], dtype=torch.float32).view(1, -1).to(self.device)
state_veh_right_rear = torch.tensor(state['vehicle_info'][5], dtype=torch.float32).view(1, -1).to(self.device)
state_light = torch.tensor(state['light'], dtype=torch.float32).view(1, -1).to(self.device)
state_ev = torch.tensor(state['ego_vehicle'],dtype=torch.float32).view(1,-1).to(self.device)
state_ = torch.cat((state_left_wps, state_veh_left_front, state_veh_left_rear,
state_center_wps, state_veh_front, state_veh_rear,
state_right_wps, state_veh_right_front, state_veh_right_rear, state_ev), dim=1)
state_ = torch.cat((state_left_wps, state_veh_left_front, state_veh_left_rear, state_light,
state_center_wps, state_veh_front, state_veh_rear, state_light,
state_right_wps, state_veh_right_front, state_veh_right_rear, state_light, state_ev), dim=1)
# print(state_.shape)
all_action_param = self.actor(state_)
q_a = torch.squeeze(self.critic(state_, all_action_param))
Expand Down
4 changes: 2 additions & 2 deletions gym_carla/env/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
ARGS.add_argument(
'--hybrid',
action='store_true',
default=True,
default=False,
help='Activate hybrid mode for Traffic Manager')
ARGS.add_argument(
'--auto_lane_change',
Expand Down Expand Up @@ -173,7 +173,7 @@
)
ARGS.add_argument(
'--pre_train_steps', type=int,
default=200,
default=20000,
help='Let the RL controller and PID controller alternatively take control every 500 steps'
)
ARGS.add_argument(
Expand Down
Binary file modified out/ddpg_pre_trained.pth
Binary file not shown.
2 changes: 1 addition & 1 deletion train_pdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
TAU = 0.01 # 软更新参数
EPSILON = 0.5 # epsilon-greedy
BUFFER_SIZE = 40000
MINIMAL_SIZE = 300
MINIMAL_SIZE = 10000
BATCH_SIZE = 128
REPLACE_A = 500
REPLACE_C = 300
Expand Down

0 comments on commit ff964a5

Please sign in to comment.