diff --git a/generate.py b/generate.py index 61cf8987..42fc777e 100644 --- a/generate.py +++ b/generate.py @@ -10,6 +10,9 @@ tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") +BASE_MODEL = "decapoda-research/llama-7b-hf" +LORA_WEIGHTS = "tloen/alpaca-lora-7b" + if torch.cuda.is_available(): device = "cuda" else: @@ -23,33 +26,31 @@ if device == "cuda": model = LlamaForCausalLM.from_pretrained( - "decapoda-research/llama-7b-hf", + "chavinlo/alpaca-native", load_in_8bit=True, torch_dtype=torch.float16, device_map="auto", ) - model = PeftModel.from_pretrained( - model, "tloen/alpaca-lora-7b", torch_dtype=torch.float16 - ) + # model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16) elif device == "mps": model = LlamaForCausalLM.from_pretrained( - "decapoda-research/llama-7b-hf", + BASE_MODEL, device_map={"": device}, torch_dtype=torch.float16, ) model = PeftModel.from_pretrained( model, - "tloen/alpaca-lora-7b", + LORA_WEIGHTS, device_map={"": device}, torch_dtype=torch.float16, ) else: model = LlamaForCausalLM.from_pretrained( - "decapoda-research/llama-7b-hf", device_map={"": device}, low_cpu_mem_usage=True + BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True ) model = PeftModel.from_pretrained( model, - "tloen/alpaca-lora-7b", + LORA_WEIGHTS, device_map={"": device}, ) @@ -75,6 +76,8 @@ def generate_prompt(instruction, input=None): model.eval() +if torch.__version__ >= "2": + model = torch.compile(model) def evaluate(