Skip to content

Commit

Permalink
[apex.contrib.sparsity] Grouped conv permutations (NVIDIA#1628)
Browse files Browse the repository at this point in the history
* adjusting epsilon slightly to avoid oscillations

* adding verbosity flag, removing some old comments

* large refactor, new support for MHA, grouped convs, attributes, unexpected behavior handling, utilizing multiple GPUs in DDP, et al.

* adding permutation application regression tests

* supporting PyTorch >= 1.8, WAR for missing tabulate library
  • Loading branch information
jpool-nv authored Mar 31, 2023
1 parent 57057e2 commit 89cc215
Show file tree
Hide file tree
Showing 5 changed files with 2,272 additions and 740 deletions.
97 changes: 43 additions & 54 deletions apex/contrib/sparsity/asp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ASP:
@classmethod
def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",
verbosity=3,
whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d],
whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.MultiheadAttention],
allowed_layer_names=None, disallowed_layer_names=[],
allow_recompute_mask=False, custom_layer_dict={},
allow_permutation=True):
Expand Down Expand Up @@ -99,39 +99,18 @@ def create_mask_from_pattern(param):
torchvision_version_major = int(torchvision_version.split('.')[0])
torchvision_version_minor = int(torchvision_version.split('.')[1])
if torchvision_version_major == 0 and torchvision_version_minor < 12:
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torchvision.ops.misc.Conv2d: ['weight']}
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight'], torchvision.ops.misc.Conv2d: ['weight']}
else: # Torchvision remove APIs that were deprecated before 0.8 (#5386) in 0.12.0, torchvision.ops.misc.Conv2d is removed
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight']}
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']}
else:
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight']}
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']}
if custom_layer_dict: # Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prune
sparse_parameter_list.update(custom_layer_dict)
whitelist += list(custom_layer_dict.keys())

for module_type in whitelist:
assert (module_type in sparse_parameter_list), "Module %s :: Don't know how to sparsify module." % module.dtype()

if allow_permutation: # find all named modules, extract parameters and decorate, used for offline permutation in K dim
for module_name, module in model.named_modules():
module_type_str = str(type(module)).split("\'")[1]
if module_type_str == 'torch.nn.modules.container.Sequential' or module_type_str.startswith('torchvision.models'):
# filter out the 'torch.nn.modules.container.Sequential' type and the whole model, like 'torchvision.models.vgg.VGG'
continue
for p_name, p in module.named_parameters():
cls.__all_parameters.append((module_name, module, p_name, p))
if module_type_str == 'torch.nn.modules.batchnorm.BatchNorm2d':
# need to get the running_mean and running_var from model.state_dict(), as they are not the learnable parameters
module_mean_name = module_name + '.running_mean'
module_var_name = module_name + '.running_var'
for param_key in model.state_dict():
if module_mean_name == param_key or module_var_name == param_key:
cls.__all_parameters.append((module_name, module, param_key.split(".")[-1], model.state_dict()[param_key]))
# add the __permutation_output_dir field to save the intermediate results for permutation
cls.__permutation_output_dir = '.'
# Set the corresponding params from ASP class to the Permutation class
Permutation.set_permutation_params_from_asp(cls.__model, cls.__sparse_parameters, cls.__all_parameters)
# Set the identical random seed for all GPUs to make sure the same results generated in permutation search
Permutation.set_identical_seed()

# find all sparse modules, extract sparse parameters and decorate
def add_sparse_attributes(module_name, module):
Expand Down Expand Up @@ -165,6 +144,30 @@ def add_sparse_attributes(module_name, module):
for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names):
add_sparse_attributes(name, sparse_module)

if allow_permutation: # find all named modules, extract parameters and decorate, used for offline permutation in K dim
for module_name, module in model.named_modules():
module_type_str = str(type(module)).split("\'")[1]
if module_type_str == 'torch.nn.modules.container.Sequential' or module_type_str.startswith('torchvision.models'):
# filter out the 'torch.nn.modules.container.Sequential' type and the whole model, like 'torchvision.models.vgg.VGG'
continue
for p_name, p in module.named_parameters():
cls.__all_parameters.append((module_name, module, p_name, p))
if module_type_str == 'torch.nn.modules.batchnorm.BatchNorm2d':
# need to get the running_mean and running_var from model.state_dict(), as they are not the learnable parameters
module_mean_name = module_name + '.running_mean'
module_var_name = module_name + '.running_var'
for param_key in model.state_dict():
if module_mean_name == param_key or module_var_name == param_key:
cls.__all_parameters.append((module_name, module, param_key.split(".")[-1], model.state_dict()[param_key]))
# add the __permutation_output_dir field to save the intermediate results for permutation
cls.__permutation_output_dir = '.'
# Set the corresponding params from ASP class to the Permutation class
permutation_verbosity = 5
Permutation.set_permutation_params_from_asp(cls.__model, cls.__sparse_parameters, cls.__all_parameters, permutation_verbosity)
# Set the identical random seed for all GPUs to make sure the same results generated in permutation search
Permutation.set_identical_seed()


@classmethod
def already_init_asp_model(cls):
"""Call this method to check whether ASP has been initialized already.
Expand Down Expand Up @@ -215,38 +218,24 @@ def compute_sparse_masks(cls):
if cls.__allow_permutation:
# Step 1: use the Torch.FX library to build the graph
# Step 2: permutation search with the customized kernel
# Notice: need to use the single GPU to build the Torch.FX graph
# The simplest without user intervention:
# A. try to import with the distributed mode of the original model
# B. if meet the error, import with the none-distributed mode of the original model
start_time_build_offline_permutation_graph = time.perf_counter()
success_in_build_offline_permutation_graph = False
start_time_permute = time.perf_counter()
successful_permutation = False
try:
offline_permutation_fx_graph, success_in_build_offline_permutation_graph = Permutation.build_offline_permutation_graph(cls.__model.module, dump_fx_graph=cls.__save_permutation_graph, save_dumped_fx_graph=os.path.join(cls.__permutation_output_dir, 'model_offline_permutation_graph.json'))
if success_in_build_offline_permutation_graph:
print("\n[compute_sparse_masks] build offline permutation graph on distributed model.")
successful_permutation = Permutation.permute_model(cls.__model.module, dump_fx_graph=cls.__save_permutation_graph, save_dumped_fx_graph=os.path.join(cls.__permutation_output_dir, 'model_offline_permutation_graph.json'))
if successful_permutation:
print("\n[compute_sparse_masks] permuted the (distributed) model.")
except AttributeError:
offline_permutation_fx_graph, success_in_build_offline_permutation_graph = Permutation.build_offline_permutation_graph(cls.__model, dump_fx_graph=cls.__save_permutation_graph, save_dumped_fx_graph=os.path.join(cls.__permutation_output_dir, 'model_offline_permutation_graph.json'))
if success_in_build_offline_permutation_graph:
print("\n[compute_sparse_masks] build offline permutation graph on none-distributed model.")

if success_in_build_offline_permutation_graph:
duration_build_offline_permutation_graph = time.perf_counter() - start_time_build_offline_permutation_graph
print("[compute_sparse_masks] Take {:.4f} seconds to finish build_offline_permutation_graph function.".format(duration_build_offline_permutation_graph))

# Step 3: off-line permutation to avoid the runtime overhead in deployment
start_time_apply_offline_permutation = time.perf_counter()
try:
Permutation.apply_offline_permutation(cls.__model.module, fx_graph=offline_permutation_fx_graph)
print("\n[compute_sparse_masks] apply offline permutation on distributed model.")
except AttributeError:
Permutation.apply_offline_permutation(cls.__model, fx_graph=offline_permutation_fx_graph)
print("\n[compute_sparse_masks] apply offline permutation on none-distributed model.")
duration_apply_offline_permutation = time.perf_counter() - start_time_apply_offline_permutation
print("[compute_sparse_masks] Take {:.4f} seconds to finish apply_offline_permutation function.\n".format(duration_apply_offline_permutation))
else:
print("[compute_sparse_masks] skip applying offline permutation because there is no valid offline_permutation_fx_graph.")
# Finally, permutation search and off-line permutation is done, give the model back to ASP to generate the normal structured sparse mask
successful_permutation = Permutation.permute_model(cls.__model, dump_fx_graph=cls.__save_permutation_graph, save_dumped_fx_graph=os.path.join(cls.__permutation_output_dir, 'model_offline_permutation_graph.json'))
if successful_permutation:
print("\n[compute_sparse_masks] permuted the model.")

if successful_permutation:
duration_build_offline_permutation_graph = time.perf_counter() - start_time_permute
print("[compute_sparse_masks] Take {:.4f} seconds to find and apply permutations.".format(duration_build_offline_permutation_graph))


for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
if mask.sum() < mask.numel(): # when recalculating masks
Expand All @@ -261,7 +250,7 @@ def compute_sparse_masks(cls):

p.mul_(mask) # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights
if cls.__verbosity >= 2:
print("[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s" % (100.0-100.0*mask.sum()/mask.numel(), module_name, p_name, str(p.size()), str(p.dtype)))
print("[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s with magnitude %s" % (100.0-100.0*mask.sum()/mask.numel(), module_name, p_name, str(p.size()), str(p.dtype), torch.sum(torch.abs(p))))

@classmethod
def restore_pruned_weights(cls):
Expand Down Expand Up @@ -302,7 +291,7 @@ def is_sparsity_enabled(cls):
@classmethod
def prune_trained_model(cls, model, optimizer):
# add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)
cls.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=False)
cls.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d, torch.nn.MultiheadAttention], allow_recompute_mask=False)
cls.init_optimizer_for_pruning(optimizer)
cls.compute_sparse_masks()

Expand Down
Loading

0 comments on commit 89cc215

Please sign in to comment.