Skip to content

Commit

Permalink
[Model Compression] Pruning Scheduler (microsoft#4089)
Browse files Browse the repository at this point in the history
  • Loading branch information
J-shang authored Sep 10, 2021
1 parent 04f439a commit e98ebcf
Show file tree
Hide file tree
Showing 11 changed files with 673 additions and 47 deletions.
1 change: 1 addition & 0 deletions nni/algorithms/compression/v2/pytorch/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .compressor import Compressor, LayerInfo
from .pruner import Pruner, PrunerModuleWrapper
from .scheduler import BasePruningScheduler, Task, TaskResult
11 changes: 11 additions & 0 deletions nni/algorithms/compression/v2/pytorch/base/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ def reset(self, model: Module, config_list: List[Dict]):

self._wrap_model()

def clear_model_references(self):
"""
Clear all references to the model in this compressor. Just to free up memory.
Need reset first before the next time call compressor function.
"""
self._unwrap_model()
self.bound_model = None
self.config_list = None
self.modules_wrapper = None
self._modules_to_compress = None

def _detect_modules_to_compress(self) -> List[Tuple[LayerInfo, Dict]]:
"""
Detect all modules should be compressed, and save the result in `self._modules_to_compress`.
Expand Down
40 changes: 15 additions & 25 deletions nni/algorithms/compression/v2/pytorch/base/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,17 @@ def load_masks(self, masks: Dict[str, Dict[str, Tensor]]):
Parameters
----------
masks
The masks dict with format {'op_name': {'weight_mask': mask, 'bias_mask': mask}}.
The masks dict with format {'op_name': {'weight': mask, 'bias': mask}}.
"""
wrappers = self.get_modules_wrapper()
for name, layer_mask in masks.items():
assert name in wrappers, '{} is not in wrappers of this pruner, can not apply the mask.'.format(name)
for mask_type, mask in layer_mask.items():
assert hasattr(wrappers[name], mask_type), 'there is no attribute {} in wrapper'.format(mask_type)
setattr(wrappers[name], mask_type, mask)
if layer_mask.get('weight') is not None:
assert hasattr(wrappers[name], 'weight_mask'), 'There is no attribute weight_mask in wrapper.'
setattr(wrappers[name], 'weight_mask', layer_mask.get('weight'))
if layer_mask.get('bias') is not None:
assert hasattr(wrappers[name], 'bias_mask'), 'There is no attribute bias_mask in wrapper.'
setattr(wrappers[name], 'bias_mask', layer_mask.get('bias'))

def compress(self) -> Tuple[Module, Dict[str, Dict[str, Tensor]]]:
"""
Expand Down Expand Up @@ -126,27 +129,21 @@ def show_pruned_weights(self, dim: int = 0):
index = torch.nonzero(weight_mask.abs().sum(sum_idx) != 0, as_tuple=False).tolist()
_logger.info(f'simulated prune {wrapper.name} remain/total: {len(index)}/{weight_mask.size(dim)}')

def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None):
def export_model(self, model_path: str, mask_path: Optional[str] = None):
"""
Export pruned model weights, masks and onnx model(optional)
Parameters
----------
model_path
Path to save pruned model state_dict.
Path to save pruned model state_dict. The weight and bias have already multiplied the masks.
mask_path
(optional) path to save mask dict.
onnx_path
(optional) path to save onnx model.
input_shape
Input shape to onnx model.
device
Device of the model, used to place the dummy input tensor for exporting onnx file.
The tensor is placed on cpu if ```device``` is None.
Path to save mask dict.
"""
assert model_path is not None, 'model_path must be specified'
assert self.bound_model is not None, 'The bound model reference has been cleared.'
assert model_path is not None, 'model_path must be specified.'
mask_dict = {}
self._unwrap_model() # used for generating correct state_dict name without wrapper state
self._unwrap_model()

for name, wrapper in self.get_modules_wrapper().items():
weight_mask = wrapper.weight_mask
Expand All @@ -159,20 +156,13 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N
if bias_mask is not None:
wrapper.module.bias.data = wrapper.module.bias.data.mul(bias_mask)
# save mask to dict
mask_dict[name] = {"weight_mask": weight_mask, "bias_mask": bias_mask}
mask_dict[name] = {"weight": weight_mask, "bias": bias_mask}

torch.save(self.bound_model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path)

if mask_path is not None:
torch.save(mask_dict, mask_path)
_logger.info('Mask dict saved to %s', mask_path)
if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)

self._wrap_model()
184 changes: 184 additions & 0 deletions nni/algorithms/compression/v2/pytorch/base/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import gc
import logging
import os
from pathlib import Path
from typing import List, Dict, Tuple, Literal, Optional

import json_tricks
import torch
from torch.nn import Module
from torch.tensor import Tensor

_logger = logging.getLogger(__name__)


class Task:
# NOTE: If we want to support multi-thread, this part need to refactor, maybe use file and lock to sync.
_reference_counter = {}

def __init__(self, task_id: int, model_path: str, masks_path: str, config_list_path: str) -> None:
"""
Parameters
----------
task_id
The unique id of task.
model_path
The path of the unwrapped pytorch model that will be pruned in this task.
masks_path
The path of the masks that applied on the model before pruning.
config_list_path
The path of the config list that used in this task.
"""
self.task_id = task_id
self.model_path = model_path
self.masks_path = masks_path
self.config_list_path = config_list_path

self.status: Literal['Pending', 'Running', 'Finished'] = 'Pending'
self.score: Optional[float] = None

self.state = {}

for ref in self.referenced_paths():
self._reference_counter.setdefault(ref, 0)
self._reference_counter[ref] += 1

self._cleaned = False

def to_dict(self) -> Dict:
return {
'task_id': self.task_id,
'model_path': str(self.model_path),
'masks_path': str(self.masks_path),
'config_list_path': str(self.config_list_path),
'status': self.status,
'score': self.score,
'state': self.state
}

def load_data(self) -> Tuple[Module, Dict[str, Dict[str, Tensor]], List[Dict]]:
"""
Returns
-------
Tuple[Module, Dict[str, Dict[str, Tensor]], List[Dict]]
Return the model pruning in this task, the masks of the model before pruning,
the config list used in this task.
"""
model = torch.load(self.model_path)
masks = torch.load(self.masks_path)
with Path(self.config_list_path).open('r') as f:
config_list = json_tricks.load(f)
return model, masks, config_list

def referenced_paths(self) -> List[str]:
"""
Return the path list that need to count reference in this task.
"""
return [self.model_path, self.masks_path, self.config_list_path]

def clean_up(self):
"""
Counter of referenced file paths subtract 1. If the counter reach 0, then delete the file.
"""
if not self._cleaned:
for ref in self.referenced_paths():
self._reference_counter[ref] -= 1
if self._reference_counter[ref] <= 0:
os.remove(ref)
if self._reference_counter[ref] < 0:
_logger.warning('Referance counter error, the number of %s is %d',
ref, self._reference_counter[ref])
self._cleaned = True
else:
_logger.warning('Already clean up task %d', self.task_id)


class TaskResult:
def __init__(self, task_id: int, compact_model: Module, compact_model_masks: Dict[str, Dict[str, Tensor]],
pruner_generated_masks: Dict[str, Dict[str, Tensor]], score: Optional[float]) -> None:
"""
Parameters
----------
task_id
The unique id of task.
compact_model
The unwrapped compact pytorch model after pruning. If the compact model has been speeduped during the pruning process,
it will have a smaller structure compare with the model before pruning.
If the compact model has not been speeduped, it will have the same structure with the model before pruning.
compact_model_masks
The masks on the compact model. If the compact model has been speeduped during the pruning process,
the `compact_model_masks` is always an empty dict. If the compact model has not been speeduped,
the `compact_model_masks` is same as `pruner_generated_masks`.
pruner_generated_masks
The masks that can apply on the before pruning model. It is always the output of `pruner.compress()`.
TODO: If the compact model has been speeduped, the auto infer masks maybe also need.
score
The score of the pruning effect. i.e., the accuracy or latency after pruning.
"""
self.task_id = task_id
self.compact_model = compact_model
self.compact_model_masks = compact_model_masks
self.pruner_generated_masks = pruner_generated_masks
self.score = score


class BasePruningScheduler:
def generate_task(self) -> Optional[Task]:
"""
Returns
-------
Optional[Task]
Return the next pruning task.
"""
raise NotImplementedError()

def record_task_result(self, task_result: TaskResult):
"""
Parameters
----------
task_result
The result of the task
"""
raise NotImplementedError()

def pruning_one_step(self, task: Task) -> TaskResult:
"""
Pruning the model defined in task.
Parameters
----------
task
The pruning task in this step.
Returns
-------
TaskResult
Return the result of the task in this step.
"""
raise NotImplementedError()

def get_best_result(self) -> Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]:
"""
Returns
-------
Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]
Return the task result that has the best performance,
inculde task id, the compact model, the masks on the compact model, score and config list used in this task.
"""
raise NotImplementedError()

def compress(self):
"""
The pruning schedule main loop.
"""
task = self.generate_task()

while task is not None:
task_result = self.pruning_one_step(task)
self.record_task_result(task_result)
del task_result
gc.collect()
task = self.generate_task()
14 changes: 7 additions & 7 deletions nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
}


class OneShotPruner(Pruner):
class BasicPruner(Pruner):
def __init__(self, model: Module, config_list: List[Dict]):
self.data_collector: DataCollector = None
self.metrics_calculator: MetricsCalculator = None
Expand Down Expand Up @@ -120,7 +120,7 @@ def compress(self) -> Tuple[Module, Dict]:
return self.bound_model, masks


class LevelPruner(OneShotPruner):
class LevelPruner(BasicPruner):
def __init__(self, model: Module, config_list: List[Dict]):
"""
Parameters
Expand Down Expand Up @@ -154,7 +154,7 @@ def reset_tools(self):
self.sparsity_allocator = NormalSparsityAllocator(self)


class NormPruner(OneShotPruner):
class NormPruner(BasicPruner):
def __init__(self, model: Module, config_list: List[Dict], p: int,
mode: str = 'normal', dummy_input: Optional[Tensor] = None):
"""
Expand Down Expand Up @@ -275,7 +275,7 @@ def __init__(self, model: Module, config_list: List[Dict],
super().__init__(model, config_list, 2, mode, dummy_input)


class FPGMPruner(OneShotPruner):
class FPGMPruner(BasicPruner):
def __init__(self, model: Module, config_list: List[Dict],
mode: str = 'normal', dummy_input: Optional[Tensor] = None):
"""
Expand Down Expand Up @@ -331,7 +331,7 @@ def reset_tools(self):
raise NotImplementedError('Only support mode `normal` and `dependency_aware`')


class SlimPruner(OneShotPruner):
class SlimPruner(BasicPruner):
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor],
training_epochs: int, scale: float = 0.0001, mode='global'):
Expand Down Expand Up @@ -427,7 +427,7 @@ def reset_tools(self):
raise NotImplementedError('Only support mode `normal` and `global`')


class ActivationPruner(OneShotPruner):
class ActivationPruner(BasicPruner):
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int, activation: str = 'relu',
mode: str = 'normal', dummy_input: Optional[Tensor] = None):
Expand Down Expand Up @@ -544,7 +544,7 @@ def _get_metrics_calculator(self) -> MetricsCalculator:
return MeanRankMetricsCalculator(dim=1)


class TaylorFOWeightPruner(OneShotPruner):
class TaylorFOWeightPruner(BasicPruner):
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int,
mode: str = 'normal', dummy_input: Optional[Tensor] = None):
Expand Down
Loading

0 comments on commit e98ebcf

Please sign in to comment.