Skip to content

Commit

Permalink
update llama
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Oct 24, 2023
1 parent c4ca1f6 commit 273fd7a
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ by Nate Gruver, Marc Finzi, Shikai Qiu and Andrew Gordon Wilson (NeurIPS 2023).
## 🛠 Installation
Run the following command to install all dependencies in a conda environment named `llmtime`. Change the cuda version for torch if you don't have cuda 11.8.
```
sh install.sh
source install.sh
```
After installation, activate the environment with
```
Expand Down
2 changes: 2 additions & 0 deletions install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ pip install gpytorch
pip install transformers
pip install datasets
pip install multiprocess
pip install SentencePiece
pip install accelerate
conda deactivate llmtime
31 changes: 17 additions & 14 deletions models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"

loaded = {}

def llama2_model_string(model_size, chat):
chat = "chat-" if chat else ""
return f"meta-llama/Llama-2-{model_size.lower()}-{chat}hf"
Expand Down Expand Up @@ -55,38 +57,42 @@ def get_tokenizer(model):

return tokenizer

def get_model_and_tokenizer(model):
name_parts = model.split("-")
def get_model_and_tokenizer(model_name, cache_model=False):
if model_name in loaded:
return loaded[model_name]
name_parts = model_name.split("-")
model_size = name_parts[0]
chat = len(name_parts) > 1

assert model_size in ["7b", "13b", "70b"]

tokenizer = get_tokenizer(model)
tokenizer = get_tokenizer(model_name)

model = LlamaForCausalLM.from_pretrained(
llama2_model_string(model_size, chat),
device_map="auto",
torch_dtype=torch.float16,
)
model.eval()

if cache_model:
loaded[model_name] = model, tokenizer
return model, tokenizer

def tokenize_fn(str, model):
tokenizer = get_tokenizer(model)
return tokenizer(str)

def llama_nll_fn(model, input_arr, target_arr, settings:SerializerSettings, transform, count_seps=True, temp=1):
def llama_nll_fn(model, input_arr, target_arr, settings:SerializerSettings, transform, count_seps=True, temp=1, cache_model=True):
""" Returns the NLL/dimension (log base e) of the target array (continuous) according to the LM
conditioned on the input array. Applies relevant log determinant for transforms and
converts from discrete NLL of the LLM to continuous by assuming uniform within the bins.
inputs:
input_arr: (n,) context array
target_arr: (n,) ground truth array
cache_model: whether to cache the model and tokenizer for faster repeated calls
Returns: NLL/D
"""
model, tokenizer = get_model_and_tokenizer(model)
model, tokenizer = get_model_and_tokenizer(model, cache_model=cache_model)

input_str = serialize_arr(vmap(transform)(input_arr), settings)
target_str = serialize_arr(vmap(transform)(target_arr), settings)
Expand Down Expand Up @@ -140,10 +146,10 @@ def llama_completion_fn(
temp=0.9,
top_p=0.9,
):
avg_tokens_per_step = len(input_str)/steps
avg_tokens_per_step = len(tokenize_fn(input_str, model)['input_ids']) / len(input_str.split(settings.time_sep))
max_tokens = int(avg_tokens_per_step*steps)

model, tokenizer = get_model_and_tokenizer(model)
model, tokenizer = get_model_and_tokenizer(model, cache_model=True)

gen_strs = []
for _ in tqdm(range(num_samples // batch_size)):
Expand All @@ -154,10 +160,11 @@ def llama_completion_fn(

batch = {k: v.repeat(batch_size, 1) for k, v in batch.items()}
batch = {k: v.cuda() for k, v in batch.items()}
num_input_ids = batch['input_ids'].shape[1]

good_tokens_str = list("0123456789" + settings.time_sep)
good_tokens = [tokenizer.convert_tokens_to_ids(token) for token in good_tokens_str]
good_tokens += [tokenizer.eos_token_id]
# good_tokens += [tokenizer.eos_token_id]
bad_tokens = [i for i in range(len(tokenizer)) if i not in good_tokens]

generate_ids = model.generate(
Expand All @@ -169,13 +176,9 @@ def llama_completion_fn(
bad_words_ids=[[t] for t in bad_tokens],
renormalize_logits=True,
)

gen_strs += tokenizer.batch_decode(
generate_ids,
generate_ids[:, num_input_ids:],
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)

# gen_strs = [x.replace(input_str, '').strip() for x in gen_strs]

return gen_strs
6 changes: 4 additions & 2 deletions models/llmtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from models.llms import completion_fns, nll_fns, tokenization_fns, context_lengths

STEP_MULTIPLIER = 1.1
STEP_MULTIPLIER = 1.2

@dataclass
class Scaler:
Expand Down Expand Up @@ -45,6 +45,8 @@ def inv_transform(x):
else:
min_ = np.min(history) - beta*(np.max(history)-np.min(history))
q = np.quantile(history-min_, alpha)
if q == 0:
q = 1
def transform(x):
return (x - min_) / q
def inv_transform(x):
Expand Down Expand Up @@ -239,7 +241,7 @@ def get_llmtime_predictions_data(train, test, model, settings, num_samples=10, t
'completions_list': completions_list,
'input_strs': input_strs,
}
# # Compute NLL/D on the true test series conditioned on the (truncated) input series
# Compute NLL/D on the true test series conditioned on the (truncated) input series
# if nll_fn is not None:
# BPDs = [nll_fn(input_arr=input_arrs[i], target_arr=test[i].values, settings=settings, transform=scalers[i].transform, count_seps=True, temp=temp) for i in range(len(train))]
# out_dict['NLL/D'] = np.mean(BPDs)
Expand Down
16 changes: 16 additions & 0 deletions precomputed_outputs/deterministic_csvs/monash.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
Dataset,SES,Theta,TBATS,ETS,(DHR-)ARIMA,PR,CatBoost,FFNN,DeepAR,N-BEATS,WaveNet,Transformer,Last Value,GPT-3,LLaMA-2
tourism_yearly,95579.23,90653.6,94121.08,94818.89,95033.24,82682.97,79567.22,79593.22,71471.29,70951.8,69905.47,74316.52,99456.0540551959,140081.78090930352,98285.7930257561
tourism_quarterly,15014.19,7656.49,9972.42,8925.52,10475.47,9092.58,10267.97,8981.04,9511.37,8640.56,9137.12,9521.67,15845.100306204946,14121.091945053764,9311.977175664577
tourism_monthly,5302.1,2069.96,2940.08,2004.51,2536.77,2187.28,2537.04,2022.21,1871.69,2003.02,2095.13,2146.98,5636.83029361023,4724.946037531374,3145.478392676293
cif_2016,581875.97,714818.58,855578.4,642421.42,469059.49,563205.57,603551.3,1495923.44,3200418.0,679034.8,5998224.62,4057973.04,386526.3670424068,715086.3357994701,684057.8733782925
australian_electricity_demand,659.6,665.04,370.74,1282.99,1045.92,247.18,241.77,258.76,302.41,213.83,227.5,231.45,659.600688770839,459.96580780843107,560.4788875748806
pedestrian_counts,170.87,170.94,222.38,216.5,635.16,44.18,43.41,46.41,44.78,66.84,46.46,47.29,170.8838383838384,70.20704154040405,65.92206279671719
weather,2.24,2.51,2.3,2.35,2.45,8.17,2.51,2.09,2.02,2.34,2.29,2.03,2.362190193902301,2.3233456922229125,2.0936375345546554
nn5_weekly,15.66,15.3,14.98,15.7,15.38,14.94,15.29,15.02,14.69,14.19,19.34,20.34,16.708553516113007,15.919500032916373,15.604412195386313
solar_weekly,1202.39,1210.83,908.65,1131.01,839.88,1044.98,1513.49,1050.84,721.59,1172.64,1996.89,576.35,1729.4092503457175,2049.095696041108,1457.1885582200116
fred_md,2798.22,3492.84,1989.97,2041.42,2957.11,8921.94,2475.68,2339.57,4264.36,2557.8,2508.4,4666.04,2825.672461360778,2013.4892331450442,1781.4105318251325
traffic_weekly,1.12,1.13,1.17,1.14,1.22,1.13,1.17,1.15,1.18,1.11,1.2,1.42,1.1855844384623289,1.1736949748405259,1.1493040136761261
hospital,21.76,18.54,17.43,17.97,19.6,19.24,19.17,22.86,18.25,20.18,19.35,36.19,24.06573229030856,24.623895849630596,22.74912730008692
covid_deaths,353.71,321.32,96.29,85.59,85.77,347.98,475.15,144.14,201.98,158.81,1049.48,408.66,353.70939849624057,304.68550233082703,66.14112923433585
saugeenday,21.5,21.49,22.26,30.69,22.38,25.24,21.28,22.98,23.51,27.92,22.17,28.06,21.496667098999023,28.636295507014577,23.009772120148344
us_births,1192.2,586.93,399.0,419.73,526.33,574.93,441.7,557.87,424.93,422.0,504.4,452.87,1152.6666666666667,459.43626666666665,638.8182753333334

0 comments on commit 273fd7a

Please sign in to comment.