Skip to content

Commit

Permalink
ONNX Post-Training Quantization ImageNet classification sample (#1143)
Browse files Browse the repository at this point in the history
* Introduce ONNX GraphConverter to get NNCFGraph from ONNX

* Place tests from experimental/onnx to onnx; Introduce ONNXGraph class instead ONNXGraphHelper; Some minor code changes;

* Fix pylint; Add requirements in tests/onnx

* Fix typo

* Add license, fix some comments

* minor

* Add docstring

* Update typehints

* Fix getting shapes of inputs, outputs; Add pylint exceptions

* Add requirements to ONNX

* Apply comments

* Add dtype attribute mapping from ONNX to NNCFGraph

* Align input, output nodes name with common NNCF; Extend onnx types to NNCF types mapping; Add test model with int edge

* Fix Vasiliy comment

* Add dot file

* Update requirements.tx

* Add hardware patterns; Add harware config support

* Create an early draft on PTQ API proposal

* Update API

* Place initialization flow to Initializer

* Add post-training quantization config

* Fix typos

* Add logic of framework agnostic building CompressedModel inside CompressionBuilder; Add serialization step;

* Implement the api ideas

* draft

* Implement algo with minimum inference of model

* draft x3

* draft x4

* Update sample

* Apply comments

* new draft

* Remove bn_adaptation files; Improve code style

* Fix some todos

* add draft biascorrection

* change directories

* update algorithms

* Improve code style

* Fix bug with adding q/dq to the model;
Now all models IR version are changed to version 7;
Add functional test on quantization

* Remove all extra code;

* Fix typo; remove queue;

* Remove ONNXUpdateBias and bias transformations; Delete comments + debug code;

* Add test on quantized models graphs;
Make input argument to ONNXEngine strictly np.ndarray;
Add statistics collection to apply method of ONNXQuantizerRangeFinderAlgorithm;
Make sampler working with torch and numpy;
Remove test of ptq sample;

* Change name QuantizerRangeFinderAlgorithm to MinMaxQuantization;
Some minor changes

* Hide torchvision imports in helper

* Fix codestyle

* Remove CompressedModel

* Add typehints;
Add small class descriptions

* Add test of graph after PostTrainingQuantization

* Translate string parameters to ENUM;
Minor changes

* Add ONNXMeanMinMaxStatisticCollector;
Fix bugs;
Set default range_type to MEAN_MINMAX;
Update requirements;

* Small improvements in Engine;

* Add test on ONNXModelTranformer;
Small improvements;

* Add many quantizers in one transformation layout for test_model_transformer;

* Improve codestyle

* Add test on parameters of inserted quantizers by ONNXModelTransformer

* Now algorithms don't collect statistics inside. They get statistics as an argument to the function apply(). So statistics always should run before the algorothms; Now CompressionBuilder is essential;

* Make batch_size=1;
The support should be done in the following PRs;
Minor changes

* Rename statistics_collector to statistics_aggregator

* Fix test

* Fix test x2;
Replace PTQ algo to algorithm.py

* Add license;
Create algorithm.py

* Fix pylint;
Add torch verion in requirements

* Add test_sampler;
Fix bug in BatchSampler and RandomBatchSampler;

* Add torchvision in requirements.txt;
Move docstring

* Improve codestyle

* Determine backend once;
Add create_subalgorithms() to Algorithm;
Rename test

* Rename files and functions;

* Make names Contants in ModelTransformer;
Add more comments;
Improve Codestyle

* Typo fix

Co-authored-by: Lyalyushkin Nikolay <[email protected]>

* Typo fix

Co-authored-by: Lyalyushkin Nikolay <[email protected]>

* Add test on StatisticsAggregator;
Fix bug in test_samplers.py

* Improve code style

* Change deque to List

* Add nncf_logger

* Move Constants inside ONNXModelTransformer

* Place min_max_quantization.py to quantization folder;
Fix tests

* Fix pylint

* Fix pylint;
Rename utils.py to model_normalizer.py

* Fix pylint

* Add ONNX ptq sample

* Add licences;
Add mock dataset to test_sanity_sample.py

* Add Readme;
Change dir;
Add AC configs

* Add results

* Update results table

* Rename and remove dataset path from AC configs;
Add sample requirements;
Improve readme

* Add mean, std, crop_ration into args of create_dataloader_from_imagenet_torch_dataset

* Fix readme

* Add description;
Remove 'train' postfix in creating dataset path

* Make default init_samples to 300;
Update metrics calibrated on val part of ImageNet

Co-authored-by: Lyalyushkin Nikolay <[email protected]>
  • Loading branch information
2 people authored and vshampor committed Mar 29, 2022
1 parent 1c32f21 commit 41846be
Show file tree
Hide file tree
Showing 15 changed files with 494 additions and 5 deletions.
63 changes: 63 additions & 0 deletions examples/experimental/onnx/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Classification sample

This sample shows an example of quantization of classification models. The used dataset is ImageNet.

## Install

Please, install the requirements for ONNX Post-Training Quantization of NNCF.

Install requirements

```
pip install -r <nncf dir>/nncf/experimental/onnx/requirements.txt
pip install -r <nncf dir>/examples/experimental/onnx/requirements.txt
```

## Getting the quantized model

To run post-training quantization on your model you can use the following command.

```
python onnx_ptq_classification.py -m <ONNX model path> -o <quantized ONNX model path> --data <ImageNet data path>
```

Also, you could specify some options of quantization, please, take a look at the argument description by using the command:

```
python onnx_ptq_classification.py --help
```

## Measuring the accuracy of the original and quantized models

If you would like to compare the accuracy of the original model and quantized one, you could
use [accuracy_checker](https://github.com/openvinotoolkit/open_model_zoo/tree/master/tools/accuracy_checker). The
necessary config files are located [here](./examples/experimental/onnx/ac_configs/). The thing that you only need is to
fill in the config with the following infromation: the path to ImageNet folder and the path to the annotation file. The
accuracy checker config for the original and quantized models is the same.

Use the following command to get the model accuracy:

```
accuracy_check -c <path to config fileh> -m <ONNX model>
```

## Results of Post-Training quantization of ONNX models

| Model | Original accuracy | Quantized model accuracy |
|:-------------------------:|:-----------------:|:------------------------:|
| ResNet-50 | 75.17% | 74.74% |
| MobilenetV2 | 71.87% | 71.29% |
| InceptionV1 (GoogleNetV1) | 69.77% | 69.64% |
| InceptionV3 (GoogleNetV3) | 77.45% | 77.30% |
| SqueezenetV1.1 | 58.19% | 57.72% |

## Measuring the performance of the original and quantized models

If you would like to compare the performance of the original model and quantized one, you could
use [benchmark_tool](https://github.com/openvinotoolkit/openvino/tree/master/tools/benchmark_tool).

Use the following command to get the model performance numbers:

```
benchmark_app -m <ONNX model>
```
Empty file.
41 changes: 41 additions & 0 deletions examples/experimental/onnx/ac_configs/inception_v1.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
models:
- name: googlenet-v1
launchers:
- framework: onnx_runtime
adapter: classification
execution_providers: ['OpenVINOExecutionProvider']
inputs:
- name: result.1
type: INPUT
shape: [1,3,224,224]

datasets:
- name: imagenet_1000_classes
data_source: <ImageNet folder>
annotation_conversion:
converter: imagenet
annotation_file: <annotation file>
reader: pillow_imread

preprocessing:
- type: resize
size: 256
aspect_ratio_scale: greater
use_pillow: true
interpolation: BILINEAR

- type: crop
size: 224
use_pillow: true

- type: normalization
std: 255

- type: normalization
mean: (0.485, 0.456, 0.406)
std: (0.229, 0.224, 0.225)

metrics:
- name: accuracy@top1
type: accuracy
top_k: 1
33 changes: 33 additions & 0 deletions examples/experimental/onnx/ac_configs/inception_v3.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
models:
- name: googlenet-v3
launchers:
- framework: onnx_runtime
adapter: classification
execution_providers: ['OpenVINOExecutionProvider']
datasets:
- name: imagenet_1000_classes
data_source: <ImageNet folder>
annotation_conversion:
converter: imagenet
annotation_file: <annotation file>
reader: pillow_imread
preprocessing:
- type: resize
size: 320
aspect_ratio_scale: greater
use_pillow: true
interpolation: BILINEAR
- type: crop
size: 299
use_pillow: true

- type: normalization
std: 255

- type: normalization
mean: (0.485, 0.456, 0.406)
std: (0.229, 0.224, 0.225)
metrics:
- name: accuracy@top1
type: accuracy
top_k: 1
41 changes: 41 additions & 0 deletions examples/experimental/onnx/ac_configs/mobilenet_v2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
models:
- name: mobilenet-v2
launchers:
- framework: onnx_runtime
adapter: classification
execution_providers: ['OpenVINOExecutionProvider']
inputs:
- name: input.1
type: INPUT
shape: [1,3,224,224]

datasets:
- name: imagenet_1000_classes
data_source: <ImageNet folder>
annotation_conversion:
converter: imagenet
annotation_file: <annotation file>
reader: pillow_imread

preprocessing:
- type: resize
size: 256
aspect_ratio_scale: greater
use_pillow: true
interpolation: BILINEAR

- type: crop
size: 224
use_pillow: true

- type: normalization
std: 255

- type: normalization
mean: (0.485, 0.456, 0.406)
std: (0.229, 0.224, 0.225)

metrics:
- name: accuracy@top1
type: accuracy
top_k: 1
41 changes: 41 additions & 0 deletions examples/experimental/onnx/ac_configs/resnet50.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
models:
- name: resnet50
launchers:
- framework: onnx_runtime
adapter: classification
execution_providers: ['OpenVINOExecutionProvider']
inputs:
- name: input.1
type: INPUT
shape: [1,3,224,224]

datasets:
- name: imagenet_1000_classes
data_source: <ImageNet folder>
annotation_conversion:
converter: imagenet
annotation_file: <annotation file>
reader: pillow_imread

preprocessing:
- type: resize
size: 256
aspect_ratio_scale: greater
use_pillow: true
interpolation: BILINEAR

- type: crop
size: 224
use_pillow: true

- type: normalization
std: 255

- type: normalization
mean: (0.485, 0.456, 0.406)
std: (0.229, 0.224, 0.225)

metrics:
- name: accuracy@top1
type: accuracy
top_k: 1
41 changes: 41 additions & 0 deletions examples/experimental/onnx/ac_configs/squeezenet_1_1l.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
models:
- name: squeezenet-v1-1
launchers:
- framework: onnx_runtime
adapter: classification
execution_providers: ['OpenVINOExecutionProvider']
inputs:
- name: input
type: INPUT
shape: [1,3,224,224]

datasets:
- name: imagenet_1000_classes
data_source: <ImageNet folder>
annotation_conversion:
converter: imagenet
annotation_file: <annotation file>
reader: pillow_imread

preprocessing:
- type: resize
size: 256
aspect_ratio_scale: greater
use_pillow: true
interpolation: BILINEAR

- type: crop
size: 224
use_pillow: true

- type: normalization
std: 255

- type: normalization
mean: (0.485, 0.456, 0.406)
std: (0.229, 0.224, 0.225)

metrics:
- name: accuracy@top1
type: accuracy
top_k: 1
84 changes: 84 additions & 0 deletions examples/experimental/onnx/onnx_ptq_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
Copyright (c) 2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import argparse

from typing import List

import onnx

from nncf.experimental.post_training.compression_builder import CompressionBuilder
from nncf.experimental.post_training.algorithms.quantization import PostTrainingQuantization
from nncf.experimental.post_training.algorithms.quantization import PostTrainingQuantizationParameters
from nncf.experimental.onnx.dataloaders.imagenet_dataloader import create_dataloader_from_imagenet_torch_dataset


def run(onnx_model_path: str, output_model_path: str,
dataset_path: str, batch_size: int, shuffle: bool, num_init_samples: int,
input_shape: List[int], ignored_scopes: List[str] = None):
print("Post-Training Quantization Parameters:")
print(" number of samples: ", num_init_samples)
print(" ignored_scopes: ", ignored_scopes)
onnx.checker.check_model(onnx_model_path)
original_model = onnx.load(onnx_model_path)
print(f"The model is loaded from {onnx_model_path}")

# Step 1: Initialize the data loader.
dataloader = create_dataloader_from_imagenet_torch_dataset(dataset_path, input_shape,
batch_size=batch_size, shuffle=shuffle)

# Step 2: Create a pipeline of compression algorithms.
builder = CompressionBuilder()

# Step 3: Create the quantization algorithm and add to the builder.
quantization_parameters = PostTrainingQuantizationParameters(
number_samples=num_init_samples,
ignored_scopes=ignored_scopes
)
quantization = PostTrainingQuantization(quantization_parameters)
builder.add_algorithm(quantization)

# Step 4: Execute the pipeline.
print("Post-Training Quantization has just started!")
quantized_model = builder.apply(original_model, dataloader)

# Step 5: Save the quantized model.
onnx.save(quantized_model, output_model_path)
print(f"The quantized model is saved on {output_model_path}")

onnx.checker.check_model(output_model_path)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--onnx_model_path", "-m", help="Path to ONNX model", type=str)
parser.add_argument("--output_model_path", "-o", help="Path to output quantized ONNX model", type=str)
parser.add_argument("--data",
help="Path to ImageNet validation data in the ImageFolder torchvision format "
"(Please, take a look at torchvision.datasets.ImageFolder)",
type=str)
parser.add_argument("--batch_size", help="Batch size for initialization", default=1)
parser.add_argument("--shuffle", help="Whether to shuffle dataset for initialization", default=True)
parser.add_argument("--input_shape", help="Model's input shape", nargs="+", type=int, default=[1, 3, 224, 224])
parser.add_argument("--init_samples", help="Number of initialization samples", type=int, default=300)
parser.add_argument("--ignored_scopes", help="Ignored operations ot quantize", nargs="+", default=None)
args = parser.parse_args()
run(args.onnx_model_path,
args.output_model_path,
args.data,
args.batch_size,
args.shuffle,
args.init_samples,
args.input_shape,
args.ignored_scopes
)
1 change: 1 addition & 0 deletions examples/experimental/onnx/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
torchvision
16 changes: 11 additions & 5 deletions nncf/experimental/onnx/dataloaders/imagenet_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,25 @@ def __len__(self):
return len(self.dataset)


def create_dataloader_from_imagenet_torch_dataset(dataset_dir, input_shape: List[int], batch_size: int = 1,
def create_dataloader_from_imagenet_torch_dataset(dataset_dir: str,
input_shape: List[int],
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
crop_ratio=0.875,
batch_size: int = 1,
shuffle: bool = True):
import torchvision
from torchvision import transforms
image_size = [input_shape[-2], input_shape[-1]]
size = int(image_size[0] / 0.875)
normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))
size = int(image_size[0] / crop_ratio)
normalize = transforms.Normalize(mean=mean,
std=std)
transform = transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
normalize,
])
initialization_dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir, 'train'), transform)
# The best practise is to use validation part of dataset for calibration (aligning with POT)
initialization_dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir), transform)
return ImageNetDataLoader(initialization_dataset, batch_size, shuffle)
Loading

0 comments on commit 41846be

Please sign in to comment.