forked from jy-yuan/KIVI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
example.py
54 lines (43 loc) · 1.79 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# LLaMA model with KIVI
import warnings
warnings.filterwarnings("ignore")
import torch
import random
from models.llama_kivi import LlamaForCausalLM_KIVI
from transformers import LlamaConfig, AutoTokenizer
from datasets import load_dataset
# For reproducibility
random.seed(0)
torch.manual_seed(0)
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
#model_name = "meta-llama/Llama-2-7b-hf"
# Llama3 model_name
config = LlamaConfig.from_pretrained(model_name)
config.k_bits = 2 # KiVi currently support 2/4 K/V bits
config.v_bits = 2
config.group_size = 32
config.residual_length = 32 # corresponding to the number of recent fp16 tokens
# CACHE_DIR = "./"
CACHE_DIR = None
model = LlamaForCausalLM_KIVI.from_pretrained(
pretrained_model_name_or_path=model_name,
config=config,
cache_dir=CACHE_DIR,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
).cuda()
enc = AutoTokenizer.from_pretrained(
model_name,
use_fast=False,
trust_remote_code=True,
tokenizer_type='llama')
dataset = load_dataset('gsm8k', 'main')
prompt = ''
for i in range(5):
prompt += 'Question: ' + dataset['train'][i]['question'] + '\nAnswer: ' + dataset['train'][i]['answer'] + '\n'
prompt += "Question: John takes care of 10 dogs. Each dog takes .5 hours a day to walk and take care of their business. How many hours a week does he spend taking care of dogs?"
inputs = enc(prompt, return_tensors="pt").input_ids.cuda()
output = model.generate(inputs, max_new_tokens=96)
config_str = f"# prompt tokens: {inputs.shape[1]}, K bit: {config.k_bits}, v_bits: {config.v_bits}, group_size: {config.group_size}, residual_length: {config.residual_length}"
print(prompt + "\n" + "=" * 10 + f'\n{config_str}\n' + "=" * 10 + "\nKiVi Output:")
print(enc.decode(output[0].tolist()[inputs.shape[1]:], skip_special_tokens=True))