Skip to content

Commit

Permalink
update profiler to support multiple positional inputs (facebookresear…
Browse files Browse the repository at this point in the history
…ch#663)

Summary:
Pull Request resolved: facebookresearch#663

Currently, when profiler computes flops/activation for individual modules, it assume the forward function takes a single position argument
     def forward(self, x):
           ....

In some modules, such as the `MatMul` module used in visual transformer (https://fburl.com/diffusion/0tubwixm),  the forward function takes more than one argument.

     def forward(self, A, B):
           ....

Therefore, this diff updates the profiler to support multiple positional inputs in the `forward` function.

Reviewed By: mannatsingh

Differential Revision: D25214757

fbshipit-source-id: c5d759e3244d50e1b0894f50d4bcd84e2907b97b
  • Loading branch information
stephenyan1231 authored and facebook-github-bot committed Dec 6, 2020
1 parent bd5c260 commit 0673482
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 13 deletions.
43 changes: 30 additions & 13 deletions classy_vision/generic/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_shape(x: Union[Tuple, List, Dict]) -> Union[Tuple, List, Dict]:
return x.size()


def _layer_flops(layer: nn.Module, x: Any, y: Any) -> int:
def _layer_flops(layer: nn.Module, layer_args: List[Any], y: Any) -> int:
"""
Computes the number of FLOPs required for a single layer.
Expand All @@ -101,6 +101,7 @@ def flops(self, x):
"""

x = layer_args[0]
# get layer type:
typestr = layer.__repr__()
layer_type = typestr[: typestr.find("(")].strip()
Expand Down Expand Up @@ -323,7 +324,12 @@ def flops(self, x):
# Class MyModule(nn.Module):
# def flops(self, x):
# ...
flops = layer.flops(x)
# or
#
# Class MyModule(nn.Module):
# def flops(self, x1, x2):
# ...
flops = layer.flops(*layer_args)

if flops is None:
raise ClassyProfilerNotImplementedError(layer)
Expand All @@ -336,10 +342,10 @@ def flops(self, x):
f"flops(M): {int(flops) / 1e6}",
]
logging.debug("\t".join(message))
return flops
return int(flops)


def _layer_activations(layer: nn.Module, x: Any, out: Any) -> int:
def _layer_activations(layer: nn.Module, layer_args: List[Any], out: Any) -> int:
"""
Computes the number of activations produced by a single layer.
Expand All @@ -348,20 +354,21 @@ def _layer_activations(layer: nn.Module, x: Any, out: Any) -> int:
will be used to compute the activations instead.
Class MyModule(nn.Module):
def activations(self, x, out):
def activations(self, out, *layer_args):
...
"""

typestr = layer.__repr__()
if hasattr(layer, "activations"):
activations = layer.activations(x, out)
activations = layer.activations(out, *layer_args)
elif isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
activations = out.numel()
else:
return 0

message = [f"module: {typestr}", f"activations: {activations}"]
logging.debug("\t".join(message))
return activations
return int(activations)


def summarize_profiler_info(prof: torch.autograd.profiler.profile) -> str:
Expand Down Expand Up @@ -392,11 +399,13 @@ def __init__(self, compute_fn: Callable, count_unique: bool):
self.count = 0
self.seen_modules = set()

def compute(self, layer: nn.Module, x: Any, out: Any, module_name: str):
def compute(
self, layer: nn.Module, layer_args: List[Any], out: Any, module_name: str
):
if self.count_unique and module_name in self.seen_modules:
return
logging.debug(f"module name: {module_name}")
self.count += self.compute_fn(layer, x, out)
self.count += self.compute_fn(layer, layer_args, out)
logging.debug(f"module name: {module_name}, count {self.count}")
self.seen_modules.add(module_name)

def reset(self):
Expand All @@ -423,7 +432,7 @@ def _original_forward(self, *args, **kwargs):

def forward(self, *args, **kwargs):
out = self._original_forward(*args, **kwargs)
complexity_computer.compute(self, args[0], out, module_name)
complexity_computer.compute(self, list(args), out, module_name)
return out

def __repr__(self):
Expand Down Expand Up @@ -524,7 +533,11 @@ def compute_flops(
Compute the number of FLOPs needed for a forward pass.
"""
return compute_complexity(
model, _layer_flops, input_shape, input_key, patch_attr="flops"
model,
_layer_flops,
input_shape,
input_key,
patch_attr="flops",
)


Expand All @@ -537,7 +550,11 @@ def compute_activations(
Compute the number of activations created in a forward pass.
"""
return compute_complexity(
model, _layer_activations, input_shape, input_key, patch_attr="activations"
model,
_layer_activations,
input_shape,
input_key,
patch_attr="activations",
)


Expand Down
27 changes: 27 additions & 0 deletions test/generic_profiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,28 @@ def flops(self, x):
return 0


class TestModuleWithTwoArguments(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x1, x2):
return x1 + x2

def flops(self, x1, x2):
return x1.numel()


class TestModuleDoubleValue(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10, bias=False)
self.add = TestModuleWithTwoArguments()

def forward(self, x):
x = self.linear(x)
return self.add(x, x)


class TestModel(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -179,6 +201,11 @@ def test_flops_calculation(self):
compute_flops(model, input_shape=input_shape), TestModuleWithFlops._flops
) # the conv is applied twice

# test that a model has a module which takes two positional arguments
model = nn.Sequential(TestModuleDoubleValue())
input_shape = (10,)
self.assertEqual(compute_flops(model, input_shape=input_shape), 110)


class TestHelperFunctions(unittest.TestCase):
def test_get_shape(self) -> None:
Expand Down

0 comments on commit 0673482

Please sign in to comment.