Skip to content

Commit

Permalink
Added latest gpt models, mistral local and mistral api
Browse files Browse the repository at this point in the history
  • Loading branch information
Bouzid MEDJDOUB committed Jan 27, 2024
1 parent a483a43 commit ccf64c0
Show file tree
Hide file tree
Showing 7 changed files with 422 additions and 2 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ Add your openai api key to `~/.bashrc` with
```
echo "export OPENAI_API_KEY=<your key>" >> ~/.bashrc
```
Add your mistral api key to `~/.bashrc` with
```
echo "export MISTRAL_KEY=<your key>" >> ~/.bashrc
```

Finally, if you have a diffferent OpenAI API base, change it in your `~/.bashrc` with
```
Expand Down
134 changes: 134 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import os
import torch
os.environ['OMP_NUM_THREADS'] = '4'
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import openai
openai.api_key = os.environ['OPENAI_API_KEY']
openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
from data.serialize import SerializerSettings
from models.utils import grid_iter
from models.promptcast import get_promptcast_predictions_data
from models.darts import get_arima_predictions_data
from models.llmtime import get_llmtime_predictions_data
from data.small_context import get_datasets
from models.validation_likelihood_tuning import get_autotuned_predictions_data

def plot_preds(train, test, pred_dict, model_name, show_samples=False):
pred = pred_dict['median']
pred = pd.Series(pred, index=test.index)
plt.figure(figsize=(8, 6), dpi=100)
plt.plot(train)
plt.plot(test, label='Truth', color='black')
plt.plot(pred, label=model_name, color='purple')
# shade 90% confidence interval
samples = pred_dict['samples']
lower = np.quantile(samples, 0.05, axis=0)
upper = np.quantile(samples, 0.95, axis=0)
plt.fill_between(pred.index, lower, upper, alpha=0.3, color='purple')
if show_samples:
samples = pred_dict['samples']
# convert df to numpy array
samples = samples.values if isinstance(samples, pd.DataFrame) else samples
for i in range(min(10, samples.shape[0])):
plt.plot(pred.index, samples[i], color='purple', alpha=0.3, linewidth=1)
plt.legend(loc='upper left')
if 'NLL/D' in pred_dict:
nll = pred_dict['NLL/D']
if nll is not None:
plt.text(0.03, 0.85, f'NLL/D: {nll:.2f}', transform=plt.gca().transAxes, bbox=dict(facecolor='white', alpha=0.5))
plt.show()



print(torch.cuda.max_memory_allocated())
print()

gpt4_hypers = dict(
alpha=0.3,
basic=True,
temp=1.0,
top_p=0.8,
settings=SerializerSettings(base=10, prec=3, signed=True, time_sep=', ', bit_sep='', minus_sign='-')
)

mistral_api_hypers = dict(
alpha=0.3,
basic=True,
temp=1.0,
top_p=0.8,
settings=SerializerSettings(base=10, prec=3, signed=True, time_sep=', ', bit_sep='', minus_sign='-')
)

gpt3_hypers = dict(
temp=0.7,
alpha=0.95,
beta=0.3,
basic=False,
settings=SerializerSettings(base=10, prec=3, signed=True, half_bin_correction=True)
)


llma2_hypers = dict(
temp=0.7,
alpha=0.95,
beta=0.3,
basic=False,
settings=SerializerSettings(base=10, prec=3, signed=True, half_bin_correction=True)
)


promptcast_hypers = dict(
temp=0.7,
settings=SerializerSettings(base=10, prec=0, signed=True,
time_sep=', ',
bit_sep='',
plus_sign='',
minus_sign='-',
half_bin_correction=False,
decimal_point='')
)

arima_hypers = dict(p=[12,30], d=[1,2], q=[0])

model_hypers = {
'LLMTime GPT-3.5': {'model': 'gpt-3.5-turbo-instruct', **gpt3_hypers},
'LLMTime GPT-4': {'model': 'gpt-4', **gpt4_hypers},
'LLMTime GPT-3': {'model': 'text-davinci-003', **gpt3_hypers},
'PromptCast GPT-3': {'model': 'text-davinci-003', **promptcast_hypers},
'LLMA2': {'model': 'llama-7b', **llma2_hypers},
'mistral': {'model': 'mistral', **llma2_hypers},
'mistral-api-tiny': {'model': 'mistral-api-tiny', **mistral_api_hypers},
'mistral-api-small': {'model': 'mistral-api-tiny', **mistral_api_hypers},
'mistral-api-medium': {'model': 'mistral-api-tiny', **mistral_api_hypers},
'ARIMA': arima_hypers,

}


model_predict_fns = {
#'LLMA2': get_llmtime_predictions_data,
#'mistral': get_llmtime_predictions_data,
#'LLMTime GPT-4': get_llmtime_predictions_data,
'mistral-api-tiny': get_llmtime_predictions_data
}


model_names = list(model_predict_fns.keys())

datasets = get_datasets()
ds_name = 'AirPassengersDataset'


data = datasets[ds_name]
train, test = data # or change to your own data
out = {}

for model in model_names: # GPT-4 takes a about a minute to run
model_hypers[model].update({'dataset_name': ds_name}) # for promptcast
hypers = list(grid_iter(model_hypers[model]))
num_samples = 10
pred_dict = get_autotuned_predictions_data(train, test, hypers, num_samples, model_predict_fns[model], verbose=False, parallel=False)
out[model] = pred_dict
plot_preds(train, test, pred_dict, model, show_samples=True)
1 change: 1 addition & 0 deletions install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ pip install multiprocess
pip install SentencePiece
pip install accelerate
pip install gdown
pip install mistralai #for mistral models
conda deactivate
4 changes: 2 additions & 2 deletions models/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def gpt_completion_fn(model, input_str, steps, settings, num_samples, temp):
allowed_tokens = [settings.bit_sep + str(i) for i in range(settings.base)]
allowed_tokens += [settings.time_sep, settings.plus_sign, settings.minus_sign]
allowed_tokens = [t for t in allowed_tokens if len(t) > 0] # remove empty tokens like an implicit plus sign
if (model not in ['gpt-3.5-turbo','gpt-4']): # logit bias not supported for chat models
if (model not in ['gpt-3.5-turbo','gpt-4','gpt-4-1106-preview']): # logit bias not supported for chat models
logit_bias = {id: 30 for id in get_allowed_ids(allowed_tokens, model)}
if model in ['gpt-3.5-turbo','gpt-4']:
if model in ['gpt-3.5-turbo','gpt-4','gpt-4-1106-preview']:
chatgpt_sys_message = "You are a helpful assistant that performs time series predictions. The user will provide a sequence and you will predict the remaining sequence. The sequence is represented by decimal strings separated by commas."
extra_input = "Please continue the following sequence without producing any additional text. Do not say anything like 'the next terms in the sequence are', just return the numbers. Sequence:\n"
response = openai.ChatCompletion.create(
Expand Down
25 changes: 25 additions & 0 deletions models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
from models.llama import llama_completion_fn, llama_nll_fn
from models.llama import tokenize_fn as llama_tokenize_fn

from models.mistral import mistral_completion_fn, mistral_nll_fn
from models.mistral import tokenize_fn as mistral_tokenize_fn

from models.mistral_api import mistral_api_completion_fn, mistral_api_nll_fn
from models.mistral_api import tokenize_fn as mistral_api_tokenize_fn


# Required: Text completion function for each model
# -----------------------------------------------
# Each model is mapped to a function that samples text completions.
Expand All @@ -21,7 +28,12 @@
completion_fns = {
'text-davinci-003': partial(gpt_completion_fn, model='text-davinci-003'),
'gpt-4': partial(gpt_completion_fn, model='gpt-4'),
'gpt-4-1106-preview':partial(gpt_completion_fn, model='gpt-4-1106-preview'),
'gpt-3.5-turbo-instruct': partial(gpt_completion_fn, model='gpt-3.5-turbo-instruct'),
'mistral': partial(mistral_completion_fn, model='mistral'),
'mistral-api-tiny': partial(mistral_api_completion_fn, model='mistral-tiny'),
'mistral-api-small': partial(mistral_api_completion_fn, model='mistral-small'),
'mistral-api-medium': partial(mistral_api_completion_fn, model='mistral-medium'),
'llama-7b': partial(llama_completion_fn, model='7b'),
'llama-13b': partial(llama_completion_fn, model='13b'),
'llama-70b': partial(llama_completion_fn, model='70b'),
Expand Down Expand Up @@ -49,6 +61,11 @@
# - float: Computed NLL per dimension for p(target_arr | input_arr).
nll_fns = {
'text-davinci-003': partial(gpt_nll_fn, model='text-davinci-003'),
'mistral': partial(mistral_nll_fn, model='mistral'),
'mistral-api-tiny': partial(mistral_api_nll_fn, model='mistral-tiny'),
'mistral-api-small': partial(mistral_api_nll_fn, model='mistral-small'),
'mistral-api-medium': partial(mistral_api_nll_fn, model='mistral-medium'),
'llama-7b': partial(llama_completion_fn, model='7b'),
'llama-7b': partial(llama_nll_fn, model='7b'),
'llama-13b': partial(llama_nll_fn, model='13b'),
'llama-70b': partial(llama_nll_fn, model='70b'),
Expand All @@ -67,6 +84,10 @@
tokenization_fns = {
'text-davinci-003': partial(gpt_tokenize_fn, model='text-davinci-003'),
'gpt-3.5-turbo-instruct': partial(gpt_tokenize_fn, model='gpt-3.5-turbo-instruct'),
'mistral': partial(mistral_tokenize_fn, model='mistral'),
'mistral-api-tiny': partial(mistral_api_tokenize_fn, model='mistral-tiny'),
'mistral-api-small': partial(mistral_api_tokenize_fn, model='mistral-small'),
'mistral-api-medium': partial(mistral_api_tokenize_fn, model='mistral-medium'),
'llama-7b': partial(llama_tokenize_fn, model='7b'),
'llama-13b': partial(llama_tokenize_fn, model='13b'),
'llama-70b': partial(llama_tokenize_fn, model='70b'),
Expand All @@ -79,6 +100,10 @@
context_lengths = {
'text-davinci-003': 4097,
'gpt-3.5-turbo-instruct': 4097,
'mistral-api-tiny': 4097,
'mistral-api-small': 4097,
'mistral-api-medium': 4097,
'mistral': 4096,
'llama-7b': 4096,
'llama-13b': 4096,
'llama-70b': 4096,
Expand Down
148 changes: 148 additions & 0 deletions models/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import torch
import numpy as np
from jax import grad,vmap
from tqdm import tqdm
import argparse
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
from data.serialize import serialize_arr, deserialize_str, SerializerSettings

DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"

loaded = {}

def get_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
special_tokens_dict = dict()
if tokenizer.eos_token is None:
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
if tokenizer.bos_token is None:
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
if tokenizer.unk_token is None:
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
tokenizer.add_special_tokens(special_tokens_dict)
tokenizer.pad_token = tokenizer.eos_token
return tokenizer

def get_model_and_tokenizer(model_name, cache_model=False):
if model_name in loaded:
return loaded[model_name]
tokenizer = get_tokenizer()

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1",device_map="cpu")
model.eval()
if cache_model:
loaded[model_name] = model, tokenizer
return model, tokenizer

def tokenize_fn(str, model):
tokenizer = get_tokenizer()
return tokenizer(str)

def mistral_nll_fn(model, input_arr, target_arr, settings:SerializerSettings, transform, count_seps=True, temp=1, cache_model=True):
""" Returns the NLL/dimension (log base e) of the target array (continuous) according to the LM
conditioned on the input array. Applies relevant log determinant for transforms and
converts from discrete NLL of the LLM to continuous by assuming uniform within the bins.
inputs:
input_arr: (n,) context array
target_arr: (n,) ground truth array
cache_model: whether to cache the model and tokenizer for faster repeated calls
Returns: NLL/D
"""
model, tokenizer = get_model_and_tokenizer(model, cache_model=cache_model)

input_str = serialize_arr(vmap(transform)(input_arr), settings)
target_str = serialize_arr(vmap(transform)(target_arr), settings)
full_series = input_str + target_str

batch = tokenizer(
[full_series],
return_tensors="pt",
add_special_tokens=True
)
batch = {k: v.cuda() for k, v in batch.items()}

with torch.no_grad():
out = model(**batch)

good_tokens_str = list("0123456789" + settings.time_sep)
good_tokens = [tokenizer.convert_tokens_to_ids(token) for token in good_tokens_str]
bad_tokens = [i for i in range(len(tokenizer)) if i not in good_tokens]
out['logits'][:,:,bad_tokens] = -100

input_ids = batch['input_ids'][0][1:]
input_ids = input_ids.to('cpu')
logprobs = torch.nn.functional.log_softmax(out['logits'], dim=-1)[0][:-1]
logprobs = logprobs[torch.arange(len(input_ids)), input_ids].cpu().numpy()


tokens = tokenizer.batch_decode(
input_ids,
skip_special_tokens=False,
clean_up_tokenization_spaces=False
)

input_len = len(tokenizer([input_str], return_tensors="pt",)['input_ids'][0])
input_len = input_len - 2 # remove the BOS token

logprobs = logprobs[input_len:]
tokens = tokens[input_len:]
BPD = -logprobs.sum()/len(target_arr)

#print("BPD unadjusted:", -logprobs.sum()/len(target_arr), "BPD adjusted:", BPD)
# log p(x) = log p(token) - log bin_width = log p(token) + prec * log base
transformed_nll = BPD - settings.prec*np.log(settings.base)
avg_logdet_dydx = np.log(vmap(grad(transform))(target_arr)).mean()
return transformed_nll-avg_logdet_dydx

def mistral_completion_fn(
model,
input_str,
steps,
settings,
batch_size=5,
num_samples=20,
temp=0.9,
top_p=0.9,
cache_model=True
):
avg_tokens_per_step = len(tokenize_fn(input_str, model)['input_ids']) / len(input_str.split(settings.time_sep))
max_tokens = int(avg_tokens_per_step*steps)

model, tokenizer = get_model_and_tokenizer(model, cache_model=cache_model)

gen_strs = []
for _ in tqdm(range(num_samples // batch_size)):
batch = tokenizer(
[input_str],
return_tensors="pt",
)

batch = {k: v.repeat(batch_size, 1) for k, v in batch.items()}
batch = {k: v.cpu() for k, v in batch.items()}
num_input_ids = batch['input_ids'].shape[1]

good_tokens_str = list("0123456789" + settings.time_sep)
good_tokens = [tokenizer.convert_tokens_to_ids(token) for token in good_tokens_str]
# good_tokens += [tokenizer.eos_token_id]
bad_tokens = [i for i in range(len(tokenizer)) if i not in good_tokens]

generate_ids = model.generate(
**batch,
do_sample=True,
max_new_tokens=max_tokens,
temperature=temp,
top_p=top_p,
bad_words_ids=[[t] for t in bad_tokens],
renormalize_logits=True,
)
gen_strs += tokenizer.batch_decode(
generate_ids[:, num_input_ids:],
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return gen_strs
Loading

0 comments on commit ccf64c0

Please sign in to comment.