forked from allenai/open-instruct
-
Notifications
You must be signed in to change notification settings - Fork 0
/
quantize_autogptq_wikitext.py
100 lines (79 loc) · 3.04 KB
/
quantize_autogptq_wikitext.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Run 4-bit model quantization with GPTQ, using Wikitext as train data.
Based on `examples/quantization/basic_usage_wikitext2` in AutoGPT.
Usage example (runs on a single GPU):
python quantize_autogptq.py \
--pretrained_model_dir "/net/nfs.cirrascale/allennlp/hamishi/open-instruct/alpaca_fixed_65b" \
--quantized_model_dir "/net/nfs.cirrascale/allennlp/davidw/checkpoints/gptq_alpaca_fixed_65b"
"""
import argparse
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from datasets import load_dataset
import numpy as np
import torch
import time
def get_wikitext2(nsamples, seed, seqlen, model):
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
import random
random.seed(seed)
np.random.seed(0)
torch.random.manual_seed(0)
traindataset = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
attention_mask = torch.ones_like(inp)
traindataset.append({"input_ids": inp, "attention_mask": attention_mask})
return traindataset, testenc
def get_args():
parser = argparse.ArgumentParser(
description="Run 4-bit model quantization using GPTQ."
)
parser.add_argument(
"--pretrained_model_dir", type=str, help="Path to unquantized model."
)
parser.add_argument(
"--quantized_model_dir", type=str, help="Path to quantized model."
)
parser.add_argument(
"--n_samples", type=int, help="How many samples from Wikitext.", default=128
)
args = parser.parse_args()
return args
def main():
"Run quantization."
args = get_args()
print("Getting data.")
trainloader, testenc = get_wikitext2(
args.n_samples, 0, 2048, args.pretrained_model_dir
)
print("Done.")
quantize_config = BaseQuantizeConfig(
bits=4, # quantize model to 4-bit
group_size=128, # it is recommended to set the value to 128
)
print("Loading unquantized model")
# Load un-quantized model, the model will always be force loaded into cpu
model = AutoGPTQForCausalLM.from_pretrained(
args.pretrained_model_dir, quantize_config
)
print("Done")
# Quantize model, the examples should be list of dict whose keys can only be
# "input_ids" and "attention_mask" with value under torch.LongTensor type.
print("Quantizing")
tick = time.time()
model.quantize(trainloader, use_triton=True)
elapsed = (time.time() - tick) / 60
print(f"Elapsed time:{elapsed:0.2f} minutes.")
# save quantized model
print("Saving")
model.save_quantized(args.quantized_model_dir)
print("Done")
if __name__ == "__main__":
main()