Skip to content

Commit

Permalink
Merge pull request stanford-oval#23 from DaEpicR/main
Browse files Browse the repository at this point in the history
Small improvements
  • Loading branch information
shaoyijia authored Apr 17, 2024
2 parents abf3dac + 870c6b4 commit fdb111f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 17 deletions.
22 changes: 6 additions & 16 deletions eval/eval_outline_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,14 @@
import os.path
import re
from argparse import ArgumentParser

import pandas as pd
from tqdm import tqdm

from metrics import heading_soft_recall, heading_entity_recall


def load_str(path):
with open(path, 'r') as f:
return '\n'.join(f.readlines())


def get_sections(path):
s = load_str(path)
s = re.sub(r"\d+\.\ ", '#', s)
Expand All @@ -30,12 +26,11 @@ def get_sections(path):
if "# References" in line:
break
if line.startswith('#'):
if "references" in line.lower() or "external links" in line.lower() or "see also" in line.lower() or "notes" in line.lower():
if any(keyword in line.lower() for keyword in ["references", "external links", "see also", "notes"]):
break
sections.append(line.strip('#').strip())
return sections


def main(args):
df = pd.read_csv(args.input_path)
entity_recalls = []
Expand All @@ -45,27 +40,22 @@ def main(args):
topic_name = row['topic'].replace(' ', '_').replace('/', '_')
gt_sections = get_sections(os.path.join(args.gt_dir, 'txt', f'{topic_name}.txt'))
pred_sections = get_sections(os.path.join(args.pred_dir, topic_name, args.pred_file_name))
entity_recalls.append(heading_entity_recall(golden_headings=gt_sections,
predicted_headings=pred_sections))
entity_recalls.append(heading_entity_recall(golden_headings=gt_sections, predicted_headings=pred_sections))
heading_soft_recalls.append(heading_soft_recall(gt_sections, pred_sections))
topics.append(row['topic'])

results = pd.DataFrame(
{'topic': topics, 'entity_recall': entity_recalls, 'heading_soft_recall': heading_soft_recalls})
results = pd.DataFrame({'topic': topics, 'entity_recall': entity_recalls, 'heading_soft_recall': heading_soft_recalls})
results.to_csv(args.result_output_path, index=False)
avg_entity_recall = sum(entity_recalls) / len(entity_recalls)
avg_heading_soft_recall = sum(heading_soft_recalls) / len(heading_soft_recalls)
print(f'Entity recall: {avg_entity_recall}')
print(f'Heading soft recall: {avg_heading_soft_recall}')

print(f'Average Entity Recall: {avg_entity_recall}')
print(f'Average Heading Soft Recall: {avg_heading_soft_recall}')

if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--input-path', type=str,
help='Path to the csv file storing topics and ground truth URLs.')
parser.add_argument('--input-path', type=str, help='Path to the CSV file storing topics and ground truth URLs.')
parser.add_argument('--gt-dir', type=str, help='Path of human-written articles.')
parser.add_argument('--pred-dir', type=str, help='Path of generated outlines.')
parser.add_argument('--pred-file-name', type=str, help='Name of the outline file.')
parser.add_argument('--result-output-path', type=str, help='Path to save the results.')

main(parser.parse_args())
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ fastchat
wikipedia==1.4.0
Wikipedia-API==0.6.0
rouge-score
toml
toml
tqdm==4.66.2

0 comments on commit fdb111f

Please sign in to comment.