Skip to content

Commit

Permalink
Fix weight only quant bug. (ModelTC#342)
Browse files Browse the repository at this point in the history
Co-authored-by: hiworldwzj <[email protected]>
  • Loading branch information
helloyongyang and hiworldwzj authored Mar 4, 2024
1 parent 01a87de commit 982e802
Show file tree
Hide file tree
Showing 12 changed files with 37 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ def matmul_dequantize_int8(a, b, b_scale, out=None):
return c


def quantize_int8(weight, axis=0):
def quantize_int8(weight, axis=0, tp_rank=0):
# Weight shape: [H1, H2]
# Scale shape: [H2]
scale = weight.abs().amax(axis, keepdim=True) / 127.
weight = (weight / scale).to(torch.int8)
if axis == 0:
weight = weight.t().contiguous().t()
scale = scale.squeeze(axis)
return weight, scale
return weight.contiguous().cuda(tp_rank), scale.contiguous().cuda(tp_rank)


def test_int8(M, K, N):
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/internlm_wquant/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, kvargs):

def _verify_params(self):
assert self.load_way in ["HF", "DS"], "llama only supports HF and DS format to load Now!"
assert any("int8weight" in mode_ or "int4weight" in mode_ for mode_ in self.mode), "only for weight quant model"
assert any("w4a16" in mode_ or "w8a16" in mode_ for mode_ in self.mode), "only for weight quant model"
assert self.config["num_key_value_heads"] % self.world_size_ == 0
assert self.config["num_attention_heads"] % self.world_size_ == 0
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ def _bind_func(self):
return

def _bind_norm(self):
if "ppl_int8_activation_weight" in self.mode:
if "ppl_w8a8" in self.mode:
self._awquant_att_norm = partial(LlamaTransformerLayerInferAWquant._awquant_att_norm_ppl_int8, self)
self._awquant_ffn_norm = partial(LlamaTransformerLayerInferAWquant._awquant_ffn_norm_ppl_int8, self)
else:
raise Exception(f"error mode {self.mode}")
return

def _bind_matmul(self):
if "ppl_int8_activation_weight" in self.mode:
if "ppl_w8a8" in self.mode:
self._awquant_matmul_for_qkv = partial(
LlamaTransformerLayerInferAWquant._awquant_matmul_ppl_int8_quant_dequant, self
)
Expand All @@ -70,13 +70,13 @@ def _bind_matmul(self):
LlamaTransformerLayerInferAWquant._awquant_matmul_ppl_int8_quant_dequant, self
)
if self.tp_rank_ == 0 and self.layer_num_ == 0:
print("model use ppl_int8_activation_weight kernel")
print("model use ppl_w8a8 kernel")
else:
raise Exception(f"error mode {self.mode}")
return

def _bind_silu(self):
if "ppl_int8_activation_weight" in self.mode:
if "ppl_w8a8" in self.mode:
func = partial(LlamaTransformerLayerInferAWquant._awquant_silu_ppl_int8, self)
self._awquant_silu = func
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo
self.init_quant_mode()

def init_quant_mode(self):
if "ppl_int8_activation_weight" in self.mode:
if "ppl_w8a8" in self.mode:
self.quantize_weight = partial(dynamic_channelwise_quant_fp16_i8_ppl, tp_rank=self.tp_rank_)
else:
raise Exception(f"error mode {self.mode}")

def load_hf_weights(self, weights):
self._load_qkvo_weights(weights)
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/llama_awquant/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, kvargs):

def _verify_params(self):
assert self.load_way in ["HF", "DS"], "llama only supports HF and DS format to load Now!"
assert any("int8_activation_weight" in mode_ or "int4_activation_weight" in mode_ for mode_ in self.mode), "only for weight quant model"
assert any("w8a8" in mode_ for mode_ in self.mode), "only for weight-activation quant model"
assert self.config["num_key_value_heads"] % self.world_size_ == 0
assert self.config["num_attention_heads"] % self.world_size_ == 0
return
Expand Down
24 changes: 10 additions & 14 deletions lightllm/models/llama_wquant/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,38 +51,38 @@ def _bind_func(self):
return

def _bind_matmul(self):
if "triton_int8weight" in self.mode:
if "triton_w8a16" in self.mode:
func = partial(LlamaTransformerLayerInferWquant._wquant_matmul_triton_int8weight_only_quant, self)
self._wquant_matmul_for_qkv = func
self._wquant_matmul_for_o = func
self._wquant_matmul_for_ffn_up = func
self._wquant_matmul_for_ffn_down = func
if self.tp_rank_ == 0 and self.layer_num_ == 0:
logger.info("model use triton_int8weight kernel")
elif "triton_int4weight" in self.mode:
logger.info("model use triton_w8a16 kernel")
elif "triton_w4a16" in self.mode:
func = partial(LlamaTransformerLayerInferWquant._wquant_matmul_triton_int4weight_only_quant, self)
self._wquant_matmul_for_qkv = func
self._wquant_matmul_for_o = func
self._wquant_matmul_for_ffn_up = func
self._wquant_matmul_for_ffn_down = func
if self.tp_rank_ == 0 and self.layer_num_ == 0:
logger.info("model use triton_int4weight kernel")
elif "lmdeploy_int4weight" in self.mode:
logger.info("model use triton_w4a16 kernel")
elif "lmdeploy_w4a16" in self.mode:
func = partial(LlamaTransformerLayerInferWquant._wquant_matmul_lmdeploy_int4weight_only_quant, self)
self._wquant_matmul_for_qkv = func
self._wquant_matmul_for_o = func
self._wquant_matmul_for_ffn_up = func
self._wquant_matmul_for_ffn_down = func
if self.tp_rank_ == 0 and self.layer_num_ == 0:
logger.info("model use lmdeploy_int4weight kernel")
elif "ppl_int4weight" in self.mode:
logger.info("model use lmdeploy_w4a16 kernel")
elif "ppl_w4a16" in self.mode:
func = partial(LlamaTransformerLayerInferWquant._wquant_matmul_ppl_int4weight_only_quant, self)
self._wquant_matmul_for_qkv = func
self._wquant_matmul_for_o = func
self._wquant_matmul_for_ffn_up = func
self._wquant_matmul_for_ffn_down = func
if self.tp_rank_ == 0 and self.layer_num_ == 0:
logger.info("model use ppl_int4weight kernel")
logger.info("model use ppl_w4a16 kernel")
else:
raise Exception(f"error mode {self.mode}")
return
Expand Down Expand Up @@ -132,12 +132,8 @@ def _wquant_matmul_triton_int8weight_only_quant(
self, input, quant_weight_params, infer_state: LlamaInferStateInfo, out=None, bias=None, has_act=False
):
assert has_act is False
if infer_state.is_splitfuse is False and infer_state.is_prefill:
qweight, scale = quant_weight_params
out = matmul_dequantize_int8(input, qweight, scale, out=out)
else:
qweight, scale = quant_weight_params
out = matmul_quantize_int8(input, qweight, scale, out=out)
qweight, scale = quant_weight_params
out = matmul_dequantize_int8(input, qweight, scale, out=out)
if bias is None:
return out
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import partial

from lightllm.common.basemodel import TransformerLayerWeight
from lightllm.common.basemodel.triton_kernel.quantize_gemm_int8 import quantize_int8
from lightllm.common.basemodel.triton_kernel.dequantize_gemm_int8 import quantize_int8
from lightllm.common.basemodel.triton_kernel.dequantize_gemm_int4 import quantize_int4
from lightllm.common.basemodel.cuda_kernel.lmdeploy_wquant import quantize_int4_lmdeploy
from lightllm.common.basemodel.cuda_kernel.ppl_wquant import quantize_int4_ppl
Expand All @@ -17,28 +17,30 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo
self.init_quant_mode()

def init_quant_mode(self):
if "triton_int8weight" in self.mode:
if "triton_w8a16" in self.mode:
self.quantize_weight = partial(quantize_int8, tp_rank=self.tp_rank_)
if "triton_int4weight" in self.mode:
elif "triton_w4a16" in self.mode:
self.int4_q_group_size = 128
for _mode in self.mode:
if _mode.startswith("g"):
self.int4_q_group_size = int(_mode[1:])
self.quantize_weight = partial(quantize_int4, group_size=self.int4_q_group_size, tp_rank=self.tp_rank_)
if "lmdeploy_int4weight" in self.mode:
elif "lmdeploy_w4a16" in self.mode:
self.int4_q_group_size = 128
for _mode in self.mode:
if _mode.startswith("g"):
self.int4_q_group_size = int(_mode[1:])
self.quantize_weight = partial(
quantize_int4_lmdeploy, group_size=self.int4_q_group_size, tp_rank=self.tp_rank_
)
if "ppl_int4weight" in self.mode:
elif "ppl_w4a16" in self.mode:
self.int4_q_group_size = 128
for _mode in self.mode:
if _mode.startswith("g"):
self.int4_q_group_size = int(_mode[1:])
self.quantize_weight = partial(quantize_int4_ppl, group_size=self.int4_q_group_size, tp_rank=self.tp_rank_)
else:
raise Exception(f"error mode {self.mode}")

def load_hf_weights(self, weights):
self._load_qkvo_weights(weights)
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/llama_wquant/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, kvargs):

def _verify_params(self):
assert self.load_way in ["HF", "DS"], "llama only supports HF and DS format to load Now!"
assert any("int8weight" in mode_ or "int4weight" in mode_ for mode_ in self.mode), "only for weight quant model"
assert any("w4a16" in mode_ or "w8a16" in mode_ for mode_ in self.mode), "only for weight quant model"
assert self.config["num_key_value_heads"] % self.world_size_ == 0
assert self.config["num_attention_heads"] % self.world_size_ == 0
return
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/qwen_wquant/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, kvargs):

def _verify_params(self):
super()._verify_params()
assert any("int8weight" in mode_ or "int4weight" in mode_ for mode_ in self.mode), "only for weight quant model"
assert any("w4a16" in mode_ or "w8a16" in mode_ for mode_ in self.mode), "only for weight quant model"
assert self.config["num_key_value_heads"] % self.world_size_ == 0
assert self.config["num_attention_heads"] % self.world_size_ == 0
return
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/starcoder_wquant/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, kvargs):

def _verify_params(self):
assert self.load_way in ["HF", "DS"], "llama only supports HF and DS format to load Now!"
assert any("int8weight" in mode_ or "int4weight" in mode_ for mode_ in self.mode), "only for weight quant model"
assert any("w4a16" in mode_ or "w8a16" in mode_ for mode_ in self.mode), "only for weight quant model"
return

def _init_mem_manager(self):
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def main():
parser.add_argument("--mode", type=str, default=[], nargs='+',
help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding
| triton_gqa_attention | triton_gqa_flashdecoding]
[triton_int8weight | triton_int4weight | lmdeploy_int4weight | ppl_int4weight],
[triton_w4a16 | triton_w8a16 | lmdeploy_w4a16 | ppl_w4a16 | ppl_w8a8],
triton_flashdecoding mode is for long context, current support llama llama2 qwen;
triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA;
triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel;
Expand Down
10 changes: 5 additions & 5 deletions lightllm/server/router/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,17 @@ def exposed_init_model(self, kvargs):
if self.model_type == "bloom":
self.model = BloomTpPartModel(model_kvargs)
elif self.model_type == "llama":
if any('int8weight' in mode_ or 'int4weight' in mode_ for mode_ in self.mode):
if any('w4a16' in mode_ or 'w8a16' in mode_ for mode_ in self.mode):
self.model = LlamaTpPartModelWQuant(model_kvargs)
elif any('int8_activation_weight' in mode_ for mode_ in self.mode):
elif any('w8a8' in mode_ for mode_ in self.mode):
self.model = LlamaTpPartModelAWQuant(model_kvargs)
else:
self.model = LlamaTpPartModel(model_kvargs)
elif self.model_type == "qwen":
if "visual" in model_cfg:
self.model = QWenVLTpPartModel(model_kvargs)
self.is_multimodal = True
elif any('int8weight' in mode_ or 'int4weight' in mode_ for mode_ in self.mode):
elif any('w8a16' in mode_ or 'w4a16' in mode_ for mode_ in self.mode):
self.model = QWenTpPartModelWQuant(model_kvargs)
else:
self.model = QWenTpPartModel(model_kvargs)
Expand All @@ -120,14 +120,14 @@ def exposed_init_model(self, kvargs):
else:
raise Exception('can not support baichuan format')
elif self.model_type == 'gpt_bigcode':
if any('int8weight' in mode_ or 'int4weight' in mode_ for mode_ in self.mode):
if any('w8a16' in mode_ or 'w4a16' in mode_ for mode_ in self.mode):
self.model = StarcoderTpPartModelWQuant(model_kvargs)
else:
self.model = StarcoderTpPartModel(model_kvargs)
elif self.model_type == 'chatglm':
self.model = ChatGlm2TpPartModel(model_kvargs)
elif self.model_type == 'internlm' or self.model_type == 'internlm2':
if any('int8weight' in mode_ or 'int4weight' in mode_ for mode_ in self.mode):
if any('w8a16' in mode_ or 'w4a16' in mode_ for mode_ in self.mode):
self.model = InternlmTpPartModelWQuant(model_kvargs)
else:
if model_cfg["architectures"][0] == 'InternLM2ForCausalLM':
Expand Down

0 comments on commit 982e802

Please sign in to comment.