Skip to content

Commit

Permalink
[follow-up] Python Attr Serialization (pytorch#88913)
Browse files Browse the repository at this point in the history
Ref: pytorch#81616 (comment)
Pull Request resolved: pytorch#88913
Approved by: https://github.com/albanD
  • Loading branch information
kshitij12345 authored and pytorchmergebot committed Jan 13, 2023
1 parent a72bcb3 commit 745fe35
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 25 deletions.
6 changes: 1 addition & 5 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,11 +948,7 @@ def _test_save_load_attr(t):

t = torch.zeros(3, 3)
_test_save_load_attr(t)
# This should start failing once Parameter
# supports saving Python Attribute.
err_msg = "'Parameter' object has no attribute"
with self.assertRaisesRegex(AttributeError, err_msg):
_test_save_load_attr(torch.nn.Parameter(t))
_test_save_load_attr(torch.nn.Parameter(t))

def test_weights_only_assert(self):
class HelloWorld:
Expand Down
2 changes: 0 additions & 2 deletions torch/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,6 @@ def _rebuild_parameter(data, requires_grad, backward_hooks):
return param


# TODO(kshitij12345): Support serializing nn.Parameter with Python Attributes.
# NOTE: We are just defining it here now for future use.
def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state):
param = torch.nn.Parameter(data, requires_grad)
# NB: This line exists only for backwards compatibility; the
Expand Down
27 changes: 13 additions & 14 deletions torch/distributed/optim/apply_optimizer_in_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

__all__: List[str] = []

# WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter
# without changing it's life-time.
# NOTE: Alternative is to add the meta-data as an attribute to the tensor,
# but that will serialize the meta-data if Tensor is serialized.
param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary()
param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary()

@no_type_check
def _apply_optimizer_in_backward(
Expand Down Expand Up @@ -44,19 +50,12 @@ def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None:
# this parameter is ready (has been accumulated into .grad field)

# Don't create a new acc_grad if we already have one
# i.e.f or shared parameters or attaching multiple optimizers to a param.
if not hasattr(param, "acc_grad"):
acc_grad = param.view_as(param).grad_fn.next_functions[0][0]
else:
acc_grad = param._acc_grad
# i.e. for shared parameters or attaching multiple optimizers to a param.
if param not in param_to_acc_grad_map:
param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[0][0]

optimizer = optimizer_class([param], **optimizer_kwargs)

# Keep the grad accumulator around for the lifetime of the Tensor,
# store it on the param to avoid uncollectable ref-cycle
if not hasattr(param, "acc_grad"):
param._acc_grad = acc_grad # type: ignore[attr-defined]

if not hasattr(param, "_in_backward_optimizers"):
param._in_backward_optimizers = [] # type: ignore[attr-defined]
# TODO: investigate whether we really need these attributes.
Expand All @@ -73,10 +72,10 @@ def optimizer_hook(*_unused) -> None:

param.grad = None

handle = param._acc_grad.register_hook(optimizer_hook) # type: ignore[attr-defined]
if not hasattr(param, '_optimizer_hook_handles'):
param._optimizer_hook_handles = [] # type: ignore[attr-defined]
param._optimizer_hook_handles.append(handle) # type: ignore[attr-defined]
handle = param_to_acc_grad_map[param].register_hook(optimizer_hook) # type: ignore[attr-defined]
if param not in param_to_optim_hook_handle_map:
param_to_optim_hook_handle_map[param] = []
param_to_optim_hook_handle_map[param].append(handle)

for param in params:
_apply_optimizer_in_backward_to_param(param)
3 changes: 2 additions & 1 deletion torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,8 +710,9 @@ def __init__(
# Remove hooks that apply_optim_in_backward had registered because
# DDP customizes how optimizer is overlapped with backward due to
# the allreduce.
param_to_handle_map = dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map
for p in self._module_parameters:
for handle in getattr(p, '_optimizer_hook_handles', []):
for handle in param_to_handle_map.get(p, []):
handle.remove()

# Need a weakref to the reducer in order to run all_reduce.
Expand Down
14 changes: 11 additions & 3 deletions torch/nn/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,19 @@ def __repr__(self):
return 'Parameter containing:\n' + super(Parameter, self).__repr__()

def __reduce_ex__(self, proto):
# TODO(kshitij12345): Support saving Python Attribute
state = torch._utils._get_obj_state(self)

# See Note [Don't serialize hooks]
hooks = OrderedDict()
if not state:
return (
torch._utils._rebuild_parameter,
(self.data, self.requires_grad, hooks)
)

return (
torch._utils._rebuild_parameter,
(self.data, self.requires_grad, OrderedDict())
torch._utils._rebuild_parameter_with_state,
(self.data, self.requires_grad, hooks, state)
)

__torch_function__ = _disabled_torch_function_impl
Expand Down

0 comments on commit 745fe35

Please sign in to comment.