Skip to content

Commit

Permalink
refine model compression examples (microsoft#1804)
Browse files Browse the repository at this point in the history
  • Loading branch information
chicm-ms committed Dec 2, 2019
1 parent 962d9ae commit e5cb4ed
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 351 deletions.
16 changes: 7 additions & 9 deletions examples/model_compress/fpgm_torch_mnist.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from nni.compression.torch import FPGMPruner
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms

from nni.compression.torch import FPGMPruner

class Mnist(torch.nn.Module):
def __init__(self):
Expand All @@ -23,8 +22,8 @@ def forward(self, x):
return F.log_softmax(x, dim=1)

def _get_conv_weight_sparsity(self, conv_layer):
num_zero_filters = (conv_layer.weight.data.sum((2,3)) == 0).sum()
num_filters = conv_layer.weight.data.size(0) * conv_layer.weight.data.size(1)
num_zero_filters = (conv_layer.weight.data.sum((1, 2, 3)) == 0).sum()
num_filters = conv_layer.weight.data.size(0)
return num_zero_filters, num_filters, float(num_zero_filters)/num_filters

def print_conv_filter_sparsity(self):
Expand All @@ -41,7 +40,8 @@ def train(model, device, train_loader, optimizer):
output = model(data)
loss = F.nll_loss(output, target)
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
print('{:.2f}% Loss {:.4f}'.format(100 * batch_idx / len(train_loader), loss.item()))
if batch_idx == 0:
model.print_conv_filter_sparsity()
loss.backward()
optimizer.step()
Expand All @@ -59,7 +59,7 @@ def test(model, device, test_loader):
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%)\n'.format(
print('Loss: {:.4f} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))


Expand All @@ -78,9 +78,6 @@ def main():
model = Mnist()
model.print_conv_filter_sparsity()

'''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list)
'''
configure_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d']
Expand All @@ -96,6 +93,7 @@ def main():
train(model, device, train_loader, optimizer)
test(model, device, test_loader)

pruner.export_model('model.pth', 'mask.pth')

if __name__ == '__main__':
main()
6 changes: 3 additions & 3 deletions examples/model_compress/lottery_torch_mnist_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def forward(self, x):
def train(model, train_loader, optimizer, criterion):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.train()
for batch_idx, (imgs, targets) in enumerate(train_loader):
for imgs, targets in train_loader:
optimizer.zero_grad()
imgs, targets = imgs.to(device), targets.to(device)
output = model(imgs)
Expand Down Expand Up @@ -64,7 +64,7 @@ def test(model, test_loader, criterion):
criterion = nn.CrossEntropyLoss()

configure_list = [{
'prune_iterations': 10,
'prune_iterations': 5,
'sparsity': 0.96,
'op_types': ['default']
}]
Expand All @@ -75,7 +75,7 @@ def test(model, test_loader, criterion):
pruner.prune_iteration_start()
loss = 0
accuracy = 0
for epoch in range(50):
for epoch in range(10):
loss = train(model, train_loader, optimizer, criterion)
accuracy = test(model, test_loader, criterion)
print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy))
Expand Down
132 changes: 0 additions & 132 deletions examples/model_compress/main_tf_pruner.py

This file was deleted.

119 changes: 0 additions & 119 deletions examples/model_compress/main_tf_quantizer.py

This file was deleted.

2 changes: 0 additions & 2 deletions examples/model_compress/main_torch_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def main():

pruner = AGP_Pruner(model, configure_list)
model = pruner.compress()
# you can also use compress(model) method
# like that pruner.compress(model)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(10):
Expand Down
Loading

0 comments on commit e5cb4ed

Please sign in to comment.