Skip to content

Commit

Permalink
Use the official setup of alpaca eval by default.
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhongw committed Oct 19, 2023
1 parent 09613af commit 103c078
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
36 changes: 23 additions & 13 deletions eval/alpaca_farm/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ def main(args):
os.makedirs(args.save_dir, exist_ok=True)

logging.info("loading data and model...")
alpaca_eval_data = datasets.load_dataset("tatsu-lab/alpaca_farm", "alpaca_farm_evaluation")["eval"]
alpaca_eval_data = datasets.load_dataset("tatsu-lab/alpaca_eval", "alpaca_eval")["eval"]
prompts = []
chat_formatting_function = dynamic_import_function(args.chat_formatting_function) if args.use_chat_format else None
for example in alpaca_eval_data:
prompt = example["instruction"] + "\n\n" + example["input"] if example["input"] != "" else example["instruction"]
prompt = example["instruction"]
if args.use_chat_format:
messages = [{"role": "user", "content": prompt}]
prompt = chat_formatting_function(messages, add_bos=False)
Expand All @@ -30,7 +30,6 @@ def main(args):
model = vllm.LLM(
model=args.model_name_or_path,
tokenizer=args.tokenizer_name_or_path if args.tokenizer_name_or_path is not None else args.model_name_or_path,
# tokenizer_mode="slow",
tensor_parallel_size=torch.cuda.device_count(),
)
sampling_params = vllm.SamplingParams(
Expand Down Expand Up @@ -79,13 +78,21 @@ def main(args):
fout.write(json.dumps(example) + "\n")
model_results.append(example)

df_leaderboard, annotations = alpaca_farm_evaluate(
model_outputs=model_results,
reference_outputs=args.reference_path,
annotators_config="alpaca_eval_gpt4_0314",
output_path=args.save_dir,
is_return_instead_of_print=True,
)
if args.reference_path is not None:
df_leaderboard, annotations = alpaca_farm_evaluate(
model_outputs=model_results,
reference_outputs=args.reference_path,
annotators_config="alpaca_eval_gpt4_0314",
output_path=args.save_dir,
is_return_instead_of_print=True,
)
else:
df_leaderboard, annotations = alpaca_farm_evaluate(
model_outputs=model_results,
annotators_config="alpaca_eval_gpt4_0314",
output_path=args.save_dir,
is_return_instead_of_print=True,
)

print(df_leaderboard.to_string(float_format="%.2f"))

Expand All @@ -99,10 +106,13 @@ def main(args):
parser.add_argument(
"--reference_path",
type=str,
default="data/eval/alpaca_farm/davinci_003_outputs_2048_token.json",
default=None,
help="Path to the reference outputs. "
"Alpaca_eval leaderboard use davinci_003 to generate the reference outputs, "
"but they limit the max_tokens to 300. Here we regenerated reference outputs with max_tokens=2048.",
"Alpaca_eval leaderboard use text-davinci-003 to generate the reference outputs, "
"but they limit the max_tokens to 300, which is a bit unfair for text-davinci-003. "
"Here we keep this default setup to make numbers comparable to their leaderboard. "
"But you can also use the regenerated reference outputs with max_tokens=2048 "
"hosted at https://huggingface.co/datasets/hamishivi/alpaca-farm-davinci-003-2048-token.",
)
parser.add_argument(
"--save_dir",
Expand Down
2 changes: 0 additions & 2 deletions scripts/eval/alpaca_farm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# use vllm for generation
python -m eval.alpaca_farm.run_eval \
--model_name_or_path ../checkpoints/tulu_v1_7B/ \
--reference_path data/eval/alpaca_farm/davinci_003_outputs_2048_token.json \
--save_dir results/alpaca_farm/tulu_v1_7B/ \
--eval_batch_size 20 \
--use_vllm \
Expand All @@ -14,7 +13,6 @@ python -m eval.alpaca_farm.run_eval \
# use normal huggingface generation function
python -m eval.alpaca_farm.run_eval \
--model_name_or_path ../checkpoints/tulu_v1_7B/ \
--reference_path data/eval/alpaca_farm/davinci_003_outputs_2048_token.json \
--save_dir results/alpaca_farm/tulu_v1_7B/ \
--eval_batch_size 20 \
--use_chat_format \
Expand Down

0 comments on commit 103c078

Please sign in to comment.