Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUANT] Add GPTQModel Dynamic Quantization + lm_head Quantization #3790

Merged
Merged
Changes from 1 commit
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
dab77cf
Changed VocabParallelEmbedding.linear_method to quant_method to be co…
ZX-ModelCloud Feb 22, 2025
9c097bd
call param.packed_factor instead of param.pack_factor
ZX-ModelCloud Feb 22, 2025
4e26757
add monkey_patch_vllm_get_linear_quant_method()
ZX-ModelCloud Feb 22, 2025
c52612c
pass prefix argument
ZX-ModelCloud Feb 22, 2025
84630d8
fix gptq_marlin error
ZX-ModelCloud Feb 22, 2025
7ad7159
cleanup
ZX-ModelCloud Feb 22, 2025
c870c5f
add prefix
ZX-ModelCloud Feb 22, 2025
7f3ffa0
add prefix
ZX-ModelCloud Feb 22, 2025
29a0e2a
use clearer api name and re-order args
Qubitium Feb 22, 2025
a143440
format
Qubitium Feb 22, 2025
82461e5
move import to top
Qubitium Feb 22, 2025
e726f1b
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Feb 22, 2025
0d5a66d
reduce vllm depend: move dynamic config extraction method to sglang
Qubitium Feb 22, 2025
29518ba
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Feb 22, 2025
cedd221
add unittest
ZX-ModelCloud Feb 25, 2025
3f64919
update unittest
ZX-ModelCloud Feb 25, 2025
cd06ba8
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Feb 25, 2025
0085065
code format
ZX-ModelCloud Feb 25, 2025
7542b80
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Feb 27, 2025
9c73721
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Feb 28, 2025
5f72987
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Feb 28, 2025
a3b5811
add gptqmodel tests to run_suite.py
Qubitium Mar 1, 2025
498bb0f
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 1, 2025
bd26863
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 1, 2025
21842dc
Update quantization.md
Qubitium Mar 1, 2025
b9b7f9b
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 1, 2025
21cf4e5
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 2, 2025
3abe3c2
format
Qubitium Mar 2, 2025
bc4a63e
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 2, 2025
caaeaf0
remove vllm depends
Qubitium Mar 3, 2025
cf3ef86
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 3, 2025
b60c657
remove more vllm 0.7.3 specific depend
Qubitium Mar 3, 2025
4aa0c5a
Merge branch 'compat_gptqmodel_dynamic' of https://github.com/ZX-Mode…
Qubitium Mar 3, 2025
585f65a
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 3, 2025
eb3f6b3
all prefix code use add_prefix
Qubitium Mar 3, 2025
131e055
Merge branch 'compat_gptqmodel_dynamic' of https://github.com/ZX-Mode…
Qubitium Mar 3, 2025
fdd4ff3
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 3, 2025
d29fe7f
format
Qubitium Mar 3, 2025
95be5bb
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 3, 2025
ee8bbd5
simplify
Qubitium Mar 3, 2025
d31410e
assert output
ZX-ModelCloud Mar 3, 2025
f00b8de
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 3, 2025
ff5f364
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 3, 2025
adf7df3
fix ci
Qubitium Mar 3, 2025
d1d9eb7
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 3, 2025
4101ce9
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 4, 2025
ea4952e
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 4, 2025
a4c269f
try to fix circular imports from vllm
Qubitium Mar 4, 2025
1dd58c5
try (2): fix circular imports
Qubitium Mar 4, 2025
91c09cc
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 4, 2025
218e12b
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 4, 2025
937ca01
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 4, 2025
c2bba8d
format
Qubitium Mar 4, 2025
cef4e20
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 4, 2025
97f3ebc
Merge branch 'main' into compat_gptqmodel_dynamic
Qubitium Mar 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
reduce vllm depend: move dynamic config extraction method to sglang
  • Loading branch information
Qubitium committed Feb 22, 2025
commit 0d5a66d1c954dfede7a32c07cdafe25c76e8b239
62 changes: 57 additions & 5 deletions python/sglang/srt/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
import re
from copy import deepcopy
from typing import Callable, Dict, Optional, Type
from typing import Callable, Dict, Optional, Type, Union

import torch
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
Expand Down Expand Up @@ -60,16 +61,67 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
return QUANTIZATION_METHODS[quantization]


# Match dynamic rules with module name (prefix) and override quantize
# config if module (prefix) matches a rule
def override_config(config: QuantizationConfig, prefix: str):
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
if isinstance(weight_bits, int):
config.weight_bits = weight_bits
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
if isinstance(group_size, int):
config.group_size = group_size
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
if isinstance(desc_act, bool):
config.desc_act = desc_act

config.pack_factor = 32 // config.weight_bits # packed into int32
if config.get_name() == "gptq_marlin":
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
if isinstance(is_sym, bool):
config.is_sym = is_sym

if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
raise ValueError(
"Unsupported quantization config: "
f"bits={config.weight_bits}, sym={config.is_sym}"
)

config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
elif config.get_name() == "gptq":
if config.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {config.weight_bits} bits."
)


def get_dynamic_override(
config: QuantizationConfig,
layer_name: str,
key: Optional[str] = None,
default_value: Union[int, bool, None] = None,
) -> Union[Dict, int, bool, None]:
for pattern, pattern_dict in config.dynamic.items():
# Negative match: matched modules are excluded from quantized init
if pattern.startswith("-:"):
if re.match(pattern.removeprefix("-:"), layer_name):
return False
# Positive match: matched modules have quant properties overrides
# base quant config
elif re.match(pattern.removeprefix("+:"), layer_name):
if key is None:
return pattern_dict
else:
return pattern_dict.get(key, default_value)
return default_value


def get_linear_quant_method(
config: QuantizationConfig,
layer: torch.nn.Module,
prefix: str,
linear_method_cls: type,
):
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_dynamic_override,
override_config,
)

from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.vocab_parallel_embedding import (
Expand Down
Loading