Skip to content

Commit

Permalink
Fix a bug in tydiaqa model loading.
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhongw committed Sep 25, 2023
1 parent 4a8bba3 commit ae27904
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions eval/tydiqa/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,24 @@ def main(args):

if args.model_name_or_path:
print("Loading model and tokenizer...")
model, tokenizer = load_hf_lm_and_tokenizer(
model_name_or_path=args.model_name_or_path,
tokenizer_name_or_path=args.tokenizer_name_or_path,
load_in_8bit=args.load_in_8bit,
device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
gptq_model=args.gptq,
use_fast_tokenizer=not args.use_slow_tokenizer,
)
if args.use_vllm:
model = vllm.LLM(
model=args.model_name_or_path,
tokenizer=args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path,
tokenizer_mode="slow" if args.use_slow_tokenizer else "auto",
tensor_parallel_size=torch.cuda.device_count(),
max_num_batched_tokens=4096,
)
tokenizer = model.llm_engine.tokenizer
else:
model, tokenizer = load_hf_lm_and_tokenizer(
model_name_or_path=args.model_name_or_path,
tokenizer_name_or_path=args.tokenizer_name_or_path,
load_in_8bit=args.load_in_8bit,
device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
gptq_model=args.gptq,
use_fast_tokenizer=not args.use_slow_tokenizer,
)
else:
import tiktoken
tokenizer = tiktoken.get_encoding("cl100k_base")
Expand Down Expand Up @@ -171,13 +181,6 @@ def main(args):

if args.model_name_or_path:
if args.use_vllm:
model = vllm.LLM(
model=args.model_name_or_path,
tokenizer=args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path,
tokenizer_mode="slow" if args.use_slow_tokenizer else "auto",
tensor_parallel_size=torch.cuda.device_count(),
max_num_batched_tokens=4096,
)
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=50,
Expand Down

0 comments on commit ae27904

Please sign in to comment.