Skip to content

Commit

Permalink
Refactoring: utils.get_dummy_input()
Browse files Browse the repository at this point in the history
Remove the multiple instances of code that generates
dummy input per dataset.
  • Loading branch information
nzmora committed May 16, 2019
1 parent af5c721 commit bf1e6a0
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 35 deletions.
11 changes: 2 additions & 9 deletions distiller/thinning.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
import logging
from collections import namedtuple
import torch
from .policy import ScheduledTrainingPolicy
import distiller
from .policy import ScheduledTrainingPolicy
from .summary_graph import SummaryGraph
msglogger = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,14 +63,7 @@


def create_graph(dataset, model):
dummy_input = None
if dataset == 'imagenet':
dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False)
elif dataset == 'cifar10':
dummy_input = torch.randn((1, 3, 32, 32), requires_grad=False)
assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)

dummy_input = dummy_input.to(distiller.model_device(model))
dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model))
return SummaryGraph(model, dummy_input)


Expand Down
8 changes: 7 additions & 1 deletion distiller/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,13 +556,19 @@ def has_children(module):
return False


def get_dummy_input(dataset):
def get_dummy_input(dataset, device=None):
"""Generate a representative dummy (random) input for the specified dataset.
If a device is specified, then the dummay_input is moved to that device.
"""
if dataset == 'imagenet':
dummy_input = torch.randn(1, 3, 224, 224)
elif dataset == 'cifar10':
dummy_input = torch.randn(1, 3, 32, 32)
else:
raise ValueError("dataset %s is not supported" % dataset)
if device:
dummy_input = dummy_input.to(device)
return dummy_input


Expand Down
8 changes: 0 additions & 8 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,5 @@ def find_module_by_name(model, module_to_find):
return None


def get_dummy_input(dataset):
if dataset == "imagenet":
return torch.randn(1, 3, 224, 224).cuda()
elif dataset == "cifar10":
return torch.randn(1, 3, 32, 32).cuda()
raise ValueError("Trying to use an unknown dataset " + dataset)


def almost_equal(a , b, max_diff=0.000001):
return abs(a - b) <= max_diff
4 changes: 2 additions & 2 deletions tests/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ def test_compute_summary():
dataset = "cifar10"
arch = "simplenet_cifar"
model, _ = common.setup_test(arch, dataset, parallel=True)
df_compute = distiller.model_performance_summary(model, common.get_dummy_input(dataset))
df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input(dataset))
module_macs = df_compute.loc[:, 'MACs'].to_list()
# [conv1, conv2, fc1, fc2, fc3]
assert module_macs == [352800, 240000, 48000, 10080, 840]

dataset = "imagenet"
arch = "mobilenet"
model, _ = common.setup_test(arch, dataset, parallel=True)
df_compute = distiller.model_performance_summary(model, common.get_dummy_input(dataset))
df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input(dataset))
module_macs = df_compute.loc[:, 'MACs'].to_list()
expected_macs = [10838016, 3612672, 25690112, 1806336, 25690112, 3612672, 51380224, 903168,
25690112, 1806336, 51380224, 451584, 25690112, 903168, 51380224, 903168,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel):
assert bn1.bias.size(0) == cnt_nnz_channels
assert bn1.weight.size(0) == cnt_nnz_channels

dummy_input = common.get_dummy_input(config.dataset)
dummy_input = distiller.get_dummy_input(config.dataset, distiller.model_device(model))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1)
run_forward_backward(model, optimizer, dummy_input)

Expand Down
18 changes: 4 additions & 14 deletions tests/test_summarygraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,8 @@
logger.addHandler(fh)


def get_input(dataset):
if dataset == 'imagenet':
return torch.randn((1, 3, 224, 224), requires_grad=False)
elif dataset == 'cifar10':
return torch.randn((1, 3, 32, 32))
return None


def create_graph(dataset, arch):
dummy_input = get_input(dataset)
assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)

dummy_input = distiller.get_dummy_input(dataset)
model = create_model(False, dataset, arch, parallel=False)
assert model is not None
return SummaryGraph(model, dummy_input)
Expand Down Expand Up @@ -163,7 +153,7 @@ def test_normalize_module_name():

def named_params_layers_test_aux(dataset, arch, dataparallel:bool):
model = create_model(False, dataset, arch, parallel=dataparallel)
sgraph = SummaryGraph(model, get_input(dataset))
sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset))
sgraph_layer_names = set(k for k, i, j in sgraph.named_params_layers())
for layer_name in sgraph_layer_names:
assert sgraph.find_op(layer_name) is not None, '{} was not found in summary graph'.format(layer_name)
Expand Down Expand Up @@ -202,7 +192,7 @@ def test_sg_macs():
sg = create_graph('imagenet', 'mobilenet')
assert sg
model, _ = common.setup_test('mobilenet', 'imagenet', parallel=False)
df_compute = distiller.model_performance_summary(model, common.get_dummy_input('imagenet'))
df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input('imagenet'))
modules_macs = df_compute.loc[:, ['Name', 'MACs']]
for name, mod in model.named_modules():
if isinstance(mod, (torch.nn.Conv2d, torch.nn.Linear)):
Expand All @@ -214,7 +204,7 @@ def test_sg_macs():
def test_weights_size_attr():
def test(dataset, arch, dataparallel:bool):
model = create_model(False, dataset, arch, parallel=dataparallel)
sgraph = SummaryGraph(model, get_input(dataset))
sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset))

distiller.assign_layer_fq_names(model)
for name, mod in model.named_modules():
Expand Down

0 comments on commit bf1e6a0

Please sign in to comment.