Skip to content

Commit

Permalink
added examples
Browse files Browse the repository at this point in the history
  • Loading branch information
MDK8888 committed Feb 23, 2024
1 parent c823189 commit cba4528
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 0 deletions.
55 changes: 55 additions & 0 deletions Examples/gpt-neo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import torch
from transformers import AutoTokenizer
from GPTFast.Core import gpt_fast
from GPTFast.Helpers import timed

torch._dynamo.reset()
os.environ["TOKENIZERS_PARALLELISM"] = "false"

device = "cuda" if torch.cuda.is_available() else "cpu"

def argmax_variation(self, probabilities:torch.Tensor, temperature:float = 1, k:int = 5):
# Apply temperature scaling
device = probabilities.device
scaled_probabilities = probabilities / temperature

# Ensure k is within a valid range
k = min(k, probabilities.size(-1))

# Get the indices of the top-k scaled probabilities along the specified dimension
top_k_indices = torch.topk(scaled_probabilities, k, dim=-1).indices

# Generate random indices for sampling
random_indices = torch.randint(0, k, (1,) * probabilities.dim()).to(device)

# Use gathered indices to get the final sampled token
sampled_token = top_k_indices.gather(-1, random_indices).to(device)

return sampled_token.unsqueeze(0)

def argmax(self, probabilities):
# Use argmax to get the token with the maximum probability
max_prob_index = torch.argmax(probabilities, dim=-1)
return max_prob_index.unsqueeze(0)

model_name = "EleutherAI/gpt-neo-1.3B"
draft_model_name = "EleutherAI/gpt-neo-125m"

tokenizer = AutoTokenizer.from_pretrained(model_name)
initial_string = "Write me a short story."
input_tokens = tokenizer.encode(initial_string, return_tensors="pt").to(device)

N_ITERS=10
MAX_TOKENS=50

gpt_fast_model = gpt_fast(model_name, draft_model_name=draft_model_name, sample_function=argmax)
gpt_fast_model.to(device)

fast_compile_times = []
for i in range(N_ITERS):
with torch.no_grad():
res, compile_time = timed(lambda: gpt_fast_model.generate(cur_tokens=input_tokens, max_tokens=MAX_TOKENS, speculate_k=6))
fast_compile_times.append(compile_time)
print(f"gpt fast eval time {i}: {compile_time}")
print("~" * 10)
55 changes: 55 additions & 0 deletions Examples/gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import torch
from transformers import AutoTokenizer
from GPTFast.Core import gpt_fast
from GPTFast.Helpers import timed

torch._dynamo.reset()
os.environ["TOKENIZERS_PARALLELISM"] = "false"

device = "cuda" if torch.cuda.is_available() else "cpu"

def argmax_variation(self, probabilities:torch.Tensor, temperature:float = 1, k:int = 5):
# Apply temperature scaling
device = probabilities.device
scaled_probabilities = probabilities / temperature

# Ensure k is within a valid range
k = min(k, probabilities.size(-1))

# Get the indices of the top-k scaled probabilities along the specified dimension
top_k_indices = torch.topk(scaled_probabilities, k, dim=-1).indices

# Generate random indices for sampling
random_indices = torch.randint(0, k, (1,) * probabilities.dim()).to(device)

# Use gathered indices to get the final sampled token
sampled_token = top_k_indices.gather(-1, random_indices).to(device)

return sampled_token.unsqueeze(0)

def argmax(self, probabilities):
# Use argmax to get the token with the maximum probability
max_prob_index = torch.argmax(probabilities, dim=-1)
return max_prob_index.unsqueeze(0)

model_name = "gpt2-xl"
draft_model_name = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)
initial_string = "Write me a short story."
input_tokens = tokenizer.encode(initial_string, return_tensors="pt").to(device)

N_ITERS=10
MAX_TOKENS=50

gpt_fast_model = gpt_fast(model_name, draft_model_name=draft_model_name, sample_function=argmax)
gpt_fast_model.to(device)

fast_compile_times = []
for i in range(N_ITERS):
with torch.no_grad():
res, compile_time = timed(lambda: gpt_fast_model.generate(cur_tokens=input_tokens, max_tokens=MAX_TOKENS, speculate_k=6))
fast_compile_times.append(compile_time)
print(f"gpt fast eval time {i}: {compile_time}")
print("~" * 10)
55 changes: 55 additions & 0 deletions Examples/opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import torch
from transformers import AutoTokenizer
from GPTFast.Core import gpt_fast
from GPTFast.Helpers import timed

torch._dynamo.reset()
os.environ["TOKENIZERS_PARALLELISM"] = "false"

device = "cuda" if torch.cuda.is_available() else "cpu"

def argmax_variation(self, probabilities:torch.Tensor, temperature:float = 1, k:int = 5):
# Apply temperature scaling
device = probabilities.device
scaled_probabilities = probabilities / temperature

# Ensure k is within a valid range
k = min(k, probabilities.size(-1))

# Get the indices of the top-k scaled probabilities along the specified dimension
top_k_indices = torch.topk(scaled_probabilities, k, dim=-1).indices

# Generate random indices for sampling
random_indices = torch.randint(0, k, (1,) * probabilities.dim()).to(device)

# Use gathered indices to get the final sampled token
sampled_token = top_k_indices.gather(-1, random_indices).to(device)

return sampled_token.unsqueeze(0)

def argmax(self, probabilities):
# Use argmax to get the token with the maximum probability
max_prob_index = torch.argmax(probabilities, dim=-1)
return max_prob_index.unsqueeze(0)

model_name = "facebook/opt-1.3b"
draft_model_name = "facebook/opt-125m"

tokenizer = AutoTokenizer.from_pretrained(model_name)
initial_string = "Write me a short story."
input_tokens = tokenizer.encode(initial_string, return_tensors="pt").to(device)

N_ITERS=10
MAX_TOKENS=50

gpt_fast_model = gpt_fast(model_name, draft_model_name=draft_model_name, sample_function=argmax)
gpt_fast_model.to(device)

fast_compile_times = []
for i in range(N_ITERS):
with torch.no_grad():
res, compile_time = timed(lambda: gpt_fast_model.generate(cur_tokens=input_tokens, max_tokens=MAX_TOKENS, speculate_k=6))
fast_compile_times.append(compile_time)
print(f"gpt fast eval time {i}: {compile_time}")
print("~" * 10)

0 comments on commit cba4528

Please sign in to comment.