Skip to content

Commit

Permalink
[Major] Add TinyChat support for GQA, faster context stage and memory…
Browse files Browse the repository at this point in the history
… efficient loading.
  • Loading branch information
kentang-mit committed Sep 18, 2023
1 parent a095b3e commit ca11f3e
Show file tree
Hide file tree
Showing 18 changed files with 500 additions and 496 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ pip install -e .
* For **edge devices** like Orin, before running the commands above, please:

1. Modify [pyproject.toml](pyproject.toml) by commenting out [this line](https://github.com/mit-han-lab/llm-awq/blob/3fce69061682fdd528824e5da3d03a8a8b545f2a/pyproject.toml#L17).
2. Manually install precompiled PyTorch binaries (>=2.0.0) from [NVIDIA](https://forums.developer.nvidia.com/t/pytorch-for-jetson/72048).
3. Set the appropriate Python version for conda environment (e.g., `conda create -n awq python=3.8 -y` for JetPack 5).
2. Set [this line](https://github.com/mit-han-lab/llm-awq/blob/3fce69061682fdd528824e5da3d03a8a8b545f2a/pyproject.toml#18) to transformers==4.32.0.
3. Manually install precompiled PyTorch binaries (>=2.0.0) from [NVIDIA](https://forums.developer.nvidia.com/t/pytorch-for-jetson/72048).
4. Set the appropriate Python version for conda environment (e.g., `conda create -n awq python=3.8 -y` for JetPack 5).

3. Install efficient W4A16 (4-bit weight, 16-bit activation) CUDA kernel and optimized FP16 kernels (e.g. layernorm, positional encodings).
```
Expand Down
2 changes: 1 addition & 1 deletion awq/kernels/csrc/quantization/gemm_cuda.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>

torch::Tensor gemm_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters);
torch::Tensor _scaling_factors, torch::Tensor _zeros, int group_size, int split_k_iters);
486 changes: 153 additions & 333 deletions awq/kernels/csrc/quantization/gemm_cuda_gen.cu

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions awq/quantize/qmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,11 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, ze
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, )
# out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.split_k_iters)
# print(x.shape, self.qweight.shape, self.scales.shape, self.qzeros.shape, self.group_size)
out = awq_inference_engine.gemv_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.group_size)
inputs = x.reshape(-1, x.shape[-1])
if inputs.shape[0] > 8:
out = awq_inference_engine.gemm_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size, self.split_k_iters)
else:
out = awq_inference_engine.gemv_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size)
out = out + self.bias if self.bias is not None else out
#print(out)
#assert 0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ classifiers = [
dependencies = [
"accelerate", "sentencepiece", "tokenizers>=0.12.1",
"torch>=2.0.0", "torchvision",
"transformers>=4.31.0",
"transformers>=4.32.0",
"lm_eval", "texttable",
"toml", "attributedict",
"protobuf"
Expand Down
62 changes: 36 additions & 26 deletions tinychat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

We introduce TinyChat, a cutting-edge chatbot interface designed for lightweight resource consumption and fast inference speed on GPU platforms. It allows for seamless deployment on consumer-level GPUs such as 3090/4090 and low-power edge devices like the NVIDIA Jetson Orin, empowering users with a responsive conversational experience like never before.



The current release supports:

- LLaMA-2-7B/13B-chat;
Expand All @@ -14,8 +12,6 @@ The current release supports:

- Falcon-instruct.



## Contents

- [Examples](#examples)
Expand All @@ -26,7 +22,6 @@ The current release supports:

- [Reference](#reference)


## Examples

Thanks to AWQ, TinyChat can now deliver more prompt responses through 4-bit inference. The following examples showcase that TinyChat's W4A16 generation is up to 3.7x faster on RTX 4090 and 3.3x faster on Jetson Orin, compared to the FP16 baselines. (Tested with [LLaMA-2-7b]( https://huggingface.co/meta-llama/Llama-2-7b-chat-hf ) model.)
Expand All @@ -39,7 +34,6 @@ Thanks to AWQ, TinyChat can now deliver more prompt responses through 4-bit infe

![TinyChat on Jetson Orin: W4A16 is 1.4x faster than FP16](./figures/orin_example.gif)


## Benchmarks

We benchmark TinyChat on A6000 (server-class GPU), 4090 (desktop GPU) and Orin (edge GPU).
Expand All @@ -52,38 +46,38 @@ The latency reported in all tables are per-token latency for the generation stag

| Model | FP16 latency (ms) | INT4 latency (ms) | Speedup |
| ----------- |:-----------------:|:-----------------:|:-------:|
| LLaMA-2-7B | 27.14 | 8.71 | 3.12x |
| LLaMA-2-13B | 47.28 | 14.64 | 3.23x |
| Vicuna-7B | 26.06 | 8.39 | 3.11x |
| Vicuna-13B | 44.91 | 13.46 | 3.34x |
| MPT-7B | 22.79 | 7.99 | 2.85x |
| MPT-30B | OOM | 28.15 | -- |
| Falcon-7B | 39.44 | 11.71 | 3.37x |
| LLaMA-2-7B | 27.14 | 8.71 | 3.12x |
| LLaMA-2-13B | 47.28 | 14.64 | 3.23x |
| Vicuna-7B | 26.06 | 8.39 | 3.11x |
| Vicuna-13B | 44.91 | 13.46 | 3.34x |
| MPT-7B | 22.79 | 7.99 | 2.85x |
| MPT-30B | OOM | 28.15 | -- |
| Falcon-7B | 39.44 | 11.71 | 3.37x |

### 4090 Results

| Model | FP16 latency (ms) | INT4 latency (ms) | Speedup |
| ----------- |:-----------------:|:-----------------:|:-------:|
| LLaMA-2-7B | 19.97 | 6.02* | 3.31x |
| LLaMA-2-7B | 19.97 | 6.02* | 3.31x |
| LLaMA-2-13B | OOM | 10.35 | -- |
| Vicuna-7B | 19.09 | 5.33 | 3.58x |
| Vicuna-13B | OOM | 9.17 | -- |
| MPT-7B | 17.09 | 6.18 | 2.77x |
| Vicuna-13B | OOM | 9.17 | -- |
| MPT-7B | 17.09 | 6.18 | 2.77x |
| MPT-30B | OOM | 20.60 | -- |
| Falcon-7B | 29.91 | 8.02 | 3.73x |
| Falcon-7B | 29.91 | 8.02 | 3.73x |

*: The reason why LLaMA-2-7B is slower than Vicuna-7B is because we need a longer prompt (with > 500 tokens) to prevent the model from talking with itself. If we use the benchmarking strategy from exLLaMA (i.e. only 4 context tokens), our speed is around 195 tokens / second.

### Orin Results

| Model | FP16 latency (ms) | INT4 latency (ms) | Speedup |
| ----------- |:-----------------:|:-----------------:|:-------:|
| LLaMA-2-7B | 104.71 | 33.07* | 3.17x |
| LLaMA-2-13B | OOM | 58.20 | -- |
| LLaMA-2-7B | 104.71 | 33.07* | 3.17x |
| LLaMA-2-13B | OOM | 58.20 | -- |
| Vicuna-7B | 93.12 | 30.73 | 3.03x |
| Vicuna-13B | OOM | 54.98 | -- |
| MPT-7B | 89.85 | 31.22 | 2.88x |
| Falcon-7B | 147.84 | 45.10 | 3.28x |
| Falcon-7B | 147.84 | 45.10 | 3.28x |

*: We can similarly achieve 33 tokens / second on Orin if we use the benchmarking strategy from exLLaMA.

Expand All @@ -101,9 +95,7 @@ The latency reported in all tables are per-token latency for the generation stag

- For Falcon-instruct, please refer to [this link](https://huggingface.co/tiiuae/falcon-7b-instruct).


3. Quantize instruction-tuned LLMs with AWQ:

- We provide pre-computed AWQ search results for multiple model families, including LLaMA, OPT, Vicuna, and LLaVA. To get the pre-computed AWQ search results, run:

```bash
Expand Down Expand Up @@ -156,12 +148,33 @@ python demo.py --model_type llama \
--precision W16A16
```

The above command works well for most cloud and desktop GPUs, since their CPU and GPU memory space are separated. However, for edge GPUs with shared host and device memory, in order to run larger models (e.g. LLaMA-2-70B on 64GB Orin), it is necessary to break down the pretrained checkpoints into small pieces:

```bash
python split_ckpt.py --input_path quant_cache/llama-2-7b-chat-w4-g128-awq.pt \
--output_path quant_cache/llama-2-7b-chat-w4-g128-awq
```

Then, to run the demo, one can use the following command. The only changes compared with the demo command above are:

- We modify the `load_quant` argument;

- We introduce another flag `mem_efficient_load`.

```bash
cd tinychat
python demo.py --model_type llama \
--model_path /PATH/TO/LLAMA2/llama-2-7b-chat \
--q_group_size 128 --load_quant quant_cache/llama-2-7b-chat-w4-g128-awq \
    --precision W4A16 --mem_efficient_load
```

5. (Optional) Run the benchmark script:

```bash
cd tinychat
python benchmark.py --model_type llama \
--model_path /PATH/TO/LLAMA2/llama-2-7b-chat \
--model_path /PATH/TO/LLAMA2/llama-2-7b-chat \
--q_group_size 128
```

Expand All @@ -171,6 +184,3 @@ Note: The kv caches in the current implementation are pre-allocated. So if you r

TinyChat is inspired by the following open-source projects: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer), [FlashAttention](https://github.com/Dao-AILab/flash-attention), [vLLM](https://github.com/vllm-project/vllm), [FastChat](https://github.com/lm-sys/FastChat), [llama_cu_awq](https://github.com/ankan-ban/llama_cu_awq).




12 changes: 6 additions & 6 deletions tinychat/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,10 @@ def main():
"--max_seq_len",
type=int,
default=2048,
help="maximum sequence length for kv cache"
help="maximum sequence length for kv cache",
)
parser.add_argument(
"--max_batch_size",
type=int,
default=1,
help="maximum batch size for kv cache"
"--max_batch_size", type=int, default=1, help="maximum batch size for kv cache"
)
args = parser.parse_args()

Expand Down Expand Up @@ -79,7 +76,10 @@ def main():
], "We only support llama & falcon & mpt now"
model = model_type_dict[args.model_type.lower()](config).half()
real_quantize_model_weight(
model, w_bit=4, q_config=dict(q_group_size=args.q_group_size, zero_point=True), init_only=True
model,
w_bit=4,
q_config=dict(q_group_size=args.q_group_size, zero_point=True),
init_only=True,
)
model = model.to(device)

Expand Down
26 changes: 16 additions & 10 deletions tinychat/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,15 @@ def stream_output(output_stream):
"--max_seq_len",
type=int,
default=2048,
help="maximum sequence length for kv cache"
help="maximum sequence length for kv cache",
)
parser.add_argument(
"--max_batch_size",
type=int,
default=1,
help="maximum batch size for kv cache"
"--max_batch_size", type=int, default=1, help="maximum batch size for kv cache"
)
parser.add_argument(
"--mem_efficient_load",
action="store_true",
help="enable mem_efficient_load mod",
)

args = parser.parse_args()
Expand All @@ -129,7 +131,14 @@ def stream_output(output_stream):
gen_params.n_vocab = 32000
tinychat.utils.constants.max_batch_size = args.max_batch_size
tinychat.utils.constants.max_seq_len = args.max_seq_len
# TODO (Haotian): a more elegant implementation here.
tinychat.utils.constants.mem_efficient_load = args.mem_efficient_load
if tinychat.utils.constants.mem_efficient_load:
print("=" * 80)
print(
"[Info] You have activated mem_efficient_load mode.\n Less on-chip memory will be consumed when loading the model.\n However, the loading process will take more time."
)
print("=" * 80)
# TODO (Haotian): a more elegant implementation here.
# We need to update these global variables before models use them.
from tinychat.models import FalconForCausalLM, LlamaForCausalLM, MPTForCausalLM

Expand All @@ -153,7 +162,6 @@ def skip(*args, **kwargs):
)
modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)

model_type_dict = {
"llama": LlamaForCausalLM,
Expand All @@ -168,9 +176,7 @@ def skip(*args, **kwargs):
model, args.load_quant, 4, args.q_group_size, args.device
)
else:
model = (
model_type_dict[args.model_type.lower()](config).half()
)
model = model_type_dict[args.model_type.lower()](config).half()
model = load_awq_model(
model, args.load_quant, 4, args.q_group_size, args.device
)
Expand Down
1 change: 1 addition & 0 deletions tinychat/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
max_batch_size = tinychat.utils.constants.max_batch_size
max_seq_len = tinychat.utils.constants.max_seq_len


# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
Expand Down
Loading

0 comments on commit ca11f3e

Please sign in to comment.