Skip to content

Commit

Permalink
NF4 per-channel support for AWQ and Scale Estimation (openvinotoolkit…
Browse files Browse the repository at this point in the history
…#2898)

### Changes

Supported NF4 mode for Scale Estimation and AWQ. 
All results below were collected w/ and w/o Scale estimation algorithms
and w/ Lora Correction algorithm.


![image](https://github.com/user-attachments/assets/eaab96a9-f7c1-438c-9aef-99a37794b10f)

### Reason for changes

NF4 per-channel with scale estimation may give promising results for
NPU, since the accuracy is on par with int4 group-wise quantization.
### Related tickets

150560

### Tests
- [x] OV 2024.5 
job/NNCF/job/manual/job/post_training_weight_compression/182

![image](https://github.com/user-attachments/assets/f76d40b8-70ab-4cb1-8f6b-eba7d9a41c02)
- [x] OV 2024.4
job/NNCF/job/manual/job/post_training_weight_compression/181

![image](https://github.com/user-attachments/assets/899bebc2-4a1a-4ae2-9740-6eba8d3a0fd2)
- [x] OV 2024.3
job/NNCF/job/manual/job/post_training_weight_compression/180

![image](https://github.com/user-attachments/assets/338a0edb-9425-4868-9a1d-7d1b5eec1631)
  • Loading branch information
ljaljushkin authored Sep 23, 2024
1 parent d9b3f38 commit 05f37f5
Show file tree
Hide file tree
Showing 14 changed files with 139 additions and 67 deletions.
12 changes: 2 additions & 10 deletions nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,7 @@ def apply(
self._set_weight_compression_config(ratio_defining_params, model, graph, activations)
nncf_logger.info(self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params))

if (
self._awq
and activations is not None
and self._mode not in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]
):
if self._awq and activations is not None and self._mode != CompressWeightsMode.E2M1:
awq_params = self._advanced_parameters.awq_params
awq_algo = AWQ(
model,
Expand Down Expand Up @@ -399,11 +395,7 @@ def apply(
backend_entity=self._backend_entity,
)
else:
if (
self._scale_estimation
and activations is not None
and self._mode not in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]
):
if self._scale_estimation and activations is not None and self._mode != CompressWeightsMode.E2M1:
scale_estimation_params = self._advanced_parameters.scale_estimation_params
scale_algo = ScaleEstimation(
model,
Expand Down
19 changes: 14 additions & 5 deletions nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.parameters import CompressWeightsMode
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_scale
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_quantization
from nncf.quantization.passes import transform_to_inference_graph
from nncf.tensor import functions as fns

Expand Down Expand Up @@ -244,11 +248,16 @@ def apply(
alpha = self._alpha_min
for _ in range(self._steps):
cur_scale = gscale**alpha

g_compressed_weighs, g_c_scale, g_c_zp = do_int_quantization(
gweight * cur_scale, reduction_axis, awq_config
)
g_decompressed_weighs = do_int_dequantization(g_compressed_weighs, g_c_scale, g_c_zp)
weights_to_fake_quantize = gweight * cur_scale
if config.mode == CompressWeightsMode.NF4:
g_c_scale = calculate_nf4_scale(weights_to_fake_quantize, reduction_axis)
g_compressed_weighs = do_nf4_quantization(weights_to_fake_quantize, g_c_scale)
g_decompressed_weighs = do_nf4_dequantization(g_compressed_weighs, g_c_scale)
else:
g_compressed_weighs, g_c_scale, g_c_zp = do_int_quantization(
weights_to_fake_quantize, reduction_axis, awq_config
)
g_decompressed_weighs = do_int_dequantization(g_compressed_weighs, g_c_scale, g_c_zp)
sacts = gacts / fns.unsqueeze(cur_scale, 1)

cur_out = fns.matmul(g_decompressed_weighs, sacts)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,10 @@ def get_compress_decompress_pipeline(config: WeightCompressionConfig, w_shape, s
@staticmethod
def get_compress_pipeline(config: WeightCompressionConfig, w_shape, s_shape, z_p_shape=None, return_nodes=False):
mode = config.mode
assert mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]
assert mode in [
CompressWeightsMode.INT4_SYM,
CompressWeightsMode.INT4_ASYM,
], f"Only int4 supported, but given={mode}"
num_bits = config.num_bits

asym_quant = mode in [CompressWeightsMode.INT4_ASYM]
Expand Down
68 changes: 47 additions & 21 deletions nncf/quantization/algorithms/weight_compression/scale_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.parameters import CompressWeightsMode
from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_normalized_weight_and_fp4_scale
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_quantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization
from nncf.tensor import Tensor
from nncf.tensor import TensorDataType
Expand Down Expand Up @@ -206,11 +210,18 @@ def calculate_quantization_params(
cur_config.group_size = group_size

original_weight = fns.zeros_like(weight) + weight

compressed_weights, scale, zp = do_int_quantization(original_weight, reduction_axis, cur_config)
if zp is not None:
zp = zp.astype(scale.dtype)
q_weights = do_int_dequantization(compressed_weights, scale, zp, reduction_axis)
if config.mode == CompressWeightsMode.NF4:
norm_weight, scale = calculate_normalized_weight_and_fp4_scale(
original_weight, reduction_axis, cur_config.group_size
)
compressed_weights = do_nf4_quantization(norm_weight, scale, is_normalized_weight=True)
q_weights = do_nf4_dequantization(compressed_weights, scale, reduction_axis)
zp = None
else:
compressed_weights, scale, zp = do_int_quantization(original_weight, reduction_axis, cur_config)
if zp is not None:
zp = zp.astype(scale.dtype)
q_weights = do_int_dequantization(compressed_weights, scale, zp, reduction_axis)

s = fns.unsqueeze(s, 0)
s, _ = reshape_weight_for_grouped_quantization(s, reduction_axis, group_size)
Expand Down Expand Up @@ -246,18 +257,19 @@ def calculate_quantization_params(
key = (config.mode, config.num_bits) + q_weights.shape + scale.shape
if zp is not None:
key += zp_shape
if key in ScaleEstimation.compress_decompress_cache:
compress_decompress_model = ScaleEstimation.compress_decompress_cache[key]["compress_decompress_model"]
compress_model = ScaleEstimation.compress_decompress_cache[key]["compress_model"]
else:
compress_decompress_model = backend_entity.get_compress_decompress_pipeline(
config, q_weights.shape, scale.shape, zp_shape
)
compress_model = backend_entity.get_compress_pipeline(config, q_weights.shape, scale.shape, zp_shape)
ScaleEstimation.compress_decompress_cache[key] = {
"compress_decompress_model": compress_decompress_model,
"compress_model": compress_model,
}
if config.mode != CompressWeightsMode.NF4:
if key in ScaleEstimation.compress_decompress_cache:
compress_decompress_model = ScaleEstimation.compress_decompress_cache[key]["compress_decompress_model"]
compress_model = ScaleEstimation.compress_decompress_cache[key]["compress_model"]
else:
compress_decompress_model = backend_entity.get_compress_decompress_pipeline(
config, q_weights.shape, scale.shape, zp_shape
)
compress_model = backend_entity.get_compress_pipeline(config, q_weights.shape, scale.shape, zp_shape)
ScaleEstimation.compress_decompress_cache[key] = {
"compress_decompress_model": compress_decompress_model,
"compress_model": compress_model,
}
scale_sign = scale / fns.abs(scale)
zero_scale = 0.001
zero_mask = zero_scale * zero_mask.astype(original_weight.dtype)
Expand All @@ -271,7 +283,11 @@ def calculate_quantization_params(
near_to_ideal_scale = near_to_ideal_scale * scale_sign
input_tensors[1] = near_to_ideal_scale.data

out = compress_decompress_model(input_tensors)
if config.mode == CompressWeightsMode.NF4:
g_compressed_weighs = do_nf4_quantization(original_weight, near_to_ideal_scale)
out = do_nf4_dequantization(g_compressed_weighs, near_to_ideal_scale)
else:
out = compress_decompress_model(input_tensors)
q_weights_ = fns.zeros_like(original_weight) + out
q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X)

Expand All @@ -297,7 +313,10 @@ def calculate_quantization_params(
input_tensors[1] = near_to_ideal_scale.data

if i < initial_steps - 1:
out = compress_model(input_tensors)
if config.mode == CompressWeightsMode.NF4:
out = do_nf4_quantization(original_weight, near_to_ideal_scale)
else:
out = compress_model(input_tensors)
compressed_weights = fns.zeros_like(original_weight) + out
target, zero_mask = get_target_zero_mask(compressed_weights, zp)
zero_mask = zero_scale * zero_mask.astype(original_weight.dtype)
Expand All @@ -308,7 +327,10 @@ def calculate_quantization_params(
scaled_scale = factor * scale

input_tensors[1] = scaled_scale.data
out = compress_model(input_tensors)
if config.mode == CompressWeightsMode.NF4:
out = do_nf4_quantization(original_weight, scaled_scale)
else:
out = compress_model(input_tensors)
compressed_weights = fns.zeros_like(original_weight) + out

target, zero_mask = get_target_zero_mask(compressed_weights, zp)
Expand All @@ -317,7 +339,11 @@ def calculate_quantization_params(
near_to_ideal_scale = near_to_ideal_scale * scale_sign

input_tensors[1] = near_to_ideal_scale.data
out = compress_decompress_model(input_tensors)
if config.mode == CompressWeightsMode.NF4:
g_compressed_weighs = do_nf4_quantization(original_weight, near_to_ideal_scale)
out = do_nf4_dequantization(g_compressed_weighs, near_to_ideal_scale)
else:
out = compress_decompress_model(input_tensors)
q_weights_ = fns.zeros_like(original_weight) + out

q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X)
Expand Down
22 changes: 10 additions & 12 deletions nncf/quantization/algorithms/weight_compression/weight_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,39 +178,37 @@ def calculate_normalized_weight(weight: Tensor, scale: Tensor) -> Tensor:

def do_nf4_quantization(weight: Tensor, scale: Tensor, is_normalized_weight: bool = False) -> Tensor:
"""
Performs NF4 quantization - the floating point value is represented by floating point scale, look-up table of
16 NF4 values Quantizes the weight tensor to NF4 format.
Performs NF4 quantization. The floating point values are represented by floating point scale and look-up with
16 floating-point values on [-1, 1]. Scale normalizes original values to [-1, 1] interval and look-up table
"rounds" or "quantize" to the closest quant.
:param weight: Weight tensor to quantize.
:param scale: Scale tensor used for normalization.
:param is_normalized_weight: Whether weight was scaled to [-1, 1] interval. Defaults to False.
:return: Tensor of indexes from 0 to 15 that represents the position in look-up table with the corresponding
NF4 values from -1 to 1.
:return: Tensor with floating-point values, where each of them corresponds to 1 out of 16 quants on [-1, 1].
"""
norm_weight = weight if is_normalized_weight else calculate_normalized_weight(weight, scale)
center_nf4_quantiles = fns.from_numpy(CENTER_OF_NF4_QUANTILES, backend=norm_weight.backend)
indexes = fns.searchsorted(center_nf4_quantiles, norm_weight)
return indexes
nf4_quantiles = fns.from_numpy(NF4_QUANTILES, backend=indexes.backend)
nf4_weight = nf4_quantiles[indexes]
return nf4_weight


def do_nf4_dequantization(indexes: Tensor, scale: Tensor, reduction_axis: int = -1) -> Tensor:
def do_nf4_dequantization(nf4_weight: Tensor, scale: Tensor, reduction_axis: int = -1) -> Tensor:
"""
Decompresses the NF4 quantized weight tensor.
:param indexes: Tensor of indexes from 0 to 15 that represents the position in look-up table with the corresponding
NF4 values from -1 to 1.
:param nf4_weight: Tensor with floating-point values,
where each of them corresponds to 1 out of 16 quants on [-1, 1].
:param scale: Scale tensor used for decompression.
:param reduction_axis: axis along which weights were reshaped for group quantization and will be reshaped back to
original shapes. If equals to -1, weights are not reshaped, assumed not a group quantization. Defaults to -1.
:return: Decompressed weight tensor.
"""
nf4_quantiles = fns.from_numpy(NF4_QUANTILES, backend=indexes.backend)
nf4_weight = nf4_quantiles[indexes]

decompressed_weight = nf4_weight * scale
if reduction_axis != -1:
decompressed_weight = ungroup_weights(decompressed_weight, reduction_axis)

return decompressed_weight


Expand Down
9 changes: 4 additions & 5 deletions nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,14 +493,13 @@ def compress_weights(
if backend == BackendType.OPENVINO:
from nncf.openvino.quantization.quantize_model import compress_weights_impl as ov_compress_weights_impl

if any((awq, scale_estimation)) and (
dataset is None or mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]
if any((awq, scale_estimation, gptq, lora_correction)) and (
dataset is None or mode == CompressWeightsMode.E2M1
):
raise AttributeError(
"Scale estimation or AWQ algorithm is defined, but dataset is None or mode is (NF4 or E2M1)."
"Scale estimation, AWQ, GPTQ or Lora Correction algorithm is defined, "
"but dataset is None or mode is E2M1."
)
if any((gptq, lora_correction)) and (dataset is None or mode == CompressWeightsMode.E2M1):
raise AttributeError("GPTQ or Lora Correction algorithm is defined, but dataset is None or mode is E2M1.")

if gptq and lora_correction:
raise AttributeError(
Expand Down
32 changes: 24 additions & 8 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization
from nncf.scopes import IgnoredScope
from nncf.tensor import Tensor
from nncf.tensor import TensorDataType
from tests.cross_fw.shared.helpers import compare_stats
from tests.cross_fw.shared.helpers import dump_to_json
from tests.cross_fw.shared.helpers import load_json
Expand Down Expand Up @@ -710,7 +711,7 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params):
compress_weights(ov.Model([], []), mode=mode, **params)


@pytest.mark.parametrize("mode", INT4_MODES)
@pytest.mark.parametrize("mode", INT4_NF4_MODES)
@pytest.mark.parametrize(
"params",
({"dataset": "anything", "lora_correction": True, "gptq": True},),
Expand Down Expand Up @@ -748,7 +749,7 @@ def test_call_max_var_criterion_with_dataset_by_default_awq(mode):
compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, awq=True)


@pytest.mark.parametrize("mode", INT4_MODES)
@pytest.mark.parametrize("mode", INT4_NF4_MODES)
@pytest.mark.parametrize("with_multiply", (True, False))
def test_call_max_var_criterion_with_dataset_by_default_awq_act_matmul(mode, with_multiply):
n_layers = 8
Expand All @@ -765,15 +766,15 @@ def test_call_max_var_criterion_with_dataset_by_default_awq_act_matmul(mode, wit
assert awq_num == n_awq_target


@pytest.mark.parametrize("mode", INT4_MODES)
@pytest.mark.parametrize("mode", INT4_NF4_MODES)
def test_call_max_var_criterion_with_dataset_awq_for_compressed_model(mode):
model = AWQMatmulModel(is_int8=True).ov_model
dataset = Dataset([np.ones([8, 8])])

compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, awq=True)


@pytest.mark.parametrize("mode", INT4_MODES)
@pytest.mark.parametrize("mode", INT4_NF4_MODES)
def test_call_max_var_criterion_with_dataset_awq_neg_group_size(mode):
model = AWQMatmulModel().ov_model
dataset = Dataset([np.ones([8, 8])])
Expand Down Expand Up @@ -875,23 +876,38 @@ def test_duplicate_names_generation():
op_names.add(name)


@pytest.mark.parametrize("mode", INT4_MODES)
def test_call_max_var_criterion_with_dataset_by_default_scale_estimation(mode):
@pytest.mark.parametrize(
("mode", "compressed_weight_dtype"),
(
(CompressWeightsMode.INT4_SYM, TensorDataType.int8),
(CompressWeightsMode.INT4_ASYM, TensorDataType.uint8),
(CompressWeightsMode.NF4, TensorDataType.float32),
),
)
def test_call_max_var_criterion_with_dataset_by_default_scale_estimation(mode, compressed_weight_dtype, mocker):
model = AWQMatmulModel().ov_model
dataset = Dataset([np.ones([8, 8])])
from nncf.quantization.algorithms.weight_compression import scale_estimation
from nncf.quantization.algorithms.weight_compression.algorithm import ScaleEstimation

se_spy = mocker.spy(ScaleEstimation, "apply")
tzm_spy = mocker.spy(scale_estimation, "get_target_zero_mask")

compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, scale_estimation=True)

assert se_spy.call_count == 1
assert tzm_spy.call_args_list[0][0][0].dtype == compressed_weight_dtype

@pytest.mark.parametrize("mode", INT4_MODES)

@pytest.mark.parametrize("mode", INT4_NF4_MODES)
def test_call_max_var_criterion_with_dataset_scale_estimation_for_compressed_model(mode):
model = AWQMatmulModel(is_int8=True).ov_model
dataset = Dataset([np.ones([8, 8])])

compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, scale_estimation=True)


@pytest.mark.parametrize("mode", INT4_MODES)
@pytest.mark.parametrize("mode", INT4_NF4_MODES)
def test_call_max_var_criterion_with_dataset_scale_estimation_neg_group_size(mode):
model = AWQMatmulModel().ov_model
dataset = Dataset([np.ones([8, 8])])
Expand Down
6 changes: 5 additions & 1 deletion tests/post_training/data/wc_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,8 @@ tinyllama_scale_estimation_per_channel_backend_OV:
tinyllama_data_aware_lora_stateful_backend_OV:
metric_value: 0.83446
num_int4: 94
num_int8: 500
num_int8: 500
tinyllama_NF4_scale_estimation_stateful_per_channel_backend_OV:
metric_value: 0.88663
num_int4: 11
num_int8: 290
5 changes: 5 additions & 0 deletions tests/post_training/data/wc_reference_data_2024.4.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,9 @@ tinyllama_scale_estimation_per_channel_backend_OV:
metric_value: 0.80853
num_int4: 188
num_int8: 124
metrics_xfail_reason: "Issue-148819"
tinyllama_NF4_scale_estimation_stateful_per_channel_backend_OV:
metric_value: 0.87942
num_int4: 11
num_int8: 290
metrics_xfail_reason: "Issue-148819"
6 changes: 5 additions & 1 deletion tests/post_training/data/wc_reference_data_2024.5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,8 @@ tinyllama_scale_estimation_per_channel_backend_OV:
metric_value: 0.80798
num_int4: 188
num_int8: 124
metrics_xfail_reason: "Issue-148819"
metrics_xfail_reason: "Issue-148819"
tinyllama_NF4_scale_estimation_stateful_per_channel_backend_OV:
metric_value: 0.87132
num_int4: 11
num_int8: 290
Loading

0 comments on commit 05f37f5

Please sign in to comment.