From 3d1d431f0ccbf936ff4be13eeff2eb9e06feba03 Mon Sep 17 00:00:00 2001 From: Howard Yen Date: Mon, 26 Feb 2024 11:11:49 -0500 Subject: [PATCH] minor fix --- eval_lm.py | 29 ++--------------------------- 1 file changed, 2 insertions(+), 27 deletions(-) diff --git a/eval_lm.py b/eval_lm.py index ed874a2..ed6916a 100644 --- a/eval_lm.py +++ b/eval_lm.py @@ -125,8 +125,8 @@ def main(): config.train_batch_mode = model_args.train_batch_mode config.offload_hidden_states = model_args.offload_hidden_states - elif model_args.model_class == "vanilla" or model_args.model_class == "unlimiformer": - logger.info("Using vanilla/unlimiformer Llama") + elif model_args.model_class == "vanilla": + logger.info("Using vanilla Llama") model_cls = LlamaForCausalLM collator = ContextDataCollator() @@ -169,31 +169,6 @@ def main(): # model = model.to(device) ## shouldn't need this with device map logger.info(f"Loaded model: {model}") - if model_args.model_class == "unlimiformer": - # append the absolute path to the unlimiformer src directory to system path - sys.path.append(os.path.join(os.getcwd(), "unlimiformer/src")) - from unlimiformer import Unlimiformer - from usage import UnlimiformerArguments - defaults = UnlimiformerArguments() - unlimiformer_kwargs = { - 'layer_begin': defaults.layer_begin, - 'layer_end': defaults.layer_end, - 'unlimiformer_head_num': defaults.unlimiformer_head_num, - 'exclude_attention': defaults.unlimiformer_exclude, - 'chunk_overlap': defaults.unlimiformer_chunk_overlap, - 'model_encoder_max_len': defaults.unlimiformer_chunk_size, - 'verbose': defaults.unlimiformer_verbose, 'tokenizer': tokenizer, - 'unlimiformer_training': defaults.unlimiformer_training, - 'use_datastore': defaults.use_datastore, - 'flat_index': defaults.flat_index, - 'test_datastore': defaults.test_datastore, - 'reconstruct_embeddings': defaults.reconstruct_embeddings, - 'gpu_datastore': defaults.gpu_datastore, - 'gpu_index': defaults.gpu_index - } - logger.info("Converting model to unlimiformer") - model = Unlimiformer.convert_model(model, **unlimiformer_kwargs) - if model_args.model_class == "streamingllm" and model_args.enable_positional_shift: from streaming_llm.pos_shift.modify_llama import enable_llama_pos_shift_attention enable_llama_pos_shift_attention(model)