Skip to content

Commit

Permalink
Many code changes, including renaming tests/test_accept.py to measure…
Browse files Browse the repository at this point in the history
…_acceptance_rates.py, and making big code updates
  • Loading branch information
avnermay committed Apr 16, 2024
1 parent ba77471 commit 6314476
Show file tree
Hide file tree
Showing 7 changed files with 445 additions and 187 deletions.
File renamed without changes.
38 changes: 26 additions & 12 deletions data_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,62 @@
from datasets import load_dataset
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
check_min_version("4.28.0.dev0")
check_min_version('4.28.0.dev0')

logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
require_version('datasets>=1.8.0', 'To fix: pip install -r examples/pytorch/translation/requirements.txt')


def convert_wiki_dataset(tokenizer, seq_len = 256):
dataset = load_dataset("wikimedia/wikipedia", "20231101.en", split="train[0:2000]")
dataset = load_dataset('wikimedia/wikipedia', '20231101.en', split='train[0:2000]')
def tokenize_function(examples):
return tokenizer(examples["text"], return_tensors='pt',max_length=seq_len,padding=True,truncation=True)
return tokenizer(examples['text'], return_tensors='pt',max_length=seq_len,padding=True,truncation=True)
dataset = dataset.map(tokenize_function, batched=True, remove_columns=['text'])
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
return dataset


def convert_cnn_dataset(tokenizer, seq_len = 256):
dataset = load_dataset("cnn_dailymail", "1.0.0", split="test[0:2000]")
dataset = load_dataset('cnn_dailymail', '1.0.0', split='test[0:2000]')
def tokenize_function(examples):
return tokenizer(examples["article"], return_tensors='pt',max_length=seq_len,padding=True,truncation=True)
return tokenizer(examples['article'], return_tensors='pt',max_length=seq_len,padding=True,truncation=True)
dataset = dataset.map(tokenize_function, batched=True, remove_columns=['article'])
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
return dataset


def convert_c4_dataset_eval(tokenizer, seq_len = 256):
dataset = load_dataset("c4", "en", split="validation[0:2000]")
dataset = load_dataset('c4', 'en', split='validation[0:2000]')
def tokenize_function(examples):
return tokenizer(examples["text"], return_tensors='pt',max_length=seq_len,padding=True,truncation=True)
return tokenizer(examples['text'], return_tensors='pt',max_length=seq_len,padding=True,truncation=True)
dataset = dataset.map(tokenize_function, batched=True, remove_columns=['text', 'timestamp', 'url'])
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
return dataset


def convert_jsonl_file(tokenizer, data_files, seq_len=256, text_field='text', remove_columns=None):
if remove_columns is None:
remove_columns = [text_field]
dataset = load_dataset('json', data_files=data_files, split='train', streaming=False)
def tokenize_function(examples):
return tokenizer(examples['text'], return_tensors='pt', max_length=seq_len, padding=True, truncation=True)
dataset = dataset.map(tokenize_function, batched=True, remove_columns=remove_columns)
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
return dataset


def convert_dataset(tokenizer, file_path):
dataset = load_dataset("json", data_files=file_path, split="train")
dataset = load_dataset('json', data_files=file_path, split='train')
def tokenize_function(examples):
input_ids = torch.Tensor(examples['input_ids'])
labels = input_ids.clone()
if tokenizer.pad_token_id is not None:
labels[labels == tokenizer.pad_token_id] = -100
ret = {
"input_ids": input_ids,
"labels": labels
'input_ids': input_ids,
'labels': labels
}
return ret
dataset = dataset.map(tokenize_function, batched=True, remove_columns=['input_tokens'])
dataset.set_format(type='torch', columns=['input_ids', "labels"])
dataset.set_format(type='torch', columns=['input_ids', 'labels'])
return dataset
237 changes: 237 additions & 0 deletions measure_acceptance_rates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import argparse
import json
import os
from tqdm import tqdm

from accelerate import Accelerator
from datasets import load_from_disk
import torch
from torch.utils.data.dataloader import DataLoader
from transformers import DataCollatorForLanguageModeling, AutoTokenizer


from data_converter import convert_wiki_dataset, convert_cnn_dataset, convert_c4_dataset_eval, convert_jsonl_file, convert_dataset
from Tree.SpecTree import SpecTreeTest
from Tree.GreedyTree import GreedyTreeTest
from Engine.Engine import GraphInferenceEngine, GraphInferenceEngineTG


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--draft', type=str, help='draft model')
parser.add_argument('--target', type=str, help='target model')
parser.add_argument('--dataset', type=str, default='dataset/c4_small.json', help='dataset name or path')
parser.add_argument('--start', type=int, default=0, help='start')
parser.add_argument('--end', type=int, default=200, help='end')
parser.add_argument('--temp', type=float, default=0.6, help='temperature') # T
parser.add_argument('--top_p', type=float, default=0.9, help='top_p') # P
parser.add_argument('--algorithm', type=str, default='stochastic',
choices=['stochastic', 'greedy'], help='algorithm')
parser.add_argument('--max_width', type=int, default=16, help='max width') # W
parser.add_argument('--max_length', type=int, default=256, help='max length') # M
parser.add_argument('--offloading', action='store_true')
parser.add_argument('--output_dir', type=str, default='.', help='destination directory for acceptance rate vector')
parser.add_argument('--output_file', type=str, default=None, help='filename for acceptance rate vector')
args = parser.parse_args()
print(f'{args}=')
return args


def get_output_filename(args):
output_dir = args.output_dir or '.'
if args.output_file is None:
target_str = f'target_{args.target.replace('/', '--')}'
draft_str = f'draft_{args.draft.replace('/', '--')}'
temp_str = f'temp_{args.temp}'
top_p_str = f'top_p_{args.top_p}'
output_file = f'{args.dataset}_{target_str}_{draft_str}_{temp_str}_{top_p_str}_acc_rates.json'
else:
output_file = args.output_file
return f'{output_dir}/{output_file}'


def save_acceptance_rates(acceptance_rate_vector, output_file):
print(f'{acceptance_rate_vector=}')
if output_file.endswith('.json'):
with open(output_file, 'w') as f:
json.dump({'acceptance_rates': acceptance_rate_vector.cpu().numpy().tolist()}, f, indent=4)
else:
torch.save(acceptance_rate_vector, output_file)


def simulation_stochastic(
target_model : GraphInferenceEngineTG,
draft_model: GraphInferenceEngine,
dataloader: DataLoader,
output_file: str,
temp=0.6,
top_p=0.9,
max_width=4,
max_length=512,
):
num_eval_steps = len(dataloader)
num_decoding_steps = 0
num_large_model_steps = 0
dtype = torch.float16
attn_mask = torch.full((max_length, max_length), torch.finfo(dtype).min, dtype=dtype, device='cuda:0')
sequence = torch.tensor(list(range(max_length)), device='cuda:0').long().unsqueeze(-1)
new_tokens_buffer = torch.zeros(max_length).long().to('cuda:0')
parents_buffer = torch.zeros(max_length).long().to('cuda:0')
position_ids = torch.zeros(max_length).long().to('cuda:0')
branch_prob = torch.zeros(max_width + 1).to('cuda:0')
output_branch_prob = torch.zeros(max_width + 2).to('cuda:0')
with torch.no_grad():
for step, batch in tqdm(enumerate(dataloader), total=num_eval_steps):
input_ids = batch['input_ids'][..., :128]
labels = batch['labels'][..., :128]
terminate = False
if labels[0][-1] == -100: terminate = True
draft_kv_len = 0
target_kv_len = 0
while input_ids.shape[1] < 256 and not terminate:
attn_mask.fill_(torch.finfo(dtype).min)
spectree = SpecTreeTest(prefix=input_ids.squeeze(0), device='cuda:0', temperature=temp,
top_p=top_p,
draft_kv_len=draft_kv_len, target_kv_len=target_kv_len,
draft_model_engine=draft_model, target_model_engine=target_model, max_length=max_length,
attn_mask = attn_mask, sequence = sequence, new_tokens_buffer = new_tokens_buffer,
parents_buffer = parents_buffer,
position_ids = position_ids, max_width=max_width)


valid_tokens, draft_kv_len, target_kv_len, b, terminate = spectree.verify(benchmark=True)
branch_prob[b] += 1


num_decoding_steps += (valid_tokens.shape[0] - input_ids.shape[1])
num_large_model_steps += 1
input_ids = valid_tokens.unsqueeze(0)
if (input_ids[0] == 2)._is_any_true() or (input_ids[0] == 0)._is_any_true():
terminate = True

draft_model.clear_kv()
target_model.clear_kv()
if num_large_model_steps > 0:
print(num_decoding_steps / num_large_model_steps)
print(
'total decoding steps: {}.'.format(num_decoding_steps),
'large model steps: {}.'.format(num_large_model_steps),
'avg decoding step: {}.'.format(num_decoding_steps / num_large_model_steps),
)
branch_prob = branch_prob / branch_prob.sum(dim=-1)
output_branch_prob[1:] = branch_prob
save_acceptance_rates(output_branch_prob, output_file)
return num_decoding_steps / num_large_model_steps


def simulation_greedy(
target_model : GraphInferenceEngineTG,
draft_model: GraphInferenceEngine,
dataloader: DataLoader,
output_file: str,
max_width=4,
max_length=512,
):
num_eval_steps = len(dataloader)
num_decoding_steps = 0
num_large_model_steps = 0
dtype = torch.float16
attn_mask = torch.full((max_length, max_length), torch.finfo(dtype).min, dtype=dtype, device='cuda:0')
sequence = torch.tensor(list(range(max_length)), device='cuda:0').long().unsqueeze(-1)
new_tokens_buffer = torch.zeros(max_length).long().to('cuda:0')
parents_buffer = torch.zeros(max_length).long().to('cuda:0')
position_ids = torch.zeros(max_length).long().to('cuda:0')
branch_prob = torch.zeros(max_width + 1).to('cuda:0')
output_branch_prob = torch.zeros(max_width + 2).to('cuda:0')
with torch.no_grad():
for step, batch in tqdm(enumerate(dataloader), total=num_eval_steps):
input_ids = batch['input_ids'][..., :128]
labels = batch['labels'][..., :128]
terminate = False
if labels[0][-1] == -100: terminate = True
draft_kv_len = 0
target_kv_len = 0
while input_ids.shape[1] < 256 and terminate == False:
attn_mask.fill_(torch.finfo(dtype).min)
spectree = GreedyTreeTest(prefix=input_ids.squeeze(0), device='cuda:0',
draft_kv_len=draft_kv_len, target_kv_len=target_kv_len,
draft_model_engine=draft_model, target_model_engine=target_model, max_length=max_length,
attn_mask = attn_mask, sequence = sequence, new_tokens_buffer = new_tokens_buffer,
parents_buffer = parents_buffer,
position_ids = position_ids, max_width=max_width)

valid_tokens, draft_kv_len, target_kv_len, b, terminate = spectree.verify(benchmark=True)
initial_size = input_ids.shape[1]
input_ids = valid_tokens.unsqueeze(0)


if (input_ids[0] == 2)._is_any_true() or (input_ids[0] == 0)._is_any_true(): terminate = True
if not terminate:
branch_prob[b] += 1
num_decoding_steps += (valid_tokens.shape[0] - initial_size)
num_large_model_steps += 1


draft_model.clear_kv()
target_model.clear_kv()
if num_large_model_steps > 0:
print(num_decoding_steps / num_large_model_steps)
print(
'total decoding steps: {}.'.format(num_decoding_steps),
'large model steps: {}.'.format(num_large_model_steps),
'avg decoding step: {}.'.format(num_decoding_steps / num_large_model_steps),
)
branch_prob = branch_prob / branch_prob.sum(dim=-1)
output_branch_prob[1:] = branch_prob
save_acceptance_rates(output_branch_prob, output_file)
return num_decoding_steps / num_large_model_steps


def get_tokenized_dataloader(dataset, start, end):
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf', use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
current_dir = os.path.dirname(os.path.abspath(__file__))
if dataset == 'openwebtext':
tokenized_dataset_eval = load_from_disk(f'{current_dir}/dataset/openwebtext_eval')
elif dataset == 'wiki':
tokenized_dataset_eval = convert_wiki_dataset(tokenizer=tokenizer)
elif dataset == 'cnn':
tokenized_dataset_eval = convert_cnn_dataset(tokenizer=tokenizer)
elif dataset == 'wikitext':
tokenized_dataset_eval = convert_jsonl_file(tokenizer, f'{current_dir}/dataset/wikitext_dev.jsonl')
elif dataset == 'oasst':
tokenized_dataset_eval = convert_jsonl_file(tokenizer, f'{current_dir}/dataset/oasst_dev.jsonl')
elif dataset == 'c4_small':
tokenized_dataset_eval = convert_dataset(tokenizer=tokenizer, file_path=f'{current_dir}/dataset/c4_small.json')
else:
raise ValueError(f'Unsupported dataset: {dataset}.')

tokenized_dataset_eval = tokenized_dataset_eval.select(list(range(start, end)))
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
dataloader = DataLoader(tokenized_dataset_eval, batch_size=1, collate_fn=data_collator, shuffle=False)
return dataloader


if __name__ == '__main__':
args = get_args()
dataloader = get_tokenized_dataloader(args.dataset, args.start, args.end)
draft_model = GraphInferenceEngine(max_length=args.max_length, model_name_or_path=args.draft, dtype=torch.float16, device='cuda:0')
target_model = GraphInferenceEngineTG(max_length=args.max_length, model_name_or_path=args.target, dtype=torch.float16, device='cuda:0', offloading=args.offloading)
graph_capture_list = list(range(1, 129))
draft_model.initialize_cuda_graph(graph_capture_list)

accelerator = Accelerator()
dataloader = accelerator.prepare(dataloader)

output_file = get_output_filename(args)
if args.algorithm == 'stochastic':
simulation_stochastic(
target_model=target_model, draft_model=draft_model, dataloader=dataloader, output_file=output_file,
temp=args.temp, top_p=args.top_p, max_width=args.max_width, max_length=args.max_length,
)

elif args.algorithm == 'greedy':
simulation_greedy(
target_model=target_model, draft_model=draft_model, dataloader=dataloader, output_file=output_file,
max_width=args.max_width, max_length=args.max_length,
)
Loading

0 comments on commit 6314476

Please sign in to comment.