forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathframework.py
408 lines (332 loc) · 12.9 KB
/
framework.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
import logging
import numpy as np
import os
import sys
from typing import Any, Optional, TYPE_CHECKING
import tree # pip install dm_tree
import ray
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.typing import (
TensorShape,
TensorStructType,
TensorType,
)
if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
logger = logging.getLogger(__name__)
@PublicAPI
def convert_to_tensor(
data: TensorStructType,
framework: str,
device: Optional[str] = None,
):
"""Converts any nested numpy struct into framework-specific tensors.
Args:
data: The input data (numpy) to convert to framework-specific tensors.
framework: The framework to convert to. Only "torch" and "tf2" allowed.
device: An optional device name (for torch only).
Returns:
The converted tensor struct matching the input data.
"""
if framework == "torch":
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
return convert_to_torch_tensor(data, device=device)
elif framework == "tf2":
_, tf, _ = try_import_tf()
return tree.map_structure(lambda s: tf.convert_to_tensor(s), data)
raise NotImplementedError(
f"framework={framework} not supported in `convert_to_tensor()`!"
)
@PublicAPI
def get_device(config: "AlgorithmConfig", num_gpus_requested: int = 1):
"""Returns a single device (CPU or some GPU) depending on a config.
Args:
config: An AlgorithmConfig to extract information from about the device to use.
num_gpus_requested: The number of GPUs actually requested. This may be the value
of `config.num_gpus_per_env_runner` when for example calling this function
from an EnvRunner.
Returns:
A single device (or name) given `config` and `num_gpus_requested`.
"""
if config.framework_str == "torch":
torch, _ = try_import_torch()
# TODO (Kourosh): How do we handle model parallelism?
# TODO (Kourosh): Instead of using _TorchAccelerator, we should use the public
# API in ray.train but allow for session to be None without any errors raised.
if num_gpus_requested > 0:
from ray.air._internal.torch_utils import get_devices
# `get_devices()` returns a list that contains the 0th device if
# it is called from outside a Ray Train session. It's necessary to give
# the user the option to run on the gpu of their choice, so we enable that
# option here through `config.local_gpu_idx`.
devices = get_devices()
# Note, if we have a single learner and we do not run on Ray Tune, the local
# learner is not an Ray actor and Ray does not manage devices for it.
if (
len(devices) == 1
and ray._private.worker._mode() == ray._private.worker.WORKER_MODE
):
return devices[0]
else:
assert config.local_gpu_idx < torch.cuda.device_count(), (
f"local_gpu_idx {config.local_gpu_idx} is not a valid GPU ID "
"or is not available."
)
# This is an index into the available CUDA devices. For example, if
# `os.environ["CUDA_VISIBLE_DEVICES"] = "1"` then
# `torch.cuda.device_count() = 1` and torch.device(0) maps to that GPU
# with ID=1 on the node.
return torch.device(config.local_gpu_idx)
else:
return torch.device("cpu")
else:
raise NotImplementedError(
f"`framework_str` {config.framework_str} not supported!"
)
@PublicAPI
def try_import_jax(error: bool = False):
"""Tries importing JAX and FLAX and returns both modules (or Nones).
Args:
error: Whether to raise an error if JAX/FLAX cannot be imported.
Returns:
Tuple containing the jax- and the flax modules.
Raises:
ImportError: If error=True and JAX is not installed.
"""
if "RLLIB_TEST_NO_JAX_IMPORT" in os.environ:
logger.warning("Not importing JAX for test purposes.")
return None, None
try:
import jax
import flax
except ImportError:
if error:
raise ImportError(
"Could not import JAX! RLlib requires you to "
"install at least one deep-learning framework: "
"`pip install [torch|tensorflow|jax]`."
)
return None, None
return jax, flax
@PublicAPI
def try_import_tf(error: bool = False):
"""Tries importing tf and returns the module (or None).
Args:
error: Whether to raise an error if tf cannot be imported.
Returns:
Tuple containing
1) tf1.x module (either from tf2.x.compat.v1 OR as tf1.x).
2) tf module (resulting from `import tensorflow`). Either tf1.x or
2.x. 3) The actually installed tf version as int: 1 or 2.
Raises:
ImportError: If error=True and tf is not installed.
"""
tf_stub = _TFStub()
# Make sure, these are reset after each test case
# that uses them: del os.environ["RLLIB_TEST_NO_TF_IMPORT"]
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
logger.warning("Not importing TensorFlow for test purposes")
return None, tf_stub, None
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
# Try to reuse already imported tf module. This will avoid going through
# the initial import steps below and thereby switching off v2_behavior
# (switching off v2 behavior twice breaks all-framework tests for eager).
was_imported = False
if "tensorflow" in sys.modules:
tf_module = sys.modules["tensorflow"]
was_imported = True
else:
try:
import tensorflow as tf_module
except ImportError:
if error:
raise ImportError(
"Could not import TensorFlow! RLlib requires you to "
"install at least one deep-learning framework: "
"`pip install [torch|tensorflow|jax]`."
)
return None, tf_stub, None
# Try "reducing" tf to tf.compat.v1.
try:
tf1_module = tf_module.compat.v1
tf1_module.logging.set_verbosity(tf1_module.logging.ERROR)
if not was_imported:
tf1_module.disable_v2_behavior()
tf1_module.enable_resource_variables()
tf1_module.logging.set_verbosity(tf1_module.logging.WARN)
# No compat.v1 -> return tf as is.
except AttributeError:
tf1_module = tf_module
if not hasattr(tf_module, "__version__"):
version = 1 # sphinx doc gen
else:
version = 2 if "2." in tf_module.__version__[:2] else 1
return tf1_module, tf_module, version
# Fake module for tf.
class _TFStub:
def __init__(self) -> None:
self.keras = _KerasStub()
def __bool__(self):
# if tf should return False
return False
# Fake module for tf.keras.
class _KerasStub:
def __init__(self) -> None:
self.Model = _FakeTfClassStub
# Fake classes under keras (e.g for tf.keras.Model)
class _FakeTfClassStub:
def __init__(self, *a, **kw):
raise ImportError("Could not import `tensorflow`. Try pip install tensorflow.")
@DeveloperAPI
def tf_function(tf_module):
"""Conditional decorator for @tf.function.
Use @tf_function(tf) instead to avoid errors if tf is not installed."""
# The actual decorator to use (pass in `tf` (which could be None)).
def decorator(func):
# If tf not installed -> return function as is (won't be used anyways).
if tf_module is None or tf_module.executing_eagerly():
return func
# If tf installed, return @tf.function-decorated function.
return tf_module.function(func)
return decorator
@PublicAPI
def try_import_tfp(error: bool = False):
"""Tries importing tfp and returns the module (or None).
Args:
error: Whether to raise an error if tfp cannot be imported.
Returns:
The tfp module.
Raises:
ImportError: If error=True and tfp is not installed.
"""
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
logger.warning("Not importing TensorFlow Probability for test purposes.")
return None
try:
import tensorflow_probability as tfp
return tfp
except ImportError as e:
if error:
raise e
return None
# Fake module for torch.nn.
class _NNStub:
def __init__(self, *a, **kw):
# Fake nn.functional module within torch.nn.
self.functional = None
self.Module = _FakeTorchClassStub
self.parallel = _ParallelStub()
# Fake class for e.g. torch.nn.Module to allow it to be inherited from.
class _FakeTorchClassStub:
def __init__(self, *a, **kw):
raise ImportError("Could not import `torch`. Try pip install torch.")
class _ParallelStub:
def __init__(self, *a, **kw):
self.DataParallel = _FakeTorchClassStub
self.DistributedDataParallel = _FakeTorchClassStub
@PublicAPI
def try_import_torch(error: bool = False):
"""Tries importing torch and returns the module (or None).
Args:
error: Whether to raise an error if torch cannot be imported.
Returns:
Tuple consisting of the torch- AND torch.nn modules.
Raises:
ImportError: If error=True and PyTorch is not installed.
"""
if "RLLIB_TEST_NO_TORCH_IMPORT" in os.environ:
logger.warning("Not importing PyTorch for test purposes.")
return _torch_stubs()
try:
import torch
import torch.nn as nn
return torch, nn
except ImportError:
if error:
raise ImportError(
"Could not import PyTorch! RLlib requires you to "
"install at least one deep-learning framework: "
"`pip install [torch|tensorflow|jax]`."
)
return _torch_stubs()
def _torch_stubs():
nn = _NNStub()
return None, nn
@DeveloperAPI
def get_variable(
value: Any,
framework: str = "tf",
trainable: bool = False,
tf_name: str = "unnamed-variable",
torch_tensor: bool = False,
device: Optional[str] = None,
shape: Optional[TensorShape] = None,
dtype: Optional[TensorType] = None,
) -> Any:
"""Creates a tf variable, a torch tensor, or a python primitive.
Args:
value: The initial value to use. In the non-tf case, this will
be returned as is. In the tf case, this could be a tf-Initializer
object.
framework: One of "tf", "torch", or None.
trainable: Whether the generated variable should be
trainable (tf)/require_grad (torch) or not (default: False).
tf_name: For framework="tf": An optional name for the
tf.Variable.
torch_tensor: For framework="torch": Whether to actually create
a torch.tensor, or just a python value (default).
device: An optional torch device to use for
the created torch tensor.
shape: An optional shape to use iff `value`
does not have any (e.g. if it's an initializer w/o explicit value).
dtype: An optional dtype to use iff `value` does
not have any (e.g. if it's an initializer w/o explicit value).
This should always be a numpy dtype (e.g. np.float32, np.int64).
Returns:
A framework-specific variable (tf.Variable, torch.tensor, or
python primitive).
"""
if framework in ["tf2", "tf"]:
import tensorflow as tf
dtype = dtype or getattr(
value,
"dtype",
tf.float32
if isinstance(value, float)
else tf.int32
if isinstance(value, int)
else None,
)
return tf.compat.v1.get_variable(
tf_name,
initializer=value,
dtype=dtype,
trainable=trainable,
**({} if shape is None else {"shape": shape}),
)
elif framework == "torch" and torch_tensor is True:
torch, _ = try_import_torch()
if not isinstance(value, np.ndarray):
value = np.array(value)
var_ = torch.from_numpy(value)
if dtype in [torch.float32, np.float32]:
var_ = var_.float()
elif dtype in [torch.int32, np.int32]:
var_ = var_.int()
elif dtype in [torch.float64, np.float64]:
var_ = var_.double()
if device:
var_ = var_.to(device)
var_.requires_grad = trainable
return var_
# torch or None: Return python primitive.
return value
@Deprecated(
old="rllib/utils/framework.py::get_activation_fn",
new="rllib/models/utils.py::get_activation_fn",
error=True,
)
def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
pass