Skip to content

Commit

Permalink
Fix weight norm applied to modules (openvinotoolkit#1058)
Browse files Browse the repository at this point in the history
* Add test for WeightNorm-ed Conv2d wrapping

* Handle WeightNorm-ed modules when wrapping NNCF* modules
  • Loading branch information
vshampor authored Jan 3, 2022
1 parent 4baee09 commit d379803
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 15 deletions.
55 changes: 40 additions & 15 deletions nncf/torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Dict

import math
import numbers
from typing import Optional
Expand All @@ -21,6 +23,7 @@
from torch import nn
from torch.nn import init
from torch.nn.utils.rnn import PackedSequence
from torch.nn.utils.weight_norm import WeightNorm

from nncf.torch.dynamic_graph.context import forward_nncf_trace
from nncf.torch.utils import no_jit_trace
Expand All @@ -30,12 +33,34 @@
from nncf.torch.layer_utils import _NNCFModuleMixin


def dict_update(src, dst, recursive=True):
for name, value in dst.items():
if recursive and name in src and isinstance(value, dict):
dict_update(src[name], value, recursive)
def dict_update(src: Dict, dst: Dict, recursive: bool = True):
for name, value in src.items():
if recursive and name in dst and isinstance(value, dict):
dict_update(value, dst[name], recursive)
else:
src[name] = value
dst[name] = value


def maybe_reapply_weight_norm(src: torch.nn.Module, dst: torch.nn.Module) -> torch.nn.Module:
#pylint:disable=protected-access
for k, hook in src._forward_pre_hooks.items():
if isinstance(hook, WeightNorm):
# The code below presumes that the `hook` object does not
# contain internal references to the module it was set up on
# (i.e. to the `src`) and takes the module to act on as a parameter.
# This is the case for the `WeightNorm` hook.
hook.remove(dst)
del dst._forward_pre_hooks[k]
name = hook.name
dim = hook.dim
WeightNorm.apply(dst, name=name, dim=dim)
return dst


def align_module_internals(src: torch.nn.Module, dst: torch.nn.Module) -> torch.nn.Module:
dict_update(src.__dict__, dst.__dict__)
dst = maybe_reapply_weight_norm(src, dst)
return dst


class NNCFConv1d(_NNCFModuleMixin, nn.Conv1d):
Expand All @@ -48,7 +73,7 @@ def from_module(module):
module.in_channels, module.out_channels, module.kernel_size, module.stride,
module.padding, module.dilation, module.groups, hasattr(module, 'bias')
)
dict_update(nncf_conv.__dict__, module.__dict__)
nncf_conv = align_module_internals(module, nncf_conv)
return nncf_conv


Expand Down Expand Up @@ -88,7 +113,7 @@ def from_module(module):
module.in_channels, module.out_channels, module.kernel_size, module.stride,
module.padding, module.dilation, module.groups, hasattr(module, 'bias')
)
dict_update(nncf_conv.__dict__, module.__dict__)
nncf_conv = align_module_internals(module, nncf_conv)
return nncf_conv

# override class attribute of _NNCFModuleMixin
Expand Down Expand Up @@ -123,7 +148,7 @@ def from_module(module):
assert module.__class__.__name__ == nn.Linear.__name__

nncf_linear = NNCFLinear(module.in_features, module.out_features, hasattr(module, 'bias'))
dict_update(nncf_linear.__dict__, module.__dict__)
nncf_linear = align_module_internals(module, nncf_linear)
return nncf_linear


Expand All @@ -136,7 +161,7 @@ def from_module(module):
assert module.__class__.__name__ == nn.BatchNorm2d.__name__

nncf_bn = NNCFBatchNorm(module.num_features)
dict_update(nncf_bn.__dict__, module.__dict__)
nncf_bn = align_module_internals(module, nncf_bn)
return nncf_bn


Expand All @@ -152,7 +177,7 @@ def from_module(module):
num_channels=module.num_channels,
eps=module.eps,
affine=module.affine)
dict_update(nncf_bn.__dict__, module.__dict__)
nncf_bn = align_module_internals(module, nncf_bn)
return nncf_bn


Expand All @@ -169,7 +194,7 @@ def from_module(module):
if hasattr(module, 'padding_mode'):
args.append(module.padding_mode)
nncf_conv_transpose2d = NNCFConvTranspose2d(*args)
dict_update(nncf_conv_transpose2d.__dict__, module.__dict__)
nncf_conv_transpose2d = align_module_internals(module, nncf_conv_transpose2d)
return nncf_conv_transpose2d


Expand All @@ -184,7 +209,7 @@ def from_module(module):
module.in_channels, module.out_channels, module.kernel_size, module.stride,
module.padding, module.dilation, module.groups, hasattr(module, 'bias')
)
dict_update(nncf_conv3d.__dict__, module.__dict__)
nncf_conv3d = align_module_internals(module, nncf_conv3d)
return nncf_conv3d

class NNCFConvTranspose3d(_NNCFModuleMixin, nn.ConvTranspose3d):
Expand All @@ -200,7 +225,7 @@ def from_module(module):
if hasattr(module, 'padding_mode'):
args.append(module.padding_mode)
nncf_conv_transpose3d = NNCFConvTranspose3d(*args)
dict_update(nncf_conv_transpose3d.__dict__, module.__dict__)
nncf_conv_transpose3d = align_module_internals(module, nncf_conv_transpose3d)
return nncf_conv_transpose3d


Expand All @@ -216,7 +241,7 @@ def from_module(module):
module.max_norm, module.norm_type, module.scale_grad_by_freq,
module.sparse, module.weight]
nncf_embedding = NNCFEmbedding(*args)
dict_update(nncf_embedding.__dict__, module.__dict__)
nncf_embedding = align_module_internals(module, nncf_embedding)
return nncf_embedding


Expand All @@ -233,7 +258,7 @@ def from_module(module):
module.mode, module.sparse, module.weight,
module.include_last_offset]
nncf_embedding_bag = NNCFEmbeddingBag(*args)
dict_update(nncf_embedding_bag.__dict__, module.__dict__)
nncf_embedding_bag = align_module_internals(module, nncf_embedding_bag)
return nncf_embedding_bag

NNCF_MODULES_DICT = {
Expand Down
26 changes: 26 additions & 0 deletions tests/torch/test_nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pytest
import torch
from torch import nn
from torch.nn.utils import weight_norm

from nncf.common.graph import BaseLayerAttributes
from nncf.common.graph import NNCFGraph
Expand Down Expand Up @@ -124,6 +125,31 @@ def test_check_correct_modules_replacement():
assert set(nncf_modules) == set(nncf_model.get_nncf_modules())


class WeightNormedConvModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = weight_norm(torch.nn.Conv1d(1, 1, 1))

def forward(self, x):
return self.conv(x)


def test_weight_normed_modules_are_replaced_correctly():
nncf_model = NNCFNetwork(WeightNormedConvModel(), input_infos=[ModelInputInfo([1, 1, 10])])

wrapped_conv = nncf_model.conv
assert hasattr(wrapped_conv, "weight_g")
assert hasattr(wrapped_conv, "weight_v")
assert hasattr(wrapped_conv, "weight")

assert isinstance(wrapped_conv.weight_g, torch.nn.Parameter)
assert isinstance(wrapped_conv.weight_v, torch.nn.Parameter)
assert not isinstance(wrapped_conv.weight, torch.nn.Parameter)

#pylint:disable=protected-access
assert len(wrapped_conv._forward_pre_hooks) == 1


@register_module()
class ModuleOfUser(torch.nn.Module):
def __init__(self):
Expand Down

0 comments on commit d379803

Please sign in to comment.