forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtyping.py
310 lines (240 loc) · 10.7 KB
/
typing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Hashable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
import numpy as np
import gymnasium as gym
from ray.rllib.utils.annotations import OldAPIStack
if TYPE_CHECKING:
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils import try_import_jax, try_import_tf, try_import_torch
_, tf, _ = try_import_tf()
torch, _ = try_import_torch()
jax, _ = try_import_jax()
jnp = None
if jax is not None:
jnp = jax.numpy
# Represents a generic tensor type.
# This could be an np.ndarray, tf.Tensor, or a torch.Tensor.
TensorType = Union[np.array, "jnp.ndarray", "tf.Tensor", "torch.Tensor"]
# Either a plain tensor, or a dict or tuple of tensors (or StructTensors).
TensorStructType = Union[TensorType, dict, tuple]
# A shape of a tensor.
TensorShape = Union[Tuple[int], List[int]]
# A neural network
NetworkType = Union["torch.nn.Module", "tf.keras.Model"]
# An RLModule spec (single-agent or multi-agent).
RLModuleSpecType = Union["RLModuleSpec", "MultiRLModuleSpec"]
# A state dict of an RLlib component (e.g. EnvRunner, Learner, RLModule).
StateDict = Dict[str, Any]
# Represents a fully filled out config of a Algorithm class.
# Note: Policy config dicts are usually the same as AlgorithmConfigDict, but
# parts of it may sometimes be altered in e.g. a multi-agent setup,
# where we have >1 Policies in the same Algorithm.
AlgorithmConfigDict = dict # @OldAPIStack
# An algorithm config dict that only has overrides. It needs to be combined with
# the default algorithm config to be used.
PartialAlgorithmConfigDict = dict # @OldAPIStack
# Represents the model config sub-dict of the algo config that is passed to
# the model catalog.
ModelConfigDict = dict # @OldAPIStack
# Conv2D configuration format.
# Each entry in the outer list represents one Conv2D layer.
# Each inner list has the format: [num_output_filters, kernel, stride], where kernel
# and stride may be single ints (width and height are the same) or 2-tuples (int, int)
# for width and height (different values).
ConvFilterSpec = List[
Tuple[int, Union[int, Tuple[int, int]], Union[int, Tuple[int, int]]]
]
# Objects that can be created through the `from_config()` util method
# need a config dict with a "type" key, a class path (str), or a type directly.
FromConfigSpec = Union[Dict[str, Any], type, str]
# Represents the env_config sub-dict of the algo config that is passed to
# the env constructor.
EnvConfigDict = dict
# Represents an environment id. These could be:
# - An int index for a sub-env within a vectorized env.
# - An external env ID (str), which changes(!) each episode.
EnvID = Union[int, str]
# Represents a BaseEnv, MultiAgentEnv, ExternalEnv, ExternalMultiAgentEnv,
# VectorEnv, gym.Env, or ActorHandle.
# TODO (sven): Specify this type more strictly (it should just be gym.Env).
EnvType = Union[Any, gym.Env]
# A callable, taking a EnvContext object
# (config dict + properties: `worker_index`, `vector_index`, `num_workers`,
# and `remote`) and returning an env object (or None if no env is used).
EnvCreator = Callable[["EnvContext"], Optional[EnvType]]
# Represents a generic identifier for an agent (e.g., "agent1").
AgentID = Any
# Represents a generic identifier for a policy (e.g., "pol1").
PolicyID = str # @OldAPIStack
# Represents a generic identifier for a (single-agent) RLModule.
ModuleID = str
# Type of the config.policies dict for multi-agent training.
MultiAgentPolicyConfigDict = Dict[PolicyID, "PolicySpec"] # @OldAPIStack
# A new stack Episode type: Either single-agent or multi-agent.
EpisodeType = Union["SingleAgentEpisode", "MultiAgentEpisode"]
# Is Policy to train callable.
# @OldAPIStack
IsPolicyToTrain = Callable[[PolicyID, Optional["MultiAgentBatch"]], bool]
# Agent to module mapping and should-module-be-updated.
AgentToModuleMappingFn = Callable[[AgentID, EpisodeType], ModuleID]
ShouldModuleBeUpdatedFn = Union[
Sequence[ModuleID],
Callable[[ModuleID, Optional["MultiAgentBatch"]], bool],
]
# State dict of a Policy, mapping strings (e.g. "weights") to some state
# data (TensorStructType).
PolicyState = Dict[str, TensorStructType] # @OldAPIStack
# Any tf Policy type (static-graph or eager Policy).
TFPolicyV2Type = Type[Union["DynamicTFPolicyV2", "EagerTFPolicyV2"]] # @OldAPIStack
# Represents an episode id (old and new API stack).
EpisodeID = Union[int, str]
# Represents an "unroll" (maybe across different sub-envs in a vector env).
UnrollID = int # @OldAPIStack
# A dict keyed by agent ids, e.g. {"agent-1": value}.
MultiAgentDict = Dict[AgentID, Any]
# A dict keyed by env ids that contain further nested dictionaries keyed by
# agent ids. e.g., {"env-1": {"agent-1": value}}.
MultiEnvDict = Dict[EnvID, MultiAgentDict]
# Represents an observation returned from the env.
EnvObsType = Any
# Represents an action passed to the env.
EnvActionType = Any
# Info dictionary returned by calling `reset()` or `step()` on `gymnasium.Env`
# instances. Might be an empty dict.
EnvInfoDict = dict
# Represents a File object
FileType = Any
# Represents a ViewRequirements dict mapping column names (str) to
# ViewRequirement objects.
ViewRequirementsDict = Dict[str, "ViewRequirement"] # @OldAPIStack
# Represents the result dict returned by Algorithm.train() and algorithm components,
# such as EnvRunners, LearnerGroup, etc.. Also, the MetricsLogger used by all these
# components returns this upon its `reduce()` method call, so a ResultDict can further
# be accumulated (and reduced again) by downstream components.
ResultDict = Dict
# A tf or torch local optimizer object.
LocalOptimizer = Union["torch.optim.Optimizer", "tf.keras.optimizers.Optimizer"]
Optimizer = LocalOptimizer
Param = Union["torch.Tensor", "tf.Variable"]
ParamRef = Hashable
ParamDict = Dict[ParamRef, Param]
ParamList = List[Param]
# A single learning rate or a learning rate schedule (list of sub-lists, each of
# the format: [ts (int), lr_to_reach_by_ts (float)]).
LearningRateOrSchedule = Union[
float,
List[List[Union[int, float]]],
List[Tuple[int, Union[int, float]]],
]
# Dict of tensors returned by compute gradients on the policy, e.g.,
# {"td_error": [...], "learner_stats": {"vf_loss": ..., ...}}, for multi-agent,
# {"policy1": {"learner_stats": ..., }, "policy2": ...}.
GradInfoDict = dict
# Dict of learner stats returned by compute gradients on the policy, e.g.,
# {"vf_loss": ..., ...}. This will always be nested under the "learner_stats"
# key(s) of a GradInfoDict. In the multi-agent case, this will be keyed by
# policy id.
LearnerStatsDict = dict
# List of grads+var tuples (tf) or list of gradient tensors (torch)
# representing model gradients and returned by compute_gradients().
ModelGradients = Union[List[Tuple[TensorType, TensorType]], List[TensorType]]
# Type of dict returned by get_weights() representing model weights.
ModelWeights = dict
# An input dict used for direct ModelV2 calls.
ModelInputDict = Dict[str, TensorType]
# Some kind of sample batch.
SampleBatchType = Union["SampleBatch", "MultiAgentBatch", Dict[str, Any]]
# A (possibly nested) space struct: Either a gym.spaces.Space or a
# (possibly nested) dict|tuple of gym.space.Spaces.
SpaceStruct = Union[gym.spaces.Space, dict, tuple]
# A list of batches of RNN states.
# Each item in this list has dimension [B, S] (S=state vector size)
StateBatches = List[List[Any]] # @OldAPIStack
# Format of data output from policy forward pass.
# __sphinx_doc_begin_policy_output_type__
PolicyOutputType = Tuple[TensorStructType, StateBatches, Dict] # @OldAPIStack
# __sphinx_doc_end_policy_output_type__
# __sphinx_doc_begin_agent_connector_data_type__
@OldAPIStack
class AgentConnectorDataType:
"""Data type that is fed into and yielded from agent connectors.
Args:
env_id: ID of the environment.
agent_id: ID to help identify the agent from which the data is received.
data: A payload (``data``). With RLlib's default sampler, the payload
is a dictionary of arbitrary data columns (obs, rewards, terminateds,
truncateds, etc).
"""
def __init__(self, env_id: str, agent_id: str, data: Any):
self.env_id = env_id
self.agent_id = agent_id
self.data = data
# __sphinx_doc_end_agent_connector_data_type__
# __sphinx_doc_begin_action_connector_output__
@OldAPIStack
class ActionConnectorDataType:
"""Data type that is fed into and yielded from agent connectors.
Args:
env_id: ID of the environment.
agent_id: ID to help identify the agent from which the data is received.
input_dict: Input data that was passed into the policy.
Sometimes output must be adapted based on the input, for example
action masking. So the entire input data structure is provided here.
output: An object of PolicyOutputType. It is is composed of the
action output, the internal state output, and additional data fetches.
"""
def __init__(
self,
env_id: str,
agent_id: str,
input_dict: TensorStructType,
output: PolicyOutputType,
):
self.env_id = env_id
self.agent_id = agent_id
self.input_dict = input_dict
self.output = output
# __sphinx_doc_end_action_connector_output__
# __sphinx_doc_begin_agent_connector_output__
@OldAPIStack
class AgentConnectorsOutput:
"""Final output data type of agent connectors.
Args are populated depending on the AgentConnector settings.
The branching happens in ViewRequirementAgentConnector.
Args:
raw_dict: The raw input dictionary that sampler can use to
build episodes and training batches.
This raw dict also gets passed into ActionConnectors in case
it contains data useful for action adaptation (e.g. action masks).
sample_batch: The SampleBatch that can be immediately used for
querying the policy for next action.
"""
def __init__(
self, raw_dict: Dict[str, TensorStructType], sample_batch: "SampleBatch"
):
self.raw_dict = raw_dict
self.sample_batch = sample_batch
# __sphinx_doc_end_agent_connector_output__
# Generic type var.
T = TypeVar("T")