Skip to content

Commit

Permalink
Fix a bug in iterating in eval/predict.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhongw committed Sep 17, 2023
1 parent ca7c85f commit cbc02e9
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions eval/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
Then you can run this script with the following command:
python eval/predict.py \
--model_name_or_path <huggingface_model_name_or_path> \
--input_files <input_files> \
--input_files <input_file_1> <input_file_2> ... \
--output_file <output_file> \
--batch_size <batch_size> \
--load_in_8bit
--use_vllm
'''


Expand All @@ -33,12 +33,15 @@ def parse_args():
parser.add_argument(
"--model_name_or_path",
type=str,
default=None,
help="Huggingface model name or path.")
parser.add_argument(
"--tokenizer_name_or_path",
type=str,
help="Huggingface tokenizer name or path."
)
parser.add_argument(
"--openai_engine",
type=str,
default=None,
help="OpenAI engine name.")
parser.add_argument(
"--input_files",
Expand Down Expand Up @@ -127,13 +130,13 @@ def parse_args():
prompts = []
chat_formatting_function = dynamic_import_function(args.chat_formatting_function) if args.use_chat_format else None
for instance in instances:
if "messages" in instances:
if "messages" in instance:
if not args.use_chat_format:
raise ValueError("If `messages` is in the instance, `use_chat_format` should be True.")
assert all("role" in message and "content" in message for message in instance["messages"]), \
"Each message should have a `role` and a `content` field."
prompt = eval(args.chat_formatting_function)(instance["messages"], add_bos=False)
elif "prompt" in instances:
elif "prompt" in instance:
if args.use_chat_format:
messages = [{"role": "user", "content": instance["prompt"]}]
prompt = chat_formatting_function(messages, add_bos=False)
Expand Down

0 comments on commit cbc02e9

Please sign in to comment.