Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Misc ] Support Act Order in Compressed Tensors #6358

Open
wants to merge 147 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
1dfc42d
added
Jun 26, 2024
aa4a9f5
nits
Jun 26, 2024
27f9a03
cleanup
Jun 26, 2024
de7a064
stash
Jun 27, 2024
ec6a833
refactor gptq marlin
robertgshaw2-redhat Jul 3, 2024
966f7be
back out w4a16 act-order compressed tensors
robertgshaw2-redhat Jul 3, 2024
d391f44
back out w4a16 act-order compressed tensors
robertgshaw2-redhat Jul 3, 2024
db075c3
missed
robertgshaw2-redhat Jul 3, 2024
695dc05
formatted'
robertgshaw2-redhat Jul 3, 2024
75c8a11
fix models without gidx
robertgshaw2-redhat Jul 3, 2024
525cf08
format
robertgshaw2-redhat Jul 3, 2024
81f028e
fix test failure
robertgshaw2-redhat Jul 3, 2024
a8fbe89
fix perms not being on gpu
robertgshaw2-redhat Jul 3, 2024
cc843ad
stash
robertgshaw2-redhat Jul 3, 2024
b260c90
stage
robertgshaw2-redhat Jul 3, 2024
c8e97b1
updated
robertgshaw2-redhat Jul 3, 2024
e58063d
nit
robertgshaw2-redhat Jul 3, 2024
383e471
added
robertgshaw2-redhat Jul 3, 2024
865b743
format
robertgshaw2-redhat Jul 3, 2024
9c24525
newline
robertgshaw2-redhat Jul 3, 2024
8b5ac5a
formatting
robertgshaw2-redhat Jul 3, 2024
0e46e4b
working
robertgshaw2-redhat Jul 3, 2024
a47a251
added compressed tensors fp8 to automation
robertgshaw2-redhat Jul 3, 2024
c6be536
missed file
robertgshaw2-redhat Jul 3, 2024
0441171
format
robertgshaw2-redhat Jul 3, 2024
d404f00
remove unnecessary file changes
robertgshaw2-redhat Jul 3, 2024
6569323
restructure quant ops
robertgshaw2-redhat Jul 3, 2024
aa56475
updated to transpose in process_after_loading
robertgshaw2-redhat Jul 3, 2024
d94d07e
updated with varuns suggestion
robertgshaw2-redhat Jul 3, 2024
54308d7
fixed nit
robertgshaw2-redhat Jul 3, 2024
173b93b
name change
robertgshaw2-redhat Jul 3, 2024
afa1ee1
format
robertgshaw2-redhat Jul 3, 2024
5ffe0e4
fixed
robertgshaw2-redhat Jul 3, 2024
4c0e565
fixed tests
robertgshaw2-redhat Jul 3, 2024
ee58d33
Merge branch 'unify-w8a8' into compressed-tensors-fp8
robertgshaw2-redhat Jul 3, 2024
282a038
merge w8a8 unify
robertgshaw2-redhat Jul 3, 2024
a0fd035
fix nit
robertgshaw2-redhat Jul 3, 2024
ba1116b
nits
robertgshaw2-redhat Jul 3, 2024
c1d4375
cleanup
robertgshaw2-redhat Jul 3, 2024
a12bfd5
stash
robertgshaw2-redhat Jul 6, 2024
6aad8f6
Merge branch 'main' into compressed-tensors-fp8
robertgshaw2-redhat Jul 6, 2024
4fc0177
autofp8 working
robertgshaw2-redhat Jul 6, 2024
1d99867
stash
robertgshaw2-redhat Jul 6, 2024
ccee126
stash
robertgshaw2-redhat Jul 6, 2024
0969c67
format
robertgshaw2-redhat Jul 6, 2024
b2eeb84
fix imported marlin_permute_scales
robertgshaw2-redhat Jul 6, 2024
9316f92
format
robertgshaw2-redhat Jul 7, 2024
4ff23c8
added w8a8 to correctness testing
robertgshaw2-redhat Jul 7, 2024
08a8e4e
added testing
robertgshaw2-redhat Jul 7, 2024
4238ac9
format
robertgshaw2-redhat Jul 7, 2024
d1c7517
merged
robertgshaw2-redhat Jul 7, 2024
94d6b35
stash
robertgshaw2-redhat Jul 7, 2024
d48ba9d
readded
robertgshaw2-redhat Jul 7, 2024
0dd2c6a
remove nm-vllm-env
robertgshaw2-redhat Jul 7, 2024
29f40f5
remove old qwen2 moe
robertgshaw2-redhat Jul 7, 2024
ad17c88
readded utils
robertgshaw2-redhat Jul 7, 2024
fd7d825
format
robertgshaw2-redhat Jul 7, 2024
697edfa
Update models-small.txt
robertgshaw2-redhat Jul 7, 2024
e30bd57
gptq marlin tests passing
robertgshaw2-redhat Jul 7, 2024
382d230
add missing files
robertgshaw2-redhat Jul 7, 2024
ba4c7b3
refactoring in progress
robertgshaw2-redhat Jul 7, 2024
0916182
Update models-small.txt
robertgshaw2-redhat Jul 7, 2024
de0242f
stash
robertgshaw2-redhat Jul 7, 2024
9fe4fce
removed lm-eval
robertgshaw2-redhat Jul 7, 2024
c044a86
stash
robertgshaw2-redhat Jul 7, 2024
a5f0aee
remove run
robertgshaw2-redhat Jul 7, 2024
d3299f8
Merge branch 'main' into compressed-tensors-fp8
robertgshaw2-redhat Jul 7, 2024
bcfcd38
added integration test for compressed-tensors-w4-a16
robertgshaw2-redhat Jul 7, 2024
763ab2c
formatting
robertgshaw2-redhat Jul 7, 2024
950de45
Merge branch 'compressed-tensors-fp8' into refactor-gptq-marlin
robertgshaw2-redhat Jul 7, 2024
eb2fdfa
removed
robertgshaw2-redhat Jul 7, 2024
2f49425
Merge branch 'refactor-gptq-marlin' of https://github.com/neuralmagic…
robertgshaw2-redhat Jul 7, 2024
93812eb
add comment
robertgshaw2-redhat Jul 7, 2024
d4b25cf
Update w8a8_utils.py
robertgshaw2-redhat Jul 7, 2024
48b220e
Update w8a8_utils.py
robertgshaw2-redhat Jul 7, 2024
f1d8ee4
cleanup unnessary changes
robertgshaw2-redhat Jul 7, 2024
cfe27be
Merge branch 'refactor-gptq-marlin' of https://github.com/neuralmagic…
robertgshaw2-redhat Jul 7, 2024
72b9368
fix gptq marlin
robertgshaw2-redhat Jul 7, 2024
73ae598
formatting
robertgshaw2-redhat Jul 7, 2024
f854c54
cleanup
robertgshaw2-redhat Jul 7, 2024
13d4e93
Merge branch 'main' into refactor-gptq-marlin
robertgshaw2-redhat Jul 7, 2024
4e09688
Update benchmark_marlin.py
robertgshaw2-redhat Jul 7, 2024
db694e0
Update compressed_tensors_wNa16.py
robertgshaw2-redhat Jul 7, 2024
4b2dba2
Update marlin_utils_test.py
robertgshaw2-redhat Jul 7, 2024
9d8d12f
Update test_marlin_gemm.py
robertgshaw2-redhat Jul 7, 2024
54cf4f2
format
robertgshaw2-redhat Jul 7, 2024
7abc2b1
Merge branch 'refactor-gptq-marlin' of https://github.com/neuralmagic…
robertgshaw2-redhat Jul 7, 2024
ed178d4
formatting
robertgshaw2-redhat Jul 7, 2024
03b11b2
more formatting
robertgshaw2-redhat Jul 7, 2024
e2a5e7a
fix
robertgshaw2-redhat Jul 7, 2024
6f62ada
yapf
robertgshaw2-redhat Jul 7, 2024
933bec3
fixed failing tests
robertgshaw2-redhat Jul 8, 2024
fe6ae88
tweak scores
robertgshaw2-redhat Jul 8, 2024
8285ef6
tweak scores
robertgshaw2-redhat Jul 8, 2024
fcc8925
stash
robertgshaw2-redhat Jul 9, 2024
c0b5d13
format
robertgshaw2-redhat Jul 9, 2024
f6910a5
seems to still be working
robertgshaw2-redhat Jul 9, 2024
84ed30f
stash
robertgshaw2-redhat Jul 11, 2024
62368af
added tests
robertgshaw2-redhat Jul 11, 2024
b618961
seems to be working!
robertgshaw2-redhat Jul 12, 2024
f2755f2
Update build.sh
robertgshaw2-redhat Jul 12, 2024
cd392f5
Merge branch 'main' into act-order
robertgshaw2-redhat Jul 12, 2024
b092079
Merge branch 'act-order' of https://github.com/neuralmagic/nm-vllm in…
robertgshaw2-redhat Jul 12, 2024
5cbed16
cleanup bad merge
robertgshaw2-redhat Jul 12, 2024
054e2db
removed files that should not have been added
robertgshaw2-redhat Jul 12, 2024
7e0b0ec
Update run-lm-eval-gsm-vllm-baseline.sh
robertgshaw2-redhat Jul 12, 2024
bddf9d3
Update test_compressed_tensors.py
robertgshaw2-redhat Jul 12, 2024
ad43c4e
undo
robertgshaw2-redhat Jul 12, 2024
0aa9181
undo bad merge
robertgshaw2-redhat Jul 12, 2024
777e74b
last undo?
robertgshaw2-redhat Jul 12, 2024
77988d3
twas not last
robertgshaw2-redhat Jul 12, 2024
39ed988
cleanup
robertgshaw2-redhat Jul 12, 2024
7d2fff8
stash
robertgshaw2-redhat Jul 12, 2024
2e74b0b
remove more
robertgshaw2-redhat Jul 12, 2024
a845475
fix
robertgshaw2-redhat Jul 12, 2024
2e7bf61
format
robertgshaw2-redhat Jul 12, 2024
18596e2
format
robertgshaw2-redhat Jul 12, 2024
48aae94
more cleanup
robertgshaw2-redhat Jul 12, 2024
b34ca83
undo changes to gptq marlin
robertgshaw2-redhat Jul 12, 2024
881afd7
another nit
robertgshaw2-redhat Jul 12, 2024
3cd8b55
another nit
robertgshaw2-redhat Jul 12, 2024
02637af
final bad merge?
robertgshaw2-redhat Jul 12, 2024
81f41ed
last bad merge?
robertgshaw2-redhat Jul 12, 2024
536fdde
cleanup
robertgshaw2-redhat Jul 12, 2024
1d10244
stopping point
robertgshaw2-redhat Jul 12, 2024
4c96377
stash
robertgshaw2-redhat Jul 12, 2024
4ca4a08
updated
robertgshaw2-redhat Jul 19, 2024
1080488
Merge branch 'main' into act-order
robertgshaw2-redhat Jul 22, 2024
8531380
Merge branch 'act-order' of https://github.com/neuralmagic/nm-vllm in…
robertgshaw2-redhat Jul 22, 2024
0ddd524
updated to have a defualt
robertgshaw2-redhat Jul 22, 2024
052cc93
switch order of arguments
robertgshaw2-redhat Jul 22, 2024
6211660
switch everything to actorder from act_order
robertgshaw2-redhat Jul 22, 2024
a0d0251
more cleanup
robertgshaw2-redhat Jul 22, 2024
f187922
more name change
robertgshaw2-redhat Jul 22, 2024
434b471
merge in main
kylesayrs Aug 17, 2024
04ed5d7
reorder for better diff
kylesayrs Aug 17, 2024
07ad850
remove doubled variables, fix shape for marlin_permute_scales
kylesayrs Aug 17, 2024
d2a923a
merge in main
kylesayrs Aug 17, 2024
22de619
merge in main
kylesayrs Aug 17, 2024
fb8ffb2
use BasevLLMParameter
kylesayrs Aug 17, 2024
3bb7294
apply style
kylesayrs Aug 17, 2024
0e396fc
documentation
kylesayrs Aug 17, 2024
14495ba
use layer.group_size
kylesayrs Aug 18, 2024
2f46596
add warning
kylesayrs Aug 29, 2024
ef08596
Merge remote-tracking branch 'upstream/main' into act-order
kylesayrs Aug 30, 2024
22e579e
Merge remote-tracking branch 'upstream/main' into act-order
kylesayrs Sep 1, 2024
cc2c9ab
Group Index Conditioning (#405)
kylesayrs Sep 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.776
value: 0.752
- name: "exact_match,flexible-extract"
value: 0.776
value: 0.752
limit: 250
num_fewshot: 5
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.744
value: 0.728
- name: "exact_match,flexible-extract"
value: 0.744
value: 0.728
limit: 250
num_fewshot: 5
16 changes: 1 addition & 15 deletions benchmarks/kernels/benchmark_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
marlin_quantize, MarlinWorkspace)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand All @@ -26,20 +26,6 @@
K_FULL_OPTS = [False, True]


class MarlinWorkspace:

def __init__(self, out_features, min_thread_n, max_parallel):
assert (out_features % min_thread_n == 0), (
"out_features = {} is undivisible by min_thread_n = {}".format(
out_features, min_thread_n))

max_workspace_size = ((out_features // min_thread_n) * max_parallel)

self.scratch = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda")


def bench_run(results: List[benchmark.Measurement], model: str,
act_order: bool, is_k_full: bool, num_bits: int, group_size: int,
size_m: int, size_k: int, size_n: int):
Expand Down
16 changes: 1 addition & 15 deletions tests/kernels/test_marlin_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
pack_fp8_to_int32)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
get_weight_perm, marlin_quantize, marlin_weights)
get_weight_perm, marlin_quantize, marlin_weights, MarlinWorkspace)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand All @@ -44,20 +44,6 @@
DTYPES = [torch.float16, torch.bfloat16]


class MarlinWorkspace:

def __init__(self, out_features, min_thread_n, max_parallel):
assert (out_features % min_thread_n == 0), (
"out_features = {} is undivisible by min_thread_n = {}".format(
out_features, min_thread_n))

max_workspace_size = ((out_features // min_thread_n) * max_parallel)

self.scratch = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda")


def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
output_size_per_partition = sum(output_partition_sizes)

# If group_size is -1, we are in channelwise case.
if self.group_size is None:
raise ValueError("Gr")
elif self.group_size == -1:
if self.group_size == -1:
group_size = input_size
else:
group_size = self.group_size
Expand Down
12 changes: 12 additions & 0 deletions vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales
from .quant_utils import get_pack_factor, quantize_weights, sort_weights

class MarlinWorkspace:
def __init__(self, out_features, min_thread_n, max_parallel):
assert (out_features % min_thread_n == 0), (
"out_features = {} is undivisible by min_thread_n = {}".format(
out_features, min_thread_n))

max_workspace_size = ((out_features // min_thread_n) * max_parallel)

self.scratch = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda")


def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
assert q_w.shape == (size_k, size_n)
Expand Down