Skip to content

Commit

Permalink
add model analysis tool
Browse files Browse the repository at this point in the history
Summary: Support flop/activation counting for 100 inputs.

Reviewed By: rbgirshick

Differential Revision: D20988652

fbshipit-source-id: 394bd3050fd4edc0d726bdb431edd40211f66810
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Apr 21, 2020
1 parent 5525cf5 commit 4bd92bb
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 22 deletions.
8 changes: 6 additions & 2 deletions detectron2/engine/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,19 @@
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]


def default_argument_parser():
def default_argument_parser(epilog=None):
"""
Create a parser with some common arguments used by detectron2 users.
Args:
epilog (str): epilog passed to ArgumentParser describing the usage.
Returns:
argparse.ArgumentParser:
"""
parser = argparse.ArgumentParser(
epilog=f"""
epilog=epilog
or f"""
Examples:
Run on single machine:
Expand Down
3 changes: 2 additions & 1 deletion detectron2/layers/mask_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def paste_masks_in_image(masks, boxes, image_shape, threshold=0.5):
num_chunks = N
else:
# GPU benefits from parallelism for larger chunks, but may have memory issue
num_chunks = int(np.ceil(N * img_h * img_w * BYTES_PER_FLOAT / GPU_MEM_LIMIT))
# int(img_h) because shape may be tensors in tracing
num_chunks = int(np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT / GPU_MEM_LIMIT))
assert (
num_chunks <= N
), "Default GPU_MEM_LIMIT in mask_ops.py is too small; try increasing it"
Expand Down
81 changes: 62 additions & 19 deletions detectron2/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from fvcore.nn import activation_count, flop_count, parameter_count, parameter_count_table
from torch import nn

from detectron2.structures import BitMasks, Boxes, ImageList, Instances

from .logger import log_first_n

__all__ = [
Expand All @@ -20,6 +22,34 @@
ACTIVATIONS_MODE = "activations"


# some extra ops to ignore from counting.
_IGNORED_OPS = [
"aten::batch_norm",
"aten::div",
"aten::div_",
"aten::rsub",
"aten::sub",
"aten::relu_",
"aten::add_",
"aten::mul",
"aten::add",
"aten::relu",
"aten::sigmoid",
"aten::sigmoid_",
"aten::sort",
"aten::exp",
"aten::mul_",
"aten::max_pool2d",
"aten::constant_pad_nd",
"aten::sqrt",
"aten::softmax",
"aten::log2",
"aten::nonzero_numpy",
"prim::PythonOp",
"torchvision::nms",
]


def flop_count_operators(
model: nn.Module, inputs: list, **kwargs
) -> typing.DefaultDict[str, float]:
Expand Down Expand Up @@ -64,10 +94,38 @@ def activation_count_operators(
return _wrapper_count_operators(model=model, inputs=inputs, mode=ACTIVATIONS_MODE, **kwargs)


def _flatten_to_tuple(outputs):
result = []
if isinstance(outputs, torch.Tensor):
result.append(outputs)
elif isinstance(outputs, (list, tuple)):
for v in outputs:
result.extend(_flatten_to_tuple(v))
elif isinstance(outputs, dict):
for _, v in outputs.items():
result.extend(_flatten_to_tuple(v))
elif isinstance(outputs, Instances):
result.extend(_flatten_to_tuple(outputs.get_fields()))
elif isinstance(outputs, (Boxes, BitMasks, ImageList)):
result.append(outputs.tensor)
else:
log_first_n(
logging.WARN,
f"Output of type {type(outputs)} not included in flops/activations count.",
n=10,
)
return tuple(result)


def _wrapper_count_operators(
model: nn.Module, inputs: list, mode: str, **kwargs
) -> typing.DefaultDict[str, float]:

# ignore some ops
supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS}
supported_ops.update(kwargs.pop("supported_ops", {}))
kwargs["supported_ops"] = supported_ops

assert len(inputs) == 1, "Please use batch size=1"
tensor_input = inputs[0]["image"]

Expand All @@ -84,25 +142,10 @@ def __init__(self, model):
def forward(self, image):
# jit requires the input/output to be Tensors
inputs = [{"image": image}]
outputs = self.model.forward(inputs)[0]
if isinstance(outputs, dict) and "instances" in outputs:
# Only the subgraph that computes the returned tensor will be
# counted. So we return everything we found in Instances.
inst = outputs["instances"]
ret = [inst.pred_boxes.tensor]
inst.remove("pred_boxes")
for k, v in inst.get_fields().items():
if isinstance(v, torch.Tensor):
ret.append(v)
else:
log_first_n(
logging.WARN,
f"Field '{k}' in output instances is not included"
" in flops/activations count.",
n=10,
)
return tuple(ret)
raise NotImplementedError("Count for segmentation models is not supported yet.")
outputs = self.model.forward(inputs)
# Only the subgraph that computes the returned tuple of tensor will be
# counted. So we flatten everything we found to tuple of tensors.
return _flatten_to_tuple(outputs)

old_train = model.training
with torch.no_grad():
Expand Down
114 changes: 114 additions & 0 deletions tools/analyze_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
# noqa: B950

import logging
from collections import Counter
import tqdm

from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import build_detection_test_loader
from detectron2.engine import default_argument_parser
from detectron2.modeling import build_model
from detectron2.utils.analysis import (
activation_count_operators,
flop_count_operators,
parameter_count_table,
)
from detectron2.utils.logger import setup_logger

logger = logging.getLogger("detectron2")


def setup(args):
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.DATALOADER.NUM_WORKERS = 0
cfg.merge_from_list(args.opts)
cfg.freeze()
setup_logger()
return cfg


def do_flop(cfg):
data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
model = build_model(cfg)
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
model.eval()

counts = Counter()
for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa
counts += flop_count_operators(model, data)
logger.info(
"(G)Flops for Each Type of Operators:\n" + str([(k, v / idx) for k, v in counts.items()])
)


def do_activation(cfg):
data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
model = build_model(cfg)
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
model.eval()

counts = Counter()
for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa
counts += activation_count_operators(model, data)
logger.info(
"(Million) Activations for Each Type of Operators:\n"
+ str([(k, v / idx) for k, v in counts.items()])
)


def do_parameter(cfg):
model = build_model(cfg)
logger.info("Parameter Count:\n" + parameter_count_table(model, max_depth=5))


def do_structure(cfg):
model = build_model(cfg)
logger.info("Model Structure:\n" + str(model))


if __name__ == "__main__":
parser = default_argument_parser(
epilog="""
Examples:
To show parameters of a model:
$ ./analyze_model.py --tasks parameter \\
--config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml
Flops and activations are data-dependent, therefore inputs and model weights
are needed to count them:
$ ./analyze_model.py --num-inputs 100 --tasks flop \\
--config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \\
MODEL.WEIGHTS /path/to/model.pkl
"""
)
parser.add_argument(
"--tasks",
choices=["flop", "activation", "parameter", "structure"],
required=True,
nargs="+",
)
parser.add_argument(
"--num-inputs",
default=100,
type=int,
help="number of inputs used to compute statistics for flops/activations, "
"both are data dependent.",
)
args = parser.parse_args()
assert not args.eval_only
assert args.num_gpus == 1

cfg = setup(args)

for task in args.tasks:
{
"flop": do_flop,
"activation": do_activation,
"parameter": do_parameter,
"structure": do_structure,
}[task](cfg)

0 comments on commit 4bd92bb

Please sign in to comment.