forked from ngruver/llmtime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added latest gpt models, mistral local and mistral api
- Loading branch information
Bouzid MEDJDOUB
committed
Jan 27, 2024
1 parent
a483a43
commit ccf64c0
Showing
7 changed files
with
422 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import os | ||
import torch | ||
os.environ['OMP_NUM_THREADS'] = '4' | ||
import numpy as np | ||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
import openai | ||
openai.api_key = os.environ['OPENAI_API_KEY'] | ||
openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1") | ||
from data.serialize import SerializerSettings | ||
from models.utils import grid_iter | ||
from models.promptcast import get_promptcast_predictions_data | ||
from models.darts import get_arima_predictions_data | ||
from models.llmtime import get_llmtime_predictions_data | ||
from data.small_context import get_datasets | ||
from models.validation_likelihood_tuning import get_autotuned_predictions_data | ||
|
||
def plot_preds(train, test, pred_dict, model_name, show_samples=False): | ||
pred = pred_dict['median'] | ||
pred = pd.Series(pred, index=test.index) | ||
plt.figure(figsize=(8, 6), dpi=100) | ||
plt.plot(train) | ||
plt.plot(test, label='Truth', color='black') | ||
plt.plot(pred, label=model_name, color='purple') | ||
# shade 90% confidence interval | ||
samples = pred_dict['samples'] | ||
lower = np.quantile(samples, 0.05, axis=0) | ||
upper = np.quantile(samples, 0.95, axis=0) | ||
plt.fill_between(pred.index, lower, upper, alpha=0.3, color='purple') | ||
if show_samples: | ||
samples = pred_dict['samples'] | ||
# convert df to numpy array | ||
samples = samples.values if isinstance(samples, pd.DataFrame) else samples | ||
for i in range(min(10, samples.shape[0])): | ||
plt.plot(pred.index, samples[i], color='purple', alpha=0.3, linewidth=1) | ||
plt.legend(loc='upper left') | ||
if 'NLL/D' in pred_dict: | ||
nll = pred_dict['NLL/D'] | ||
if nll is not None: | ||
plt.text(0.03, 0.85, f'NLL/D: {nll:.2f}', transform=plt.gca().transAxes, bbox=dict(facecolor='white', alpha=0.5)) | ||
plt.show() | ||
|
||
|
||
|
||
print(torch.cuda.max_memory_allocated()) | ||
print() | ||
|
||
gpt4_hypers = dict( | ||
alpha=0.3, | ||
basic=True, | ||
temp=1.0, | ||
top_p=0.8, | ||
settings=SerializerSettings(base=10, prec=3, signed=True, time_sep=', ', bit_sep='', minus_sign='-') | ||
) | ||
|
||
mistral_api_hypers = dict( | ||
alpha=0.3, | ||
basic=True, | ||
temp=1.0, | ||
top_p=0.8, | ||
settings=SerializerSettings(base=10, prec=3, signed=True, time_sep=', ', bit_sep='', minus_sign='-') | ||
) | ||
|
||
gpt3_hypers = dict( | ||
temp=0.7, | ||
alpha=0.95, | ||
beta=0.3, | ||
basic=False, | ||
settings=SerializerSettings(base=10, prec=3, signed=True, half_bin_correction=True) | ||
) | ||
|
||
|
||
llma2_hypers = dict( | ||
temp=0.7, | ||
alpha=0.95, | ||
beta=0.3, | ||
basic=False, | ||
settings=SerializerSettings(base=10, prec=3, signed=True, half_bin_correction=True) | ||
) | ||
|
||
|
||
promptcast_hypers = dict( | ||
temp=0.7, | ||
settings=SerializerSettings(base=10, prec=0, signed=True, | ||
time_sep=', ', | ||
bit_sep='', | ||
plus_sign='', | ||
minus_sign='-', | ||
half_bin_correction=False, | ||
decimal_point='') | ||
) | ||
|
||
arima_hypers = dict(p=[12,30], d=[1,2], q=[0]) | ||
|
||
model_hypers = { | ||
'LLMTime GPT-3.5': {'model': 'gpt-3.5-turbo-instruct', **gpt3_hypers}, | ||
'LLMTime GPT-4': {'model': 'gpt-4', **gpt4_hypers}, | ||
'LLMTime GPT-3': {'model': 'text-davinci-003', **gpt3_hypers}, | ||
'PromptCast GPT-3': {'model': 'text-davinci-003', **promptcast_hypers}, | ||
'LLMA2': {'model': 'llama-7b', **llma2_hypers}, | ||
'mistral': {'model': 'mistral', **llma2_hypers}, | ||
'mistral-api-tiny': {'model': 'mistral-api-tiny', **mistral_api_hypers}, | ||
'mistral-api-small': {'model': 'mistral-api-tiny', **mistral_api_hypers}, | ||
'mistral-api-medium': {'model': 'mistral-api-tiny', **mistral_api_hypers}, | ||
'ARIMA': arima_hypers, | ||
|
||
} | ||
|
||
|
||
model_predict_fns = { | ||
#'LLMA2': get_llmtime_predictions_data, | ||
#'mistral': get_llmtime_predictions_data, | ||
#'LLMTime GPT-4': get_llmtime_predictions_data, | ||
'mistral-api-tiny': get_llmtime_predictions_data | ||
} | ||
|
||
|
||
model_names = list(model_predict_fns.keys()) | ||
|
||
datasets = get_datasets() | ||
ds_name = 'AirPassengersDataset' | ||
|
||
|
||
data = datasets[ds_name] | ||
train, test = data # or change to your own data | ||
out = {} | ||
|
||
for model in model_names: # GPT-4 takes a about a minute to run | ||
model_hypers[model].update({'dataset_name': ds_name}) # for promptcast | ||
hypers = list(grid_iter(model_hypers[model])) | ||
num_samples = 10 | ||
pred_dict = get_autotuned_predictions_data(train, test, hypers, num_samples, model_predict_fns[model], verbose=False, parallel=False) | ||
out[model] = pred_dict | ||
plot_preds(train, test, pred_dict, model, show_samples=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import torch | ||
import numpy as np | ||
from jax import grad,vmap | ||
from tqdm import tqdm | ||
import argparse | ||
from transformers import ( | ||
AutoModelForCausalLM, | ||
AutoTokenizer, | ||
) | ||
from data.serialize import serialize_arr, deserialize_str, SerializerSettings | ||
|
||
DEFAULT_EOS_TOKEN = "</s>" | ||
DEFAULT_BOS_TOKEN = "<s>" | ||
DEFAULT_UNK_TOKEN = "<unk>" | ||
|
||
loaded = {} | ||
|
||
def get_tokenizer(): | ||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") | ||
special_tokens_dict = dict() | ||
if tokenizer.eos_token is None: | ||
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN | ||
if tokenizer.bos_token is None: | ||
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN | ||
if tokenizer.unk_token is None: | ||
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN | ||
tokenizer.add_special_tokens(special_tokens_dict) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
return tokenizer | ||
|
||
def get_model_and_tokenizer(model_name, cache_model=False): | ||
if model_name in loaded: | ||
return loaded[model_name] | ||
tokenizer = get_tokenizer() | ||
|
||
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1",device_map="cpu") | ||
model.eval() | ||
if cache_model: | ||
loaded[model_name] = model, tokenizer | ||
return model, tokenizer | ||
|
||
def tokenize_fn(str, model): | ||
tokenizer = get_tokenizer() | ||
return tokenizer(str) | ||
|
||
def mistral_nll_fn(model, input_arr, target_arr, settings:SerializerSettings, transform, count_seps=True, temp=1, cache_model=True): | ||
""" Returns the NLL/dimension (log base e) of the target array (continuous) according to the LM | ||
conditioned on the input array. Applies relevant log determinant for transforms and | ||
converts from discrete NLL of the LLM to continuous by assuming uniform within the bins. | ||
inputs: | ||
input_arr: (n,) context array | ||
target_arr: (n,) ground truth array | ||
cache_model: whether to cache the model and tokenizer for faster repeated calls | ||
Returns: NLL/D | ||
""" | ||
model, tokenizer = get_model_and_tokenizer(model, cache_model=cache_model) | ||
|
||
input_str = serialize_arr(vmap(transform)(input_arr), settings) | ||
target_str = serialize_arr(vmap(transform)(target_arr), settings) | ||
full_series = input_str + target_str | ||
|
||
batch = tokenizer( | ||
[full_series], | ||
return_tensors="pt", | ||
add_special_tokens=True | ||
) | ||
batch = {k: v.cuda() for k, v in batch.items()} | ||
|
||
with torch.no_grad(): | ||
out = model(**batch) | ||
|
||
good_tokens_str = list("0123456789" + settings.time_sep) | ||
good_tokens = [tokenizer.convert_tokens_to_ids(token) for token in good_tokens_str] | ||
bad_tokens = [i for i in range(len(tokenizer)) if i not in good_tokens] | ||
out['logits'][:,:,bad_tokens] = -100 | ||
|
||
input_ids = batch['input_ids'][0][1:] | ||
input_ids = input_ids.to('cpu') | ||
logprobs = torch.nn.functional.log_softmax(out['logits'], dim=-1)[0][:-1] | ||
logprobs = logprobs[torch.arange(len(input_ids)), input_ids].cpu().numpy() | ||
|
||
|
||
tokens = tokenizer.batch_decode( | ||
input_ids, | ||
skip_special_tokens=False, | ||
clean_up_tokenization_spaces=False | ||
) | ||
|
||
input_len = len(tokenizer([input_str], return_tensors="pt",)['input_ids'][0]) | ||
input_len = input_len - 2 # remove the BOS token | ||
|
||
logprobs = logprobs[input_len:] | ||
tokens = tokens[input_len:] | ||
BPD = -logprobs.sum()/len(target_arr) | ||
|
||
#print("BPD unadjusted:", -logprobs.sum()/len(target_arr), "BPD adjusted:", BPD) | ||
# log p(x) = log p(token) - log bin_width = log p(token) + prec * log base | ||
transformed_nll = BPD - settings.prec*np.log(settings.base) | ||
avg_logdet_dydx = np.log(vmap(grad(transform))(target_arr)).mean() | ||
return transformed_nll-avg_logdet_dydx | ||
|
||
def mistral_completion_fn( | ||
model, | ||
input_str, | ||
steps, | ||
settings, | ||
batch_size=5, | ||
num_samples=20, | ||
temp=0.9, | ||
top_p=0.9, | ||
cache_model=True | ||
): | ||
avg_tokens_per_step = len(tokenize_fn(input_str, model)['input_ids']) / len(input_str.split(settings.time_sep)) | ||
max_tokens = int(avg_tokens_per_step*steps) | ||
|
||
model, tokenizer = get_model_and_tokenizer(model, cache_model=cache_model) | ||
|
||
gen_strs = [] | ||
for _ in tqdm(range(num_samples // batch_size)): | ||
batch = tokenizer( | ||
[input_str], | ||
return_tensors="pt", | ||
) | ||
|
||
batch = {k: v.repeat(batch_size, 1) for k, v in batch.items()} | ||
batch = {k: v.cpu() for k, v in batch.items()} | ||
num_input_ids = batch['input_ids'].shape[1] | ||
|
||
good_tokens_str = list("0123456789" + settings.time_sep) | ||
good_tokens = [tokenizer.convert_tokens_to_ids(token) for token in good_tokens_str] | ||
# good_tokens += [tokenizer.eos_token_id] | ||
bad_tokens = [i for i in range(len(tokenizer)) if i not in good_tokens] | ||
|
||
generate_ids = model.generate( | ||
**batch, | ||
do_sample=True, | ||
max_new_tokens=max_tokens, | ||
temperature=temp, | ||
top_p=top_p, | ||
bad_words_ids=[[t] for t in bad_tokens], | ||
renormalize_logits=True, | ||
) | ||
gen_strs += tokenizer.batch_decode( | ||
generate_ids[:, num_input_ids:], | ||
skip_special_tokens=True, | ||
clean_up_tokenization_spaces=False | ||
) | ||
return gen_strs |
Oops, something went wrong.