Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Misc ] Support Act Order in Compressed Tensors #6358

Open
wants to merge 147 commits into
base: main
Choose a base branch
from

Conversation

robertgshaw2-neuralmagic
Copy link
Collaborator

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic commented Jul 12, 2024

Summary

Add support for compressed-tensors models which have been quantized using activation ordering (group-wise quantization in decreasing order of activation).

  • add actorder argument to CompressedTensorsWNA16
  • add weight_g_idx layer parameter

Evaluation

Accuracy

Full Precision

vllm (pretrained=Qwen/Qwen2-0.5B-Instruct,add_bos_token=True), gen_kwargs: (None), limit: 250.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|?  |0.384|?  |0.0308|
|     |       |strict-match    |     5|exact_match|?  |0.384|?  |0.0308|

Group Quantization Only (ksayers/gwen_group)

vllm (pretrained=kylesayrs/gwen_group,add_bos_token=True), gen_kwargs: (None), limit: 250.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|?  |0.216|?  |0.0261|
|     |       |strict-match    |     5|exact_match|?  |0.196|?  |0.0252|

Activation Ordering (ksayers/gwen_actorder)

vllm (pretrained=ksayers/gwen_actorder,add_bos_token=True), gen_kwargs: (None), limit: 250.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|                             
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|                             
|gsm8k|      3|flexible-extract|     5|exact_match|?  |0.248|?  |0.0274|                             
|     |       |strict-match    |     5|exact_match|?  |0.248|?  |0.0274|

Latency Regression

Namespace(model=‘/home/ksayers/llm-compressor/gwen_actorder/’, speculative_model=None, num_speculative_tokens=None, speculative_draft_tensor_parallel_size=None, tokenizer=None, quantization=None, tensor_parallel_size=1, input_len=32, output_len=128, batch_size=32, n=1, use_beam_search=False, num_iters_warmup=10, num_iters=30, trust_remote_code=False, max_model_len=None, dtype=‘auto’, enforce_eager=False, kv_cache_dtype=‘auto’, quantization_param_path=None, profile=False, profile_result_dir=None, device=‘auto’, block_size=16, enable_chunked_prefill=False, enable_prefix_caching=False, use_v2_block_manager=False, ray_workers_use_nsight=False, download_dir=None, output_json=None, gpu_memory_utilization=0.9, load_format=‘auto’, distributed_executor_backend=None, otlp_traces_endpoint=None)

Group Quantization Only

Avg latency: 0.8884373404396076 seconds
10% percentile latency: 0.8715801022946834 seconds
25% percentile latency: 0.8739993472117931 seconds
50% percentile latency: 0.876951577141881 seconds
75% percentile latency: 0.8830150356516242 seconds
90% percentile latency: 0.9393035409972071 seconds
99% percentile latency: 0.9404808702412992 seconds

Activation Ordering

Avg latency: 0.9159474782645702 seconds
10% percentile latency: 0.9001966264098883 seconds
25% percentile latency: 0.9010569080710411 seconds
50% percentile latency: 0.9041027296334505 seconds
75% percentile latency: 0.9064613012596965 seconds
90% percentile latency: 0.9662564094178379 seconds
99% percentile latency: 0.9761117453686893 seconds

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for changes, LGTM with model smoke test

@alexm-neuralmagic
Copy link
Collaborator

LGTM

Comment on lines 130 to 133
# G_IDX (for activation reordering)
g_idx = BasevLLMParameter(data=torch.empty(input_size_per_partition,
dtype=torch.int32),
weight_loader=weight_loader)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it okay to make this parameter in every case? What about older checkpoints that don't have this parameter?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gptq_marlin_gemm supports passing an empty tensor for g_idx, I'd prefer to that or a nullptr to avoid excess memory usage

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my question was worded weirdly, sorry. I am just concerned about the weight loader trying to find this parameter in the checkpoint, and it not being present.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I regression tested using neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin without issue

Copy link
Contributor

@dsikka dsikka Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd just update to only create the parameter if self.actorder is True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because the g_idx passed to the kernel is conditional on the actorder flag in the config
https://github.com/vllm-project/vllm/pull/6358/files#diff-df5f822218e5ac1430f35a806bc9cebd78c99cfe1e6738de89ae3e9f5a1fdbecR162

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If self.actorder is True, it'll use the created parameter. Otherwise, it'll create an empty one. So I dont think you need to initialize it here if self.actorder is False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that seems to be the case from this else-case later - so no need to make the parameter

            layer.weight_g_idx = marlin_make_empty_g_idx(device)
            layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)

@kylesayrs
Copy link
Contributor

Do not merge, tensor parallel bug needs to be fixed

@kylesayrs
Copy link
Contributor

False alarm on tensor parallelism bug. Regression testing was performed with TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T with and without activation ordering and tensor_parallel_size=2

@kylesayrs
Copy link
Contributor

Moving to draft while support for static_grouping actorder is added

@kylesayrs
Copy link
Contributor

Actually will make a separate PR to address static_grouping feature

Comment on lines 130 to 133
# G_IDX (for activation reordering)
g_idx = BasevLLMParameter(data=torch.empty(input_size_per_partition,
dtype=torch.int32),
weight_loader=weight_loader)
Copy link
Contributor

@dsikka dsikka Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd just update to only create the parameter if self.actorder is True

@@ -119,14 +127,21 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
dtype=torch.int64),
weight_loader=weight_loader)

# G_IDX (for activation reordering)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test case for this case?

  1. tests/quantization/test_compressed_tensors.py
  2. add a model to models.txt under tests/weight_loading

Comment on lines 130 to 133
# G_IDX (for activation reordering)
g_idx = BasevLLMParameter(data=torch.empty(input_size_per_partition,
dtype=torch.int32),
weight_loader=weight_loader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because the g_idx passed to the kernel is conditional on the actorder flag in the config
https://github.com/vllm-project/vllm/pull/6358/files#diff-df5f822218e5ac1430f35a806bc9cebd78c99cfe1e6738de89ae3e9f5a1fdbecR162

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 29, 2024
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once the remaining issues are addressed

@kylesayrs
Copy link
Contributor

kylesayrs commented Sep 1, 2024

New requirements have been added to act-order to support different strategies such as weight-only ordering and group ordering. See neuralmagic/compressed-tensors#146

I've made a PR I'd like to merge into this branch which conditions activation ordering on g-idx directly rather than relying on the config. See neuralmagic#405

@kylesayrs
Copy link
Contributor

Moved to #8135

@simon-mo simon-mo requested a review from youkaichao as a code owner November 26, 2024 05:49
Copy link

mergify bot commented Nov 26, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @robertgshaw2-neuralmagic.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-rebase ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants