Skip to content

Commit

Permalink
Add the permutation related support as the extension for asp lib. (NV…
Browse files Browse the repository at this point in the history
…IDIA#1194)

* Add the permutation related support as the extension for asp lib.

* [Fix] Track the permutation sequence for progressive channel swap strategy.

* Fix the corner case that one layer is not sparse, but need to apply permutation due to its siblings.

* Fix the deprecated functions in ASP unit tests.

* Fix the sparsity info typo in ASP lib.

* [Enhancement] Set the identical random seed for all GPUs to make sure the same results generated in permutation search.

* Update the README.md with identical random seed setting and NeurIPS info.

* Integrate the Pybind11 enhancement of permutation search into ASP lib.
  • Loading branch information
ChongyuNVIDIA authored Feb 1, 2022
1 parent 79c0187 commit 89edb81
Show file tree
Hide file tree
Showing 12 changed files with 2,028 additions and 7 deletions.
58 changes: 57 additions & 1 deletion apex/contrib/sparsity/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@
This serves as a quick-start for ASP (Automatic SParsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python.

## Importing ASP

```
from apex.contrib.sparsity import ASP
```

## Initializing ASP

Apart from the import statement, it is sufficient to add just the following line of code before the training phase to augment the model and the optimizer for sparse training/inference:

```
ASP.prune_trained_model(model, optimizer)
```

In the context of a typical PyTorch training loop, it might look like this:

```
ASP.prune_trained_model(model, optimizer)
Expand All @@ -27,6 +30,7 @@ for epoch in range(epochs):
torch.save(...)
```

The `prune_trained_model` step calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step.

## Generate a Sparse Network
Expand All @@ -42,7 +46,6 @@ The following approach serves as a guiding example on how to generate a pruned m
In code, below is a sketch on how to use ASP for this approach (steps 1 and 2 above).

```
model = define_model(..., pretrained=True) # define model architecture and load parameter tensors with trained values (by reading a trained checkpoint)
criterion = ... # compare ground truth with model predition; use the same criterion as used to generate the dense trained model
optimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model
Expand Down Expand Up @@ -72,7 +75,60 @@ ASP.compute_sparse_masks()

A more thorough example can be found in `./test/toy_problem.py`.

## Advanced Usage: Channel Permutation

We introduce channel permutations as an advanced method to maximize the accuracy of structured sparse networks. By permuting weight matrices along their channel dimension and adjusting the surrounding layers appropriately, we demonstrate accuracy recovery for even small, parameter-efficient networks, without affecting inference run-time.

The final accuracy has a strong relationship with the quality of permutations. We provide the default algorithms to search for high-quality permutations. The permutation search process can be accelerated by the Apex CUDA extension: `apex.contrib.sparsity.permutation_search_kernels`

If you want to use the GPU to accelerate the permutation search process, we recommend installing Apex with permutation search CUDA extension via

```
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--permutation_search" ./
```

If you want to disable the permutation search process, please pass the `allow_permutation=False` to `init_model_for_pruning` function. For example:

```
ASP.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=False, allow_permutation=False)
```

Please notice, when using multi-GPUs we should set the identical random seed for all GPUs to make sure the same results generated in permutation search. The library has implemented the `set_identical_seed` function in `permutation_lib.py`, and be called in ASP library. We still suggest the users to set the identical random seed when using multi-GPUs in their code, the example code is as follows:

```
import torch
import numpy
import random
torch.manual_seed(identical_seed)
torch.cuda.manual_seed_all(identical_seed)
numpy.random.seed(identical_seed)
random.seed(identical_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
```

## Reference Papers

More details about sparsity support on the NVIDIA Ampere GPU with Sparse Tensor Cores can refer to our [white paper](https://arxiv.org/abs/2104.08378).

```
@article{mishra2021accelerating,
title={Accelerating sparse deep neural networks},
author={Mishra, Asit and Latorre, Jorge Albericio and Pool, Jeff and Stosic, Darko and Stosic, Dusan and Venkatesh, Ganesh and Yu, Chong and Micikevicius, Paulius},
journal={arXiv preprint arXiv:2104.08378},
year={2021}
}
```

The details about sparsity with permutation can refer to our [paper](https://proceedings.neurips.cc/paper/2021/hash/6e8404c3b93a9527c8db241a1846599a-Abstract.html) published in *Thirty-fifth Conference on Neural Information Processing Systems* (**NeurIPS 2021**):

```
@article{pool2021channel,
title={Channel Permutations for N: M Sparsity},
author={Pool, Jeff and Yu, Chong},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}
```
101 changes: 98 additions & 3 deletions apex/contrib/sparsity/asp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import types
import torch
from .sparse_masklib import create_mask
from .permutation_lib import Permutation

torchvision_imported=True
try:
Expand All @@ -9,6 +10,11 @@
print("[ASP][Warning] torchvision cannot be imported.")
torchvision_imported=False

import json
import os
import string
import time

def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallowed_layer_names):
eligible_modules_list = []
for name, mod in model.named_modules():
Expand All @@ -18,19 +24,25 @@ def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallow
eligible_modules_list.append((name, mod))
return eligible_modules_list


class ASP:
__model = None
__verbosity = 0
__optimizer = None
__sparse_parameters = []
__calculate_mask = None
__allow_permutation = True
__all_parameters = []
__save_permutation_graph = False
__permutation_output_dir = ''

@classmethod
def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",
verbosity=3,
whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d],
allowed_layer_names=None, disallowed_layer_names=[],
allow_recompute_mask=False, custom_layer_dict={}):
allow_recompute_mask=False, custom_layer_dict={},
allow_permutation=True):
"""Call this method to modify your model to take advantage of sparse matrix multiplication.
Note that this call alone only augments the model with additional buffers needed for sparse MMA,
it does not enable use of sparse MMA.
Expand Down Expand Up @@ -63,12 +75,14 @@ def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",
allow_recompute_mask If True, stores pruned values so that dense weights can be restored.
Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.
custom_layer_dict Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']}
allow_permutation If True, allow the input channel permutation to ease the influence of weight pruning.
[Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe -- AKM.
[Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe.
"""
assert (cls.__model is None), "ASP has been initialized already."
cls.__model = model
cls.__verbosity = verbosity
cls.__allow_permutation = allow_permutation

if isinstance(mask_calculator, str):
def create_mask_from_pattern(param):
Expand All @@ -91,6 +105,28 @@ def create_mask_from_pattern(param):
for module_type in whitelist:
assert (module_type in sparse_parameter_list), "Module %s :: Don't know how to sparsify module." % module.dtype()

if allow_permutation: # find all named modules, extract parameters and decorate, used for offline permutation in K dim
for module_name, module in model.named_modules():
module_type_str = str(type(module)).split("\'")[1]
if module_type_str == 'torch.nn.modules.container.Sequential' or module_type_str.startswith('torchvision.models'):
# filter out the 'torch.nn.modules.container.Sequential' type and the whole model, like 'torchvision.models.vgg.VGG'
continue
for p_name, p in module.named_parameters():
cls.__all_parameters.append((module_name, module, p_name, p))
if module_type_str == 'torch.nn.modules.batchnorm.BatchNorm2d':
# need to get the running_mean and running_var from model.state_dict(), as they are not the learnable parameters
module_mean_name = module_name + '.running_mean'
module_var_name = module_name + '.running_var'
for param_key in model.state_dict():
if module_mean_name == param_key or module_var_name == param_key:
cls.__all_parameters.append((module_name, module, param_key.split(".")[-1], model.state_dict()[param_key]))
# add the __permutation_output_dir field to save the intermediate results for permutation
cls.__permutation_output_dir = '.'
# Set the corresponding params from ASP class to the Permutation class
Permutation.set_permutation_params_from_asp(cls.__model, cls.__sparse_parameters, cls.__all_parameters)
# Set the identical random seed for all GPUs to make sure the same results generated in permutation search
Permutation.set_identical_seed()

# find all sparse modules, extract sparse parameters and decorate
def add_sparse_attributes(module_name, module):
sparse_parameters = sparse_parameter_list[type(module)]
Expand Down Expand Up @@ -123,6 +159,19 @@ def add_sparse_attributes(module_name, module):
for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names):
add_sparse_attributes(name, sparse_module)

@classmethod
def already_init_asp_model(cls):
"""Call this method to check whether ASP has been initialized already.
"""
if cls.__model is None:
if cls.__verbosity >= 3:
print("[ASP] ASP has not been initialized.")
return False
else:
if cls.__verbosity >= 3:
print("[ASP] ASP has been initialized already.")
return True

@classmethod
def init_optimizer_for_pruning(cls, optimizer):
"""Call this method to monkey patch optimizer step function so that masks can be applied to
Expand Down Expand Up @@ -157,6 +206,38 @@ def compute_sparse_masks(cls):
If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None.
"""
with torch.no_grad():
if cls.__allow_permutation:
# Step 1: use the Torch.FX library to build the graph
# Step 2: permutation search with the customized kernel
# Notice: need to use the single GPU to build the Torch.FX graph
# The simplest without user intervention:
# A. try to import with the distributed mode of the original model
# B. if meet the error, import with the none-distributed mode of the original model
start_time_build_offline_permutation_graph = time.perf_counter()
try:
offline_permutation_fx_graph, success_in_build_offline_permutation_graph = Permutation.build_offline_permutation_graph(cls.__model.module, dump_fx_graph=cls.__save_permutation_graph, save_dumped_fx_graph=os.path.join(cls.__permutation_output_dir, 'model_offline_permutation_graph.json'))
print("\n[compute_sparse_masks] build offline permutation graph on distributed model.")
except AttributeError:
offline_permutation_fx_graph, success_in_build_offline_permutation_graph = Permutation.build_offline_permutation_graph(cls.__model, dump_fx_graph=cls.__save_permutation_graph, save_dumped_fx_graph=os.path.join(cls.__permutation_output_dir, 'model_offline_permutation_graph.json'))
print("\n[compute_sparse_masks] build offline permutation graph on none-distributed model.")
duration_build_offline_permutation_graph = time.perf_counter() - start_time_build_offline_permutation_graph
print("[compute_sparse_masks] Take {:.4f} seconds to finish build_offline_permutation_graph function.".format(duration_build_offline_permutation_graph))

# Step 3: off-line permutation to avoid the runtime overhead in deployment
if success_in_build_offline_permutation_graph:
start_time_apply_offline_permutation = time.perf_counter()
try:
Permutation.apply_offline_permutation(cls.__model.module, fx_graph=offline_permutation_fx_graph)
print("\n[compute_sparse_masks] apply offline permutation on distributed model.")
except AttributeError:
Permutation.apply_offline_permutation(cls.__model, fx_graph=offline_permutation_fx_graph)
print("\n[compute_sparse_masks] apply offline permutation on none-distributed model.")
duration_apply_offline_permutation = time.perf_counter() - start_time_apply_offline_permutation
print("[compute_sparse_masks] Take {:.4f} seconds to finish apply_offline_permutation function.\n".format(duration_apply_offline_permutation))
else:
print("[compute_sparse_masks] skip applying offline permutation because there is no valid offline_permutation_fx_graph.")
# Finally, permutation search and off-line permutation is done, give the model back to ASP to generate the normal structured sparse mask

for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
if mask.sum() < mask.numel(): # when recalculating masks
# restore dense parameter if allow_recompute_mask is enabled
Expand All @@ -170,7 +251,7 @@ def compute_sparse_masks(cls):

p.mul_(mask) # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights
if cls.__verbosity >= 2:
print("[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s" % (100.0*mask.sum()/mask.numel(), module_name, p_name, str(p.size()), str(p.dtype)))
print("[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s" % (100.0-100.0*mask.sum()/mask.numel(), module_name, p_name, str(p.size()), str(p.dtype)))

@classmethod
def restore_pruned_weights(cls):
Expand Down Expand Up @@ -215,3 +296,17 @@ def prune_trained_model(cls, model, optimizer):
cls.init_optimizer_for_pruning(optimizer)
cls.compute_sparse_masks()

@classmethod
def set_permutation_saving_params(cls, allow_permutation=True, save_permutation_graph=False, permutation_output_dir='.'):
"""This function is used to set the permutation saving related parameters in ASP class and inside of the Permutation class."""
print("\n[ASP][set_permutation_saving_param] Set permutation saving related parameters")
print("\n[set_permutation_saving_param] Set permutation saving related parameters")
cls.__allow_permutation = allow_permutation
print("[set_permutation_saving_param]\t Allow permutation: {}".format(cls.__allow_permutation))
cls.__save_permutation_graph = save_permutation_graph
print("[set_permutation_saving_param]\t Save permutation graphs: {}".format(cls.__save_permutation_graph))
cls.__permutation_output_dir = permutation_output_dir
print("[set_permutation_saving_param]\t Permutation graphs saving dir: {}".format(cls.__permutation_output_dir))

Permutation.set_permutation_saving_params(allow_permutation, save_permutation_graph, permutation_output_dir)

Loading

0 comments on commit 89edb81

Please sign in to comment.