Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
howard-yen committed Feb 26, 2024
1 parent 6783688 commit 3d1d431
Showing 1 changed file with 2 additions and 27 deletions.
29 changes: 2 additions & 27 deletions eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def main():
config.train_batch_mode = model_args.train_batch_mode
config.offload_hidden_states = model_args.offload_hidden_states

elif model_args.model_class == "vanilla" or model_args.model_class == "unlimiformer":
logger.info("Using vanilla/unlimiformer Llama")
elif model_args.model_class == "vanilla":
logger.info("Using vanilla Llama")
model_cls = LlamaForCausalLM
collator = ContextDataCollator()

Expand Down Expand Up @@ -169,31 +169,6 @@ def main():
# model = model.to(device) ## shouldn't need this with device map
logger.info(f"Loaded model: {model}")

if model_args.model_class == "unlimiformer":
# append the absolute path to the unlimiformer src directory to system path
sys.path.append(os.path.join(os.getcwd(), "unlimiformer/src"))
from unlimiformer import Unlimiformer
from usage import UnlimiformerArguments
defaults = UnlimiformerArguments()
unlimiformer_kwargs = {
'layer_begin': defaults.layer_begin,
'layer_end': defaults.layer_end,
'unlimiformer_head_num': defaults.unlimiformer_head_num,
'exclude_attention': defaults.unlimiformer_exclude,
'chunk_overlap': defaults.unlimiformer_chunk_overlap,
'model_encoder_max_len': defaults.unlimiformer_chunk_size,
'verbose': defaults.unlimiformer_verbose, 'tokenizer': tokenizer,
'unlimiformer_training': defaults.unlimiformer_training,
'use_datastore': defaults.use_datastore,
'flat_index': defaults.flat_index,
'test_datastore': defaults.test_datastore,
'reconstruct_embeddings': defaults.reconstruct_embeddings,
'gpu_datastore': defaults.gpu_datastore,
'gpu_index': defaults.gpu_index
}
logger.info("Converting model to unlimiformer")
model = Unlimiformer.convert_model(model, **unlimiformer_kwargs)

if model_args.model_class == "streamingllm" and model_args.enable_positional_shift:
from streaming_llm.pos_shift.modify_llama import enable_llama_pos_shift_attention
enable_llama_pos_shift_attention(model)
Expand Down

0 comments on commit 3d1d431

Please sign in to comment.