Skip to content

Commit

Permalink
Fuse LeakyReLU along with other non-ReLU activations (openvinotoolkit…
Browse files Browse the repository at this point in the history
…#1076)

### Changes

torch.nn.functional.leaky_relu nodes are now properly fused with preceding operations.

### Reason for changes

Purported performance increases from fusing.

### Related tickets

N/A

### Tests
tests.torch.test_compressed_graph.test_synthetic_model_quantization

<!--- How was the correctness of changes tested and whether new tests were added -->
  • Loading branch information
vshampor authored Jan 28, 2022
1 parent 848bc2a commit 6ace468
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 1 deletion.
1 change: 1 addition & 0 deletions nncf/torch/graph/pattern_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
NON_RELU_ACTIVATIONS_OPERATIONS = {'type': ['elu',
'elu_',
'prelu',
'leaky_relu',
'sigmoid',
'gelu',
'silu',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
strict digraph {
"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
"1 SymmetricQuantizer/symmetric_quantize_0" [id=1, type=symmetric_quantize];
"2 ConvBNLeakyReLU/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" [id=2, type=symmetric_quantize];
"3 ConvBNLeakyReLU/NNCFConv2d[conv]/conv2d_0" [id=3, type=conv2d];
"4 ConvBNLeakyReLU/NNCFBatchNorm[bn]/batch_norm_0" [id=4, type=batch_norm];
"5 ConvBNLeakyReLU/leaky_relu_0" [id=5, type=leaky_relu];
"6 /nncf_model_output_0" [id=6, type=nncf_model_output];
"0 /nncf_model_input_0" -> "1 SymmetricQuantizer/symmetric_quantize_0";
"1 SymmetricQuantizer/symmetric_quantize_0" -> "3 ConvBNLeakyReLU/NNCFConv2d[conv]/conv2d_0";
"2 ConvBNLeakyReLU/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" -> "3 ConvBNLeakyReLU/NNCFConv2d[conv]/conv2d_0";
"3 ConvBNLeakyReLU/NNCFConv2d[conv]/conv2d_0" -> "4 ConvBNLeakyReLU/NNCFBatchNorm[bn]/batch_norm_0";
"4 ConvBNLeakyReLU/NNCFBatchNorm[bn]/batch_norm_0" -> "5 ConvBNLeakyReLU/leaky_relu_0";
"5 ConvBNLeakyReLU/leaky_relu_0" -> "6 /nncf_model_output_0";
}
4 changes: 3 additions & 1 deletion tests/torch/test_compressed_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch

from nncf.torch.utils import get_model_device
from tests.torch.test_models.synthetic import ConvBNLeakyReLU
from tests.torch.test_models.synthetic import ConvRelu6HSwishHSigmoid
from tests.torch.test_models.synthetic import MMDivConv
from tests.torch.test_models.synthetic import MatMulDivConv
Expand Down Expand Up @@ -728,7 +729,8 @@ def forward(self, x):
GeneralModelDesc(model_builder=MultiOutputSameTensorModel),
GeneralModelDesc(model_builder=MatMulDivConv, input_sample_sizes=([1, 1, 5, 5], [1, 1, 5, 5])),
GeneralModelDesc(model_builder=MMDivConv, input_sample_sizes=([5, 5], [5, 5])),
GeneralModelDesc(model_builder=ConvRelu6HSwishHSigmoid, input_sample_sizes=([1, 1, 5, 5],))
GeneralModelDesc(model_builder=ConvRelu6HSwishHSigmoid, input_sample_sizes=([1, 1, 5, 5],)),
GeneralModelDesc(model_builder=ConvBNLeakyReLU, input_sample_sizes=([1, 1, 5, 5],))
]


Expand Down
15 changes: 15 additions & 0 deletions tests/torch/test_models/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import torch
import torch.nn.functional as F
from abc import abstractmethod

from torch.nn import BatchNorm2d

from tests.torch.helpers import create_conv
from torch import nn
from torch.nn import Dropout
Expand Down Expand Up @@ -271,3 +274,15 @@ def forward(self, x: torch.Tensor):
z = self.conv2(z)
z = self._hsigmoid(z)
return z

class ConvBNLeakyReLU(nn.Module):
def __init__(self):
super().__init__()
self.conv = create_conv(1, 2, 2, 2)
self.bn = BatchNorm2d(2)

def forward(self, x: torch.Tensor):
z = self.conv(x)
z = self.bn(z)
z = torch.nn.functional.leaky_relu(z)
return z

0 comments on commit 6ace468

Please sign in to comment.