Skip to content

Commit

Permalink
Merge pull request NVIDIA#918 from a-maci/ASP_sparse_param_dict_update
Browse files Browse the repository at this point in the history
Asp sparse param dict update
  • Loading branch information
thorjohnsen authored Jul 23, 2020
2 parents 0ac5dd6 + b3c1641 commit 459de22
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions apex/contrib/sparsity/asp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",
verbosity=3,
whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d],
allowed_layer_names=None, disallowed_layer_names=[],
allow_recompute_mask=False):
allow_recompute_mask=False, custom_layer_dict={}):
"""Call this method to modify your model to take advantage of sparse matrix multiplication.
Note that this call alone only augments the model with additional buffers needed for sparse MMA,
it does not enable use of sparse MMA.
Expand Down Expand Up @@ -62,7 +62,9 @@ def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",
disallowed_layer_names If not [], only layer names that do not appear in this list are considered for sparsity.
allow_recompute_mask If True, stores pruned values so that dense weights can be restored.
Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.
Support for allow_recompute_mask can be removed, it is not part of our recipe -- AKM.
custom_layer_dict Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']}
[Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe -- AKM.
"""
assert (cls.__model is None), "ASP has been initialized already."
cls.__model = model
Expand All @@ -82,6 +84,10 @@ def create_mask_from_pattern(param):
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torchvision.ops.misc.Conv2d: ['weight']}
else:
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight']}
if custom_layer_dict: # Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prune
sparse_parameter_list.update(custom_layer_dict)
whitelist += list(custom_layer_dict.keys())

for module_type in whitelist:
assert (module_type in sparse_parameter_list), "Module %s :: Don't know how to sparsify module." % module.dtype()

Expand All @@ -97,7 +103,7 @@ def add_sparse_attributes(module_name, module):
if p.dtype == torch.float16 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #For Conv2d dim= K x CRS; we prune along C
print("[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype)))
continue

if cls.__verbosity >= 3:
print("[ASP] Sparsifying %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype)))

Expand All @@ -110,6 +116,9 @@ def add_sparse_attributes(module_name, module):
else:
pruned = None
cls.__sparse_parameters.append((module_name, module, p_name, p, mask, pruned))
else:
if cls.__verbosity >= 3:
print("[ASP] Not sparsifying %s::%s of size=%s and type=%s" % (module_name, p_name, str(p.size()), str(p.dtype)))

for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names):
add_sparse_attributes(name, sparse_module)
Expand Down

0 comments on commit 459de22

Please sign in to comment.