Skip to content

Commit

Permalink
Fpgm algo implementation unit test (microsoft#1746)
Browse files Browse the repository at this point in the history
* unit test for fpgm pruner
  • Loading branch information
chicm-ms authored Nov 21, 2019
1 parent 6a5864c commit 55b557f
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 35 deletions.
12 changes: 4 additions & 8 deletions src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,17 +178,13 @@ def _get_min_gm_kernel_idx(self, weight, n):
assert len(weight.shape) >= 3
assert weight.shape[0] * weight.shape[1] > 2

dist_list, idx_list = [], []
dist_list = []
for in_i in range(weight.shape[0]):
for out_i in range(weight.shape[1]):
dist_sum = self._get_distance_sum(weight, in_i, out_i)
dist_list.append(dist_sum)
idx_list.append([in_i, out_i])
dist_tensor = tf.convert_to_tensor(dist_list)
idx_tensor = tf.constant(idx_list)

_, idx = tf.math.top_k(dist_tensor, k=n)
return tf.gather(idx_tensor, idx)
dist_list.append((dist_sum, (in_i, out_i)))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels]

def _get_distance_sum(self, weight, in_idx, out_idx):
w = tf.reshape(weight, (-1, weight.shape[-2], weight.shape[-1]))
Expand Down
104 changes: 77 additions & 27 deletions src/sdk/pynni/tests/test_compressor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from unittest import TestCase, main
import numpy as np
import tensorflow as tf
import torch
import torch.nn.functional as F
Expand All @@ -7,11 +8,11 @@
if tf.__version__ >= '2.0':
import nni.compression.tensorflow as tf_compressor

def get_tf_mnist_model():
def get_tf_model():
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=32, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"),
tf.keras.layers.Conv2D(filters=5, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"),
tf.keras.layers.MaxPooling2D(pool_size=2),
tf.keras.layers.Conv2D(filters=64, kernel_size=3, activation='relu', padding="SAME"),
tf.keras.layers.Conv2D(filters=10, kernel_size=3, activation='relu', padding="SAME"),
tf.keras.layers.MaxPooling2D(pool_size=2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=128, activation='relu'),
Expand All @@ -23,43 +24,51 @@ def get_tf_mnist_model():
metrics=["accuracy"])
return model

class TorchMnist(torch.nn.Module):
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)

def tf2(func):
def test_tf2_func(self):
def test_tf2_func(*args):
if tf.__version__ >= '2.0':
func()
func(*args)
return test_tf2_func

k1 = [[1]*3]*3
k2 = [[2]*3]*3
k3 = [[3]*3]*3
k4 = [[4]*3]*3
k5 = [[5]*3]*3

w = [[k1, k2, k3, k4, k5]] * 10

class CompressorTestCase(TestCase):
def test_torch_pruner(self):
model = TorchMnist()
def test_torch_level_pruner(self):
model = TorchModel()
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
torch_compressor.LevelPruner(model, configure_list).compress()

def test_torch_fpgm_pruner(self):
model = TorchMnist()
configure_list = [{'sparsity': 0.5, 'op_types': ['Conv2d']}]
torch_compressor.FPGMPruner(model, configure_list).compress()
@tf2
def test_tf_level_pruner(self):
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
tf_compressor.LevelPruner(get_tf_model(), configure_list).compress()

def test_torch_quantizer(self):
model = TorchMnist()
def test_torch_naive_quantizer(self):
model = TorchModel()
configure_list = [{
'quant_types': ['weight'],
'quant_bits': {
Expand All @@ -70,18 +79,59 @@ def test_torch_quantizer(self):
torch_compressor.NaiveQuantizer(model, configure_list).compress()

@tf2
def test_tf_pruner(self):
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
tf_compressor.LevelPruner(get_tf_mnist_model(), configure_list).compress()
def test_tf_naive_quantizer(self):
tf_compressor.NaiveQuantizer(get_tf_model(), [{'op_types': ['default']}]).compress()

@tf2
def test_tf_quantizer(self):
tf_compressor.NaiveQuantizer(get_tf_mnist_model(), [{'op_types': ['default']}]).compress()
def test_torch_fpgm_pruner(self):
"""
With filters(kernels) defined as above (k1 - k5), it is obvious that k3 is the Geometric Median
which minimize the total geometric distance by defination of Geometric Median in this paper:
Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration,
https://arxiv.org/pdf/1811.00250.pdf
So if sparsity is 0.2, the expected masks should mask out all k3, this can be verified through:
`all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 90., 0., 90., 90.]))`
If sparsity is 0.6, the expected masks should mask out all k2, k3, k4, this can be verified through:
`all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 0., 0., 0., 90.]))`
"""

model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d']}, {'sparsity': 0.6, 'op_types': ['Conv2d']}]
pruner = torch_compressor.FPGMPruner(model, config_list)

model.conv2.weight.data = torch.tensor(w).float()
layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
masks = pruner.calc_mask(layer, config_list[0])
assert all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 90., 0., 90., 90.]))

pruner.update_epoch(1)
model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1])
assert all(torch.sum(masks, (0, 2, 3)).numpy() == np.array([90., 0., 0., 0., 90.]))

@tf2
def test_tf_fpgm_pruner(self):
configure_list = [{'sparsity': 0.5, 'op_types': ['Conv2D']}]
tf_compressor.FPGMPruner(get_tf_mnist_model(), configure_list).compress()
model = get_tf_model()
config_list = [{'sparsity': 0.2, 'op_types': ['Conv2D']}, {'sparsity': 0.6, 'op_types': ['Conv2D']}]

pruner = tf_compressor.FPGMPruner(model, config_list)
weights = model.layers[2].weights
weights[0] = np.array(w).astype(np.float32).transpose([2, 3, 0, 1]).transpose([0, 1, 3, 2])
model.layers[2].set_weights([weights[0], weights[1].numpy()])

layer = tf_compressor.compressor.LayerInfo(model.layers[2])
masks = pruner.calc_mask(layer, config_list[0]).numpy()
masks = masks.transpose([2, 3, 0, 1]).transpose([1, 0, 2, 3])

assert all(masks.sum((0, 2, 3)) == np.array([90., 90., 0., 90., 90.]))

pruner.update_epoch(1)
model.layers[2].set_weights([weights[0], weights[1].numpy()])
masks = pruner.calc_mask(layer, config_list[1]).numpy()
masks = masks.transpose([2, 3, 0, 1]).transpose([1, 0, 2, 3])

assert all(masks.sum((0, 2, 3)) == np.array([90., 0., 0., 0., 90.]))


if __name__ == '__main__':
Expand Down

0 comments on commit 55b557f

Please sign in to comment.