Skip to content

Commit f06f680

Browse files
committed
Add Fake Quantization
1 parent 8900418 commit f06f680

11 files changed

+6240
-0
lines changed

fake_quant/README.md

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Fake Quantization in QuaRot
2+
3+
4+
In this directory, we provide the torch scripts for the experiments in QuaRot.
5+
6+
7+
## Language Generation and Zero-Shot Evaluations
8+
9+
Currently, we only support **LLaMa-2** models. You can simply run the `main.py` to reproduce the results in the paper. The most important arguments are:
10+
11+
- `--model`: the model name (or path to the weights)
12+
- `--bsz`: the batch size for PPL evaluation
13+
- `--rotate`: whether we want to rotate the model
14+
- `--lm_eval`: whether we want to run LM-Eval for Zero-Shot tasks
15+
- `--tasks`: the tasks for LM-Eval
16+
- `--cal_dataset`: the calibration dataset for GPTQ quantization
17+
- `--a_bits`: the number of bits for activation quantization
18+
- `--w_bits`: the number of bits for weight quantization
19+
- `--v_bits`: the number of bits for value quantization
20+
- `--k_bits`: the number of bits for key quantization
21+
- `--w_clip`: Whether we want to clip the weights
22+
- `--a_clip_ratio`: The ratio of clipping for activation
23+
- `--k_clip_ratio`: The ratio of clipping for key
24+
- `--v_clip_ratio`: The ratio of clipping for value
25+
- `--w_asym`: Whether we want to use asymmetric quantization for weights
26+
- `--a_asym`: Whether we want to use asymmetric quantization for activation
27+
- `--v_asym`: Whether we want to use asymmetric quantization for value
28+
- `--k_asym`: Whether we want to use asymmetric quantization for key
29+
- `--a_groupsize`: The group size for activation quantization
30+
- `--w_groupsize`: The group size for weight quantization
31+
- `--v_groupsize`: The group size for value quantization
32+
- `--k_groupsize`: The group size for key quantization
33+
34+
For example, to run the perplexity of `LLaMA2-7B` model with quantizing all weights and activations, you can run the following command:
35+
36+
```bash
37+
/bin/python main.py --model meta-llama/Llama-2-7b-hf --rotate --a_bits 4 --v_bits 4 --k_bits 4 --w_bits 4 --w_clip
38+
```

fake_quant/data_utils.py

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import datasets
2+
import random
3+
import transformers
4+
5+
def get_wikitext2(nsamples, seed, seqlen, model, hf_token, eval_mode=False):
6+
7+
if hf_token is None:
8+
tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False)
9+
else:
10+
tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token)
11+
12+
if eval_mode:
13+
testdata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
14+
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
15+
return testenc
16+
else:
17+
traindata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
18+
trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
19+
random.seed(seed)
20+
trainloader = []
21+
for _ in range(nsamples):
22+
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
23+
j = i + seqlen
24+
inp = trainenc.input_ids[:, i:j]
25+
tar = inp.clone()
26+
tar[:, :-1] = -100
27+
trainloader.append((inp, tar))
28+
return trainloader
29+
30+
def get_c4_new(nsamples, seed, seqlen, model, hf_token=None, eval_mode=False):
31+
32+
if hf_token is None:
33+
tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False)
34+
else:
35+
tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token)
36+
37+
if eval_mode:
38+
valdata = datasets.load_dataset(
39+
'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
40+
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
41+
valenc = valenc.input_ids[:, :(256 * seqlen)]
42+
class TokenizerWrapper:
43+
def __init__(self, input_ids):
44+
self.input_ids = input_ids
45+
valenc = TokenizerWrapper(valenc)
46+
return valenc
47+
else:
48+
traindata = datasets.load_dataset(
49+
'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
50+
51+
random.seed(seed)
52+
trainloader = []
53+
for _ in range(nsamples):
54+
while True:
55+
i = random.randint(0, len(traindata) - 1)
56+
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
57+
if trainenc.input_ids.shape[1] >= seqlen:
58+
break
59+
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
60+
j = i + seqlen
61+
inp = trainenc.input_ids[:, i:j]
62+
tar = inp.clone()
63+
tar[:, :-1] = -100
64+
trainloader.append((inp, tar))
65+
return trainloader
66+
67+
68+
69+
70+
def get_ptb_new(nsamples, seed, seqlen, model, hf_token, eval_mode=False):
71+
72+
73+
if hf_token is None:
74+
tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False)
75+
else:
76+
tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_fast=False, use_auth_token=hf_token)
77+
78+
if eval_mode:
79+
testdata = datasets.load_dataset('ptb_text_only', 'penn_treebank', split='test')
80+
testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt')
81+
return testenc
82+
else:
83+
traindata = datasets.load_dataset('ptb_text_only', 'penn_treebank', split='train')
84+
trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt')
85+
random.seed(seed)
86+
trainloader = []
87+
for _ in range(nsamples):
88+
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
89+
j = i + seqlen
90+
inp = trainenc.input_ids[:, i:j]
91+
tar = inp.clone()
92+
tar[:, :-1] = -100
93+
trainloader.append((inp, tar))
94+
return trainloader
95+
96+
97+
def get_loaders(
98+
name, nsamples=128, seed=0, seqlen=2048, model='', hf_token=None, eval_mode=False
99+
):
100+
if 'wikitext2' in name:
101+
return get_wikitext2(nsamples, seed, seqlen, model, hf_token, eval_mode)
102+
if 'ptb' in name:
103+
return get_ptb_new(nsamples, seed, seqlen, model, hf_token, eval_mode)
104+
if 'c4' in name:
105+
return get_c4_new(nsamples, seed, seqlen, model, hf_token, eval_mode)

fake_quant/eval_utils.py

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import utils
2+
import model_utils
3+
import quant_utils
4+
import torch
5+
import os
6+
import logging
7+
from tqdm import tqdm
8+
9+
10+
@torch.no_grad()
11+
def evaluator(model, testenc, dev, args):
12+
13+
model.eval()
14+
15+
if 'opt' in args.model:
16+
opt_type = True
17+
llama_type = False
18+
elif 'meta' in args.model:
19+
llama_type = True
20+
opt_type = False
21+
else:
22+
raise ValueError(f'Unknown model {args.model}')
23+
24+
25+
use_cache = model.config.use_cache
26+
model.config.use_cache = False
27+
28+
if opt_type:
29+
layers = model.model.decoder.layers
30+
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
31+
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
32+
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
33+
model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
34+
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
35+
model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
36+
37+
elif llama_type:
38+
layers = model.model.layers
39+
model.model.embed_tokens = model.model.embed_tokens.to(dev)
40+
41+
layers[0] = layers[0].to(dev)
42+
43+
# Convert the whole text of evaluation dataset into batches of sequences.
44+
input_ids = testenc.input_ids # (1, text_len)
45+
nsamples = input_ids.numel() // model.seqlen # The tail is truncated.
46+
input_ids = input_ids[:, :nsamples * model.seqlen].view(nsamples, model.seqlen).to(dev) # (nsamples, seqlen)
47+
48+
batch_size = args.bsz
49+
input_ids = [input_ids[i:i + batch_size] for i in range(0, nsamples, batch_size)]
50+
nbatches = len(input_ids)
51+
52+
dtype = next(iter(model.parameters())).dtype
53+
# The input of the first decoder layer.
54+
inps = torch.zeros(
55+
(nbatches, batch_size, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
56+
)
57+
inps = [0] * nbatches
58+
cache = {'i': 0, 'attention_mask': None}
59+
class Catcher(torch.nn.Module):
60+
def __init__(self, module):
61+
super().__init__()
62+
self.module = module
63+
def forward(self, inp, **kwargs):
64+
inps[cache['i']] = inp
65+
cache['i'] += 1
66+
cache['attention_mask'] = kwargs['attention_mask']
67+
if llama_type:
68+
cache['position_ids'] = kwargs['position_ids']
69+
raise ValueError
70+
layers[0] = Catcher(layers[0])
71+
72+
for i in range(nbatches):
73+
batch = input_ids[i]
74+
try:
75+
model(batch)
76+
except ValueError:
77+
pass
78+
layers[0] = layers[0].module
79+
layers[0] = layers[0].cpu()
80+
81+
if opt_type:
82+
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
83+
model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
84+
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
85+
model.model.decoder.project_out = model.model.decoder.project_out.cpu()
86+
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
87+
model.model.decoder.project_in = model.model.decoder.project_in.cpu()
88+
elif llama_type:
89+
model.model.embed_tokens = model.model.embed_tokens.cpu()
90+
position_ids = cache['position_ids']
91+
92+
torch.cuda.empty_cache()
93+
outs = [0] * nbatches
94+
attention_mask = cache['attention_mask']
95+
96+
for i in tqdm(range(len(layers)), desc="(Eval) Layers"):
97+
layer = layers[i].to(dev)
98+
99+
# Dump the layer input and output
100+
if args.capture_layer_io and args.layer_idx == i:
101+
captured_io = model_utils.capture_layer_io(model_utils.get_model_type(model), layer, inps)
102+
save_path = model_utils.get_layer_io_save_path(args)
103+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
104+
torch.save(captured_io, save_path)
105+
logging.info(f'Dumped layer input and output to: {save_path}')
106+
107+
for j in range(nbatches):
108+
if opt_type:
109+
outs[j] = layer(inps[j], attention_mask=attention_mask)[0]
110+
elif llama_type:
111+
outs[j] = layer(inps[j], attention_mask=attention_mask, position_ids=position_ids)[0]
112+
layers[i] = layer.cpu()
113+
del layer
114+
torch.cuda.empty_cache()
115+
inps, outs = outs, inps
116+
117+
if opt_type:
118+
if model.model.decoder.final_layer_norm is not None:
119+
model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev)
120+
if model.model.decoder.project_out is not None:
121+
model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
122+
123+
elif llama_type:
124+
if model.model.norm is not None:
125+
model.model.norm = model.model.norm.to(dev)
126+
127+
model.lm_head = model.lm_head.to(dev)
128+
nlls = []
129+
loss_fct = torch.nn.CrossEntropyLoss(reduction = "none")
130+
for i in range(nbatches):
131+
hidden_states = inps[i]
132+
if opt_type:
133+
if model.model.decoder.final_layer_norm is not None:
134+
hidden_states = model.model.decoder.final_layer_norm(hidden_states)
135+
if model.model.decoder.project_out is not None:
136+
hidden_states = model.model.decoder.project_out(hidden_states)
137+
elif llama_type:
138+
if model.model.norm is not None:
139+
hidden_states = model.model.norm(hidden_states)
140+
lm_logits = model.lm_head(hidden_states)
141+
shift_logits = lm_logits[:, :-1, :]
142+
shift_labels = input_ids[i][:, 1:]
143+
loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels)
144+
neg_log_likelihood = loss.float().mean(dim=1)
145+
nlls.append(neg_log_likelihood)
146+
nlls_tensor = torch.cat(nlls)
147+
ppl = torch.exp(nlls_tensor.mean())
148+
model.config.use_cache = use_cache
149+
logging.info(f'\n{args.eval_dataset.upper()} PPL: {ppl.item():.3f}')
150+
return ppl.item()

0 commit comments

Comments
 (0)