forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaction_dist.py
94 lines (75 loc) · 3.34 KB
/
action_dist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
import gymnasium as gym
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import OldAPIStack
from ray.rllib.utils.typing import TensorType, List, Union, ModelConfigDict
@OldAPIStack
class ActionDistribution:
"""The policy action distribution of an agent.
Attributes:
inputs: input vector to compute samples from.
model (ModelV2): reference to model producing the inputs.
"""
def __init__(self, inputs: List[TensorType], model: ModelV2):
"""Initializes an ActionDist object.
Args:
inputs: input vector to compute samples from.
model (ModelV2): reference to model producing the inputs. This
is mainly useful if you want to use model variables to compute
action outputs (i.e., for autoregressive action distributions,
see examples/autoregressive_action_dist.py).
"""
self.inputs = inputs
self.model = model
def sample(self) -> TensorType:
"""Draw a sample from the action distribution."""
raise NotImplementedError
def deterministic_sample(self) -> TensorType:
"""
Get the deterministic "sampling" output from the distribution.
This is usually the max likelihood output, i.e. mean for Normal, argmax
for Categorical, etc..
"""
raise NotImplementedError
def sampled_action_logp(self) -> TensorType:
"""Returns the log probability of the last sampled action."""
raise NotImplementedError
def logp(self, x: TensorType) -> TensorType:
"""The log-likelihood of the action distribution."""
raise NotImplementedError
def kl(self, other: "ActionDistribution") -> TensorType:
"""The KL-divergence between two action distributions."""
raise NotImplementedError
def entropy(self) -> TensorType:
"""The entropy of the action distribution."""
raise NotImplementedError
def multi_kl(self, other: "ActionDistribution") -> TensorType:
"""The KL-divergence between two action distributions.
This differs from kl() in that it can return an array for
MultiDiscrete. TODO(ekl) consider removing this.
"""
return self.kl(other)
def multi_entropy(self) -> TensorType:
"""The entropy of the action distribution.
This differs from entropy() in that it can return an array for
MultiDiscrete. TODO(ekl) consider removing this.
"""
return self.entropy()
@staticmethod
@OldAPIStack
def required_model_output_shape(
action_space: gym.Space, model_config: ModelConfigDict
) -> Union[int, np.ndarray]:
"""Returns the required shape of an input parameter tensor for a
particular action space and an optional dict of distribution-specific
options.
Args:
action_space (gym.Space): The action space this distribution will
be used for, whose shape attributes will be used to determine
the required shape of the input parameter tensor.
model_config: Model's config dict (as defined in catalog.py)
Returns:
model_output_shape (int or np.ndarray of ints): size of the
required input vector (minus leading batch dimension).
"""
raise NotImplementedError