Skip to content

Commit

Permalink
generate.py tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
tloen committed Mar 19, 2023
1 parent 80fd983 commit c83e30a
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")

BASE_MODEL = "decapoda-research/llama-7b-hf"
LORA_WEIGHTS = "tloen/alpaca-lora-7b"

if torch.cuda.is_available():
device = "cuda"
else:
Expand All @@ -23,33 +26,31 @@

if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf",
"chavinlo/alpaca-native",
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model, "tloen/alpaca-lora-7b", torch_dtype=torch.float16
)
# model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16)
elif device == "mps":
model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf",
BASE_MODEL,
device_map={"": device},
torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(
model,
"tloen/alpaca-lora-7b",
LORA_WEIGHTS,
device_map={"": device},
torch_dtype=torch.float16,
)
else:
model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf", device_map={"": device}, low_cpu_mem_usage=True
BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
)
model = PeftModel.from_pretrained(
model,
"tloen/alpaca-lora-7b",
LORA_WEIGHTS,
device_map={"": device},
)

Expand All @@ -75,6 +76,8 @@ def generate_prompt(instruction, input=None):


model.eval()
if torch.__version__ >= "2":
model = torch.compile(model)


def evaluate(
Expand Down

0 comments on commit c83e30a

Please sign in to comment.