Skip to content

Commit

Permalink
Fix flops counter for bs>1 (microsoft#4154)
Browse files Browse the repository at this point in the history
  • Loading branch information
ultmaster authored Sep 10, 2021
1 parent e98ebcf commit bcc55c5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 31 deletions.
37 changes: 14 additions & 23 deletions nni/compression/pytorch/utils/counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,43 +121,43 @@ def _count_linear(self, m, x, y):
return self._get_result(m, total_ops)

def _count_bn(self, m, x, y):
total_ops = 2 * x[0].numel()
total_ops = 2 * x[0][0].numel()
return self._get_result(m, total_ops)

def _count_relu(self, m, x, y):
total_ops = x[0].numel()
total_ops = x[0][0].numel()
return self._get_result(m, total_ops)

def _count_avgpool(self, m, x, y):
total_ops = y.numel()
total_ops = y[0].numel()
return self._get_result(m, total_ops)

def _count_adap_avgpool(self, m, x, y):
kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze()
total_add = int(torch.prod(kernel))
total_div = 1
kernel_ops = total_add + total_div
num_elements = y.numel()
num_elements = y[0].numel()
total_ops = kernel_ops * num_elements

return self._get_result(m, total_ops)

def _count_upsample(self, m, x, y):
if m.mode == 'linear':
total_ops = y.nelement() * 5 # 2 muls + 3 add
total_ops = y[0].nelement() * 5 # 2 muls + 3 add
elif m.mode == 'bilinear':
# https://en.wikipedia.org/wiki/Bilinear_interpolation
total_ops = y.nelement() * 11 # 6 muls + 5 adds
total_ops = y[0].nelement() * 11 # 6 muls + 5 adds
elif m.mode == 'bicubic':
# https://en.wikipedia.org/wiki/Bicubic_interpolation
# Product matrix [4x4] x [4x4] x [4x4]
ops_solve_A = 224 # 128 muls + 96 adds
ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds
total_ops = y.nelement() * (ops_solve_A + ops_solve_p)
total_ops = y[0].nelement() * (ops_solve_A + ops_solve_p)
elif m.mode == 'trilinear':
# https://en.wikipedia.org/wiki/Trilinear_interpolation
# can viewed as 2 bilinear + 1 linear
total_ops = y.nelement() * (13 * 2 + 5)
total_ops = y[0].nelement() * (13 * 2 + 5)
else:
total_ops = 0

Expand Down Expand Up @@ -202,26 +202,16 @@ def _count_cell_flops(self, input_size, hidden_size, cell_type):

return total_ops


def _count_rnn_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'rnn')
batch_size = x[0].size(0)
total_ops *= batch_size

return self._get_result(m, total_ops)

def _count_gru_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'gru')
batch_size = x[0].size(0)
total_ops *= batch_size

return self._get_result(m, total_ops)

def _count_lstm_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'lstm')
batch_size = x[0].size(0)
total_ops *= batch_size

return self._get_result(m, total_ops)

def _get_bsize_nsteps(self, m, x):
Expand All @@ -243,18 +233,17 @@ def _count_rnn_module(self, m, x, y, module_name):
hidden_size = m.hidden_size
num_layers = m.num_layers

batch_size, num_steps = self._get_bsize_nsteps(m, x)
_, num_steps = self._get_bsize_nsteps(m, x)
total_ops = self._count_cell_flops(input_size, hidden_size, module_name)

for _ in range(num_layers - 1):
if m.bidirectional:
cell_flops = self._count_cell_flops(hidden_size * 2, hidden_size, module_name) * 2
else:
cell_flops = self._count_cell_flops(hidden_size, hidden_size,module_name)
cell_flops = self._count_cell_flops(hidden_size, hidden_size, module_name)
total_ops += cell_flops

total_ops *= num_steps
total_ops *= batch_size
return total_ops

def _count_rnn(self, m, x, y):
Expand All @@ -272,7 +261,6 @@ def _count_lstm(self, m, x, y):

return self._get_result(m, total_ops)


def count_module(self, m, x, y, name):
# assume x is tuple of single tensor
result = self.ops[type(m)](m, x, y)
Expand Down Expand Up @@ -335,7 +323,10 @@ def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
identify the mask on the module and take the pruned shape into consideration.
Note that, for sturctured pruning, we only identify the remained filters
according to its mask, and do not take the pruned input channels into consideration,
so the calculated FLOPs will be larger than real number.
so the calculated FLOPs will be larger than real number.
The FLOPs is counted "per sample", which means that input has a batch size larger than 1,
the calculated FLOPs should not differ from batch size of 1.
Parameters
---------
Expand Down
18 changes: 10 additions & 8 deletions test/ut/sdk/test_compression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def test_mask_conflict(self):
assert b_index1 == b_index2


class FlopsCounterTest(TestCase):
def test_flops_params(self):
class Model1(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -170,16 +171,17 @@ def forward(self, x):
for _ in range(5):
x = self.conv2(x)
return x

flops, params, results = count_flops_params(Model1(), (1, 3, 2, 2), mode='full', verbose=False)
assert (flops, params) == (610, 240)

flops, params, results = count_flops_params(Model2(), (1, 3, 2, 2), verbose=False)
assert (flops, params) == (560, 50)
for bs in [1, 2]:
flops, params, results = count_flops_params(Model1(), (bs, 3, 2, 2), mode='full', verbose=False)
assert (flops, params) == (610, 240)

from torchvision.models import resnet50
flops, params, results = count_flops_params(resnet50(), (1, 3, 224, 224), verbose=False)
assert (flops, params) == (4089184256, 25503912)
flops, params, results = count_flops_params(Model2(), (bs, 3, 2, 2), verbose=False)
assert (flops, params) == (560, 50)

from torchvision.models import resnet50
flops, params, results = count_flops_params(resnet50(), (bs, 3, 224, 224), verbose=False)
assert (flops, params) == (4089184256, 25503912)


if __name__ == '__main__':
Expand Down

0 comments on commit bcc55c5

Please sign in to comment.