Skip to content

Commit

Permalink
fix quantize pass error when the quantization supported Op are exclud…
Browse files Browse the repository at this point in the history
…ed in the model (apache#13596)
  • Loading branch information
ciyongch authored and reminisce committed Dec 12, 2018
1 parent 002e0bb commit e36f888
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 39 deletions.
5 changes: 3 additions & 2 deletions src/operator/quantization/quantize_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ Graph QuantizeGraph(Graph &&src) {
// skip non-quantized input
continue;
}
if (quantized_op_map.count(e.node->op())) {
if (NeedQuantize(e.node, excluded_nodes)) {
// here we calculate the output number (exclude min/max, in order to
// calculate min/max index from mirror node) based on assumption that
// there is only 1min and 1max output from mirror node (which is
Expand Down Expand Up @@ -314,7 +314,8 @@ Graph QuantizeGraph(Graph &&src) {

std::vector<NodeEntry> outputs;
for (const auto& e : src.outputs) {
if (quantized_op_map.count(e.node->op())) {
if (NeedQuantize(e.node, excluded_nodes)) {
// Only insert dequantize for those Ops supports quantize and not excluded.
NodePtr mirror_node = mirror_map.at(e.node.get());
NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version};
size_t num_inputs = e.node->num_inputs();
Expand Down
87 changes: 50 additions & 37 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,12 +406,16 @@ def get_fp32_sym():

def get_fp32_residual():
data = mx.sym.Variable('data')
conv = mx.sym.Convolution(data=data, num_filter=4, kernel=(1,1), pad=(0,0),
no_bias=True, name='conv')
bn = mx.sym.BatchNorm(data=conv, fix_gamma=False, eps=2e-5, momentum=0.9, name='bn')
act = mx.sym.Activation(data=bn + data, act_type='relu', name='relu')
pool = mx.sym.Pooling(act, kernel=(4, 4), pool_type='avg', name='pool')
fc = mx.sym.FullyConnected(pool, num_hidden=10, flatten=True, name='fc')
conv0 = mx.sym.Convolution(data=data, num_filter=4, kernel=(1,1), pad=(0,0),
no_bias=True, name='conv0')
bn = mx.sym.BatchNorm(data=conv0, fix_gamma=False, eps=2e-5, momentum=0.9, name='bn')
act0 = mx.sym.Activation(data=bn + data, act_type='relu', name='relu0')
pool0 = mx.sym.Pooling(act0, kernel=(4, 4), pool_type='avg', name='pool0')
conv1 = mx.sym.Convolution(data=pool0, num_filter=4, kernel=(1,1), pad=(0,0),
no_bias=False, name='conv1')
act1 = mx.sym.Activation(data=conv1, act_type='relu', name='relu1')
pool1 = mx.sym.Pooling(act1, kernel=(4, 4), pool_type='avg', name='pool1')
fc = mx.sym.FullyConnected(pool1, num_hidden=10, flatten=True, name='fc')
sym = mx.sym.SoftmaxOutput(fc, grad_scale=1, ignore_label=-1, multi_output=False,
out_grad=False, preserve_shape=False, use_ignore=False, name='softmax')
return sym
Expand Down Expand Up @@ -574,38 +578,47 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape):

mod.init_params()
arg_params, aux_params = mod.get_params()
excluded_sym_names = []
excluded_names = []
if mx.current_context() == mx.cpu():
excluded_sym_names += ['fc']
excluded_sym_names += ['concat']
qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s,
arg_params=arg_params,
aux_params=aux_params,
excluded_sym_names=excluded_sym_names,
ctx=mx.current_context(),
quantized_dtype=qdtype,
calib_mode='none')
check_params(arg_params, qarg_params, qsym)
check_params(aux_params, qaux_params)
check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape)

calib_data = mx.nd.random.uniform(shape=dshape)
calib_data = NDArrayIter(data=calib_data, batch_size=batch_size)
calib_data = DummyIter(calib_data)
qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s,
arg_params=arg_params,
aux_params=aux_params,
excluded_sym_names=excluded_sym_names,
ctx=mx.current_context(),
quantized_dtype=qdtype,
calib_mode='naive',
calib_data=calib_data,
num_calib_examples=20)
check_params(arg_params, qarg_params, qsym)
check_params(aux_params, qaux_params)
check_qsym_calibrated(qsym)
check_qsym_qdtype(qsym, qdtype)
check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape)
excluded_names += ['fc']
excluded_names += ['concat']

optional_names = ['pool0']
for skip_optional_names in [False, True]:
exclude_sym_names = []
if skip_optional_names:
excluded_sym_names = excluded_names
else:
excluded_sym_names = excluded_names + optional_names

qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s,
arg_params=arg_params,
aux_params=aux_params,
excluded_sym_names=excluded_sym_names,
ctx=mx.current_context(),
quantized_dtype=qdtype,
calib_mode='none')
check_params(arg_params, qarg_params, qsym)
check_params(aux_params, qaux_params)
check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape)

calib_data = mx.nd.random.uniform(shape=dshape)
calib_data = NDArrayIter(data=calib_data, batch_size=batch_size)
calib_data = DummyIter(calib_data)
qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s,
arg_params=arg_params,
aux_params=aux_params,
excluded_sym_names=excluded_sym_names,
ctx=mx.current_context(),
quantized_dtype=qdtype,
calib_mode='naive',
calib_data=calib_data,
num_calib_examples=20)
check_params(arg_params, qarg_params, qsym)
check_params(aux_params, qaux_params)
check_qsym_calibrated(qsym)
check_qsym_qdtype(qsym, qdtype)
check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape)

for qdtype in ['int8', 'uint8']:
check_quantize_model(qdtype)
Expand Down

0 comments on commit e36f888

Please sign in to comment.