Skip to content

Commit

Permalink
[MAJOR] add new backend kernels & support multi-modal models
Browse files Browse the repository at this point in the history
  • Loading branch information
ys-2020 committed Feb 23, 2024
1 parent e427849 commit 5f06dbb
Show file tree
Hide file tree
Showing 58 changed files with 6,481 additions and 417 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
.DS_Store

data/
checkpoints
demo_images
serve_images
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
69 changes: 60 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,39 @@
The current release supports:

- AWQ search for accurate quantization.
- Pre-computed AWQ model zoo for LLMs (LLaMA, Llama2, OPT, CodeLlama, StarCoder, Vicuna, LLaVA; load to generate quantized weights).
- Pre-computed AWQ model zoo for LLMs (LLaMA, Llama2, OPT, CodeLlama, StarCoder, Vicuna, VILA, LLaVA; load to generate quantized weights).
- Memory-efficient 4-bit Linear in PyTorch.
- Efficient CUDA kernel implementation for fast inference (support context and decoding stage).
- Examples on 4-bit inference of an instruction-tuned model (Vicuna) and multi-modal LM (LLaVA).
- Examples on 4-bit inference of an instruction-tuned model (Vicuna) and **multi-modal LM** (VILA).

**Thanks to AWQ, TinyChat can deliver more efficient responses with LLM/VLM chatbots through 4-bit inference.**

* TinyChat on RTX 4090 (3.4x faster than FP16):

![TinyChat on RTX 4090: W4A16 is 3.4x faster than FP16](./tinychat/figures/4090_example.gif)

* TinyChat on Jetson Orin (3.2x faster than FP16):

![TinyChat on Orin: W4A16 is 3.2x faster than FP16](./tinychat/figures/orin_example.gif)

Check out [TinyChat](tinychat), which delievers **30 tokens/second** inference performance (**3.2x faster** than FP16) for the **Llama2** chatbot on the resource-constrained NVIDIA Jetson Orin!
**TinyChat also supports inference with vision language models (e.g., VILA, LLaVA). In the following examples, W4A16 quantized models from VILA family are launched with TinyChat.**

* TinyChat with VILA-13B on RTX 4090 (multi-image inputs supported):

![TinyChat with VILA on 4090](./tinychat/figures/4090_vila_example.gif)

* TinyChat with VILA-7B/13B on Jetson Orin:

![TinyChat with VILA on Orin](./tinychat/figures/orin_vila_example.gif)

<!-- Check out [TinyChat](tinychat), which delievers **30 tokens/second** inference performance (**3.2x faster** than FP16) for the **Llama2** chatbot on the resource-constrained NVIDIA Jetson Orin! -->

It also offers a turn-key solution for **on-device inference** of LLMs on **resource-constrained edge platforms**. With TinyChat, it is now possible to run **large** models on **small** and **low-power** devices even without Internet connection.
Check out [TinyChat](tinychat), which offers a turn-key solution for **on-device inference** of LLMs and VLMs on **resource-constrained edge platforms**. With TinyChat, it is now possible to efficiently run **large** models on **small** and **low-power** devices even without Internet connection!


## News
- [2024/02] 🔥 We supported [VILA model family](https://arxiv.org/abs/2312.07533) in AWQ & TinyChat! Check our latest demos with multi-image inputs!
- [2024/02] 🔥 We released new version of quantized GEMM/GEMV kernels in [**TinyChat**](tinychat), leading to **38 tokens/second** inference speed on NVIDIA Jetson Orin!
- [2023/11] 🔥 We added AWQ support and pre-computed search results for CodeLlama, StarCoder, StableCode models. Checkout our model zoo [here](https://huggingface.co/datasets/mit-han-lab/awq-model-zoo)!
- [2023/11] 🔥 AWQ is now integrated natively in Hugging Face transformers through `from_pretrained`. You can either load quantized models from the Hub or your own HF quantized models.
- [2023/10] AWQ is integrated into NVIDIA [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/)
Expand All @@ -33,11 +53,16 @@ It also offers a turn-key solution for **on-device inference** of LLMs on **reso

## Contents

- [Install](#install)
- [AWQ Model Zoo](#awq-model-zoo)
- [Examples](#examples)
- [Usage](#usage)
- [Reference](#reference)
- [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](#awq-activation-aware-weight-quantization-for-llm-compression-and-acceleration)
- [News](#news)
- [Contents](#contents)
- [Install](#install)
- [AWQ Model Zoo](#awq-model-zoo)
- [Examples](#examples)
- [Usage](#usage)
- [Evaluation](#evaluation)
- [Reference](#reference)
- [Related Projects](#related-projects)

## Install

Expand Down Expand Up @@ -88,6 +113,7 @@ The detailed support list:
| StarCoder | 15.5B |||
| Vicuna-v1.1 | 7B/13B || |
| LLaVA-v0 | 13B || |
| VILA | 7B/13B || |

## Examples

Expand Down Expand Up @@ -136,6 +162,31 @@ python -m awq.entry --model_path /PATH/TO/OPT/opt-6.7b \
--load_quant quant_cache/opt-6.7b-w4-g128-awq.pt
```

## Results on Vision-Language Models (VILA-7b/13B)

AWQ also seamlessly supports large multi-modal models (LMMs). We demonstrate the results on the recent [VILA](https://github.com/Efficient-Large-Model/VILA) model family.

| VILA-7B | VQA-v2 | GQA | VizWiz | ScienceQA | TextVQA | POPE | MME | MMBench | MMBench-CN | SEED |
| ----------- |:-----------------:|:-----------------:|:-------:|:-----------------:|:-----------------:|:-------:|:-------:|:-----------------:|:-------------:|:-------:|
| FP16 | 80.3 | 63.1 | 59.6 | 68.0 | 62.6 | 86.3 | 1489.4 | 69.8 | 61.0 | 61.7 |
| AWQ-INT4 | 80.1 | 63.0 | 57.8 | 68.3 | 61.9 | 85.3 | 1486.3 | 68.8 | 58.9 | 61.3 |

| VILA-13B | VQA-v2 | GQA | VizWiz | ScienceQA | TextVQA | POPE | MME | MMBench | MMBench-CN | SEED |
| ----------- |:-----------------:|:-----------------:|:-------:|:-----------------:|:-----------------:|:-------:|:-------:|:-----------------:|:-------------:|:-------:|
| FP16 | 80.5 | 63.6 | 63.1 | 70.5 | 64.0 | 86.3 | 1553.6 | 73.8 | 66.7 | 62.8 |
| AWQ-INT4 | 80.4 | 63.6 | 63.0 | 71.2 | 63.5 | 87.0 | 1552.9 | 73.6 | 66.3 | 62.2 |


## Inference speed ( Token/sec )

| $~~~~~~$ | Precision | A100 | 4090 | Orin |
| --- | --- |--- | --- | --- |
| VILA-7B | fp16 | 81.6 | 58.5 | 11.5 |
| VILA-7B-AWQ| int4 |155.3| 168.1| 35.6 |
| VILA-13B | fp16 | 48.5 | OOM | 6.1 |
| VILA-13B-AWQ | int4 | 102.1| 99.0| 17.5 |


## Reference

If you find AWQ useful or relevant to your research, please kindly cite our paper:
Expand Down
48 changes: 31 additions & 17 deletions awq/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from torch import nn
import tqdm


parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, help="path of the hf model")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
Expand Down Expand Up @@ -87,15 +86,26 @@ def build_model_and_enc(model_path):
print(f"* Building model {model_path}")

# all hf model
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if "mpt" in config.__class__.__name__.lower():
enc = AutoTokenizer.from_pretrained(
config.tokenizer_name, trust_remote_code=True
if "llava" in model_path.lower() or "vila" in model_path.lower():
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path

enc, model, image_processor, context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=get_model_name_from_path(model_path),
device="cpu",
)
else:
enc = AutoTokenizer.from_pretrained(
model_path, use_fast=False, trust_remote_code=True
)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if "mpt" in config.__class__.__name__.lower():
enc = AutoTokenizer.from_pretrained(
config.tokenizer_name, trust_remote_code=True
)
else:
enc = AutoTokenizer.from_pretrained(
model_path, use_fast=False, trust_remote_code=True
)

if args.load_quant: # directly load quantized weights
print("Loading pre-computed quantized weights...")
Expand Down Expand Up @@ -137,9 +147,10 @@ def build_model_and_enc(model_path):
args.run_awq &= not args.load_awq # if load_awq, no need to run awq
# Init model on CPU:
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs
)
if not ("llava" in model_path.lower() or "vila" in model_path.lower()):
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs
)

model.eval()

Expand Down Expand Up @@ -178,6 +189,9 @@ def build_model_and_enc(model_path):
elif args.q_backend == "real": # real quantization
real_quantize_model_weight(model, w_bit=args.w_bit, q_config=q_config)
if args.dump_quant:
if not args.dump_quant.endswith("v2.pt"):
print("[Info] Auto-change the dump_quant file name to *v2.pt")
args.dump_quant = args.dump_quant.replace(".pt", "-v2.pt")
dirpath = os.path.dirname(args.dump_quant)
os.makedirs(dirpath, exist_ok=True)

Expand Down Expand Up @@ -272,12 +286,12 @@ def main():

print(evaluator.make_table(results))

if args.output_path is not None:
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
# otherwise cannot save
results["config"]["model"] = args.model_path
with open(args.output_path, "w") as f:
json.dump(results, f, indent=2)
if args.output_path is not None:
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
# otherwise cannot save
results["config"]["model"] = args.model_path
with open(args.output_path, "w") as f:
json.dump(results, f, indent=2)


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions awq/kernels/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemv_cuda.h"
#include "quantization_new/gemm/gemm_cuda.h"
#include "quantization_new/gemv/gemv_cuda.h"
#include "position_embedding/pos_encoding.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel");
m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel.");
m.def("gemv_forward_cuda", &gemv_forward_cuda, "Quantized GEMV kernel.");
m.def("gemm_forward_cuda_new", &gemm_forward_cuda_new, "New quantized GEMM kernel.");
m.def("gemv_forward_cuda_new", &gemv_forward_cuda_new, "New quantized GEMV kernel.");
m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key");
m.def("single_query_attention", &single_query_attention, "Attention with a single query",
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"),
Expand Down
77 changes: 77 additions & 0 deletions awq/kernels/csrc/quantization_new/dequantize.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <cuda_fp16.h>
#pragma once

__inline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result)
{
// uint4 result;

uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);

// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;

// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.

// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));

// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.

// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;

// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));

// return result;
}
Loading

0 comments on commit 5f06dbb

Please sign in to comment.