-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathagent.py
146 lines (120 loc) · 5.15 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
nn.init.orthogonal_(layer.weight, std)
nn.init.constant_(layer.bias, bias_const)
return layer
class Agent(nn.Module):
def __init__(self, envs) -> None:
super(Agent, self).__init__()
### Parameters
self.hidden_layer_size = 64
### Critic network
self.critic = nn.Sequential(
layer_init(
nn.Linear(
np.array(envs.single_observation_space.shape).prod(),
self.hidden_layer_size,
)
),
nn.Tanh(),
layer_init(nn.Linear(self.hidden_layer_size, self.hidden_layer_size)),
nn.Tanh(),
layer_init(nn.Linear(self.hidden_layer_size, 1), std=1.0),
)
### Actor network
# init output layer (action) with similar weights
# so that actions have similar probabilities in the beginning
self.actor = nn.Sequential(
layer_init(
nn.Linear(
np.array(envs.single_observation_space.shape).prod(),
self.hidden_layer_size,
)
),
nn.Tanh(),
layer_init(nn.Linear(self.hidden_layer_size, self.hidden_layer_size)),
nn.Tanh(),
layer_init(
nn.Linear(self.hidden_layer_size, envs.single_action_space.n), std=0.01
),
)
def get_value(self, x):
"""Return the estimated value of the state
Args:
x (torch.Tensor): State tensor. Shape of observation space.
Returns:
Value: Estimated value of the state by the critic nextwork.
"""
return self.critic(x)
def get_action_and_value(self, x, action=None):
"""Return the action and value of the state
Args:
x (torch.Tensor): State tensor. Shape of observation space.
action (torch.Tensor, optional): Action tensor. Shape of action space. Defaults to None.
Returns:
action (torch.Tensor): Sampled action for each environment.
log_prob (torch.Tensor): Log probability of each action for each environment.
entropy (torch.Tensor): Entropy of each action probability distribution for each environment.
value (torch.Tensor): Value for each environment.
"""
logits = self.actor(x)
probs = Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy(), self.critic(x)
class AtariAgent(nn.Module):
def __init__(self, envs) -> None:
super(AtariAgent, self).__init__()
self.sharedNetwork_out = 512
### Detail 8: Shared feature extractor CNN
# The feature extractor is shared between the actor and critic networks.
# (4, 84, 84) -> (32, 20, 20) -> (64, 9, 9) -> (64, 7, 7)
self.sharedNetwork = nn.Sequential(
layer_init(nn.Conv2d(4, 32, 8, stride=4)),
nn.ReLU(),
layer_init(nn.Conv2d(32, 64, 4, stride=2)),
nn.ReLU(),
layer_init(nn.Conv2d(64, 64, 3, stride=1)),
nn.ReLU(),
nn.Flatten(),
layer_init(nn.Linear(64 * 7 * 7, self.sharedNetwork_out)),
nn.ReLU(),
)
self.actor = layer_init(
nn.Linear(self.sharedNetwork_out, envs.single_action_space.n), std=0.01
)
self.critic = layer_init(nn.Linear(self.sharedNetwork_out, 1), std=1.0)
def get_value(self, x):
"""Return the estimated value of the state
Args:
x (torch.Tensor): State tensor. Shape of observation space.
Returns:
Value: Estimated value of the state by the critic nextwork.
"""
### Detail 9: Scale input to [0, 1]
# Each pixel has a range of [0, 255].
# Scale the input to [0, 1] by deviding by 255.
return self.critic(self.sharedNetwork(x / 255.0))
def get_action_and_value(self, x, action=None):
"""Return the action and value of the state
Args:
x (torch.Tensor): State tensor. Shape of observation space.
action (torch.Tensor, optional): Action tensor. Shape of action space. Defaults to None.
Returns:
action (torch.Tensor): Sampled action for each environment.
log_prob (torch.Tensor): Log probability of each action for each environment.
entropy (torch.Tensor): Entropy of each action probability distribution for each environment.
value (torch.Tensor): Value for each environment.
"""
### Detail 9: Scale input to [0, 1]
# Each pixel has a range of [0, 255].
# Scale the input to [0, 1] by deviding by 255.
hidden = self.sharedNetwork(x / 255.0)
logits = self.actor(hidden)
probs = Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)