-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial fused GPTQ
implementation
#141
base: main
Are you sure you want to change the base?
Conversation
@jeromeku Oh my this is a LARGE PR!!!! I'll take a read through it today :) |
Ohh know I understand why you add the matmul triton kernels that are merged and not a separate dequantize kernel then a matmul ie: out = dequantize_and_matmul(X, W) vs W = dequantize(W)
out = torch.matmul(X, W) I took a look through GPTQ's repo, and yes I cannot find any dequantization kernel either written in Triton or not. To attain maximal performance, technically that means an inclusion of the GPTQ I'll see what I can do if I have some more bandwidth - sadly I don't have too much knowledge about GPTQ so I'll have to dive into their papers a bit on how their dequantization even works :) Great work so far @jeromeku and thanks so much wonderfully for trying to add GPTQ! |
@danielhanchen |
@jeromeku Ok cool!! :) |
Stripped out Promising early results (forward only):
These are median time (ms) for various sequence lengths. However, running both forward and backward degrades the performance of the compiled version vs ref, which is confusing since the backwards graph is just a transposed matmul. Needs further investigation. Interestingly, the |
@jeromeku Cool great work again! Ye it definitely looks like torch.compile is destroying the hand written GPTQ kernel inside HF's codebase loll! Ye the backwards is transpose - but I'm assuming it's cause the strides are reversed, causing a performance hit - just my speculation. |
Good news -- refactored the Performance now is on par with Will run some additional tests / benchmarks and PR should be ready for review. Trainer results after 20 steps on
{
"train_runtime": 113.4277,
"train_samples_per_second": 1.411,
"train_steps_per_second": 0.176,
"train_loss": 1.3709101617336272,
"epoch": 0.02
}
{
"train_runtime": 69.5648,
"train_samples_per_second": 2.3,
"train_steps_per_second": 0.288,
"train_loss": 1.3829106092453003,
"epoch": 0.02
}
{
"train_runtime": 63.8765,
"train_samples_per_second": 2.505,
"train_steps_per_second": 0.313,
"train_loss": 1.3803951740264893,
"epoch": 0.02
} |
@jeromeku Extremely extremely fabulous work!!! Now that is a fantastic performance boost from HF's GPTQ!! It looks like splitting the dequantization step and matmul did the trick!! Again super duper appreciate you adding GPTQ support into Unsloth - highly appreciate it :) |
Cleaned up the Re-running the above benchmark (20 train steps on {
"train_runtime": 67.3811,
"train_samples_per_second": 2.375,
"train_steps_per_second": 0.297,
"train_loss": 1.3829236447811126,
"epoch": 0.02
} To reproduce, run python benchmark.py --model_name=llama --model_type=unsloth-gptq-triton --dtype=float16 --dataset_id=guanaco --output_dir=./bench_results Replace See |
@jeromeku Super duper great work again! I will take a look later today! Thanks so much for your contribution again! |
@jeromeku Hey sorry on the delay! Extreme apologies again didn't have time to take a look :( I will do so asap in the next few days! Sorry again, and super great work again! :) |
GPTQ Peft Fine-tuning
GPTQ fast_lora
Adds
fast_lora
implementation forpeft
fine-tuning ofGPTQ
quantized models.bitsandbytes
fast_lora
custom autograd, uses fusestriton
quant / dequant matmul kernels fromauto_gptq
withLoRA
adapters into customtorch.autograd.Function
(seeunsloth/gptq/fast_lora.py
).Huggingface
GPTQ
peft fine-tuning uses theauto_gptq
cuda
QuantLinear
layer, which in turn falls back to atorch-only
implementation since the customcuda
kernel employed byauto_gptq
does not implement backwards.Profiling
unsloth
models withhuggingface
modelsbenchmarks/Profiling.MD
for documentation.