forked from GRAAL-Research/poutyne
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove variables and add device management inside Model
- Loading branch information
Showing
10 changed files
with
211 additions
and
136 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,80 +1,47 @@ | ||
import torch | ||
from torch.autograd import Variable | ||
|
||
|
||
def torch_to_numpy(obj): | ||
""" | ||
Convert to numpy arrays all tensors and variables inside a Python object | ||
composed of the supported types. | ||
Args: | ||
obj: The Python object to convert. | ||
Returns: | ||
A new Python object with the same structure as `obj` but where the | ||
tensors and variables are now numpy arrays. Not supported type are left | ||
as reference in the new object. | ||
""" | ||
if isinstance(obj, Variable): | ||
obj = obj.data | ||
if isinstance(obj, list) or isinstance(obj, tuple): | ||
return type(obj)(torch_to_numpy(el) for el in obj) | ||
if isinstance(obj, dict): | ||
return {k:torch_to_numpy(el) for k,el in obj.items()} | ||
if not torch.is_tensor(obj): | ||
return obj | ||
return obj.cpu().numpy() | ||
|
||
def tensors_to_variables(obj, *args, **kwargs): | ||
""" | ||
Convert to variables all tensors inside a Python object composed of the | ||
Convert to numpy arrays all tensors inside a Python object composed of the | ||
supported types. | ||
Args: | ||
obj: The Python object to convert. | ||
*args: The arguments to pass to the Variable constructor. | ||
**kwargs: The keyword arguments to pass to the Variable constructor. | ||
Returns: | ||
A new Python object with the same structure as `obj` but where the | ||
tensors are now variables. | ||
tensors are now numpy arrays. Not supported type are left as reference | ||
in the new object. | ||
Raises: | ||
ValueError: If a not supported type is inside `obj`. | ||
See: | ||
`pytoune.torch_apply` for supported types. | ||
""" | ||
if isinstance(obj, Variable): | ||
return obj | ||
if torch.is_tensor(obj): | ||
return Variable(obj, *args, **kwargs) | ||
if isinstance(obj, list) or isinstance(obj, tuple): | ||
return type(obj)(tensors_to_variables(el, *args, **kwargs) for el in obj) | ||
if isinstance(obj, dict): | ||
return {k:tensors_to_variables(el, *args, **kwargs) for k,el in obj.items()} | ||
return torch_apply(obj, lambda t: t.cpu().numpy()) | ||
|
||
raise ValueError("The type '%s' is not supported by this function." % type(obj).__name__) | ||
def torch_to(obj, *args, **kargs): | ||
return torch_apply(obj, lambda t: t.to(*args, **kargs)) | ||
|
||
def variables_to_tensors(obj): | ||
def torch_apply(obj, func): | ||
""" | ||
Convert to tensors all variables inside a Python object composed of the | ||
Apply a function to all tensors inside a Python object composed of the | ||
supported types. | ||
Supported types are: list, tuple and dict. | ||
Args: | ||
obj: The Python object to convert. | ||
func: The function to apply. | ||
Returns: | ||
A new Python object with the same structure as `obj` but where the | ||
variables are now tensors. | ||
Raises: | ||
ValueError: If a not supported type is inside `obj`. | ||
tensors have been applied the function `func`. Not supported type are | ||
left as reference in the new object. | ||
""" | ||
if torch.is_tensor(obj): | ||
return obj | ||
if isinstance(obj, Variable): | ||
return obj.data | ||
if isinstance(obj, list) or isinstance(obj, tuple): | ||
return type(obj)(variables_to_tensors(el) for el in obj) | ||
return type(obj)(torch_apply(el, func) for el in obj) | ||
if isinstance(obj, dict): | ||
return {k:variables_to_tensors(el) for k,el in obj.items()} | ||
|
||
raise ValueError("The type '%s' is not supported by this function." % type(obj).__name__) | ||
return {k:torch_apply(el, func) for k,el in obj.items()} | ||
if not torch.is_tensor(obj): | ||
return obj | ||
return func(obj) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.