Skip to content

Commit

Permalink
feat(experimental): add script to run eval in cli
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxiaoys committed Jul 11, 2023
1 parent b135022 commit 7ca416a
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions experimental/eval/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@ def run_eval(args):
except:
print(f"Tabby Server is not ready, please check if '{api}' is correct.")
return

items = [x for x in processing.items_from_filepattern(args.filepattern) if valid_item(x)];

items = [
x for x in processing.items_from_filepattern(args.filepattern) if valid_item(x)
]
if len(items) > args.max_records:
random.seed(0xbadbeef)
random.seed(0xBADBEEF)
items = random.sample(items, args.max_records)


for item in items:
if not valid_item(item):
Expand All @@ -56,10 +57,10 @@ def run_eval(args):
prediction = resp.choices[0].text

block_score = scorer(label, prediction)

label_lines = label.splitlines()
prediction_lines = prediction.splitlines()

if len(label_lines) > 0 and len(prediction_lines) > 0:
line_score = scorer(label_lines[0], prediction_lines[0])

Expand All @@ -71,13 +72,19 @@ def run_eval(args):
line_score=line_score,
)


if __name__ == "__main__":
logging.basicConfig(stream=sys.stderr, level=logging.INFO)

parser = argparse.ArgumentParser(description='SxS eval for tabby')
parser.add_argument('filepattern', type=str, help='File pattern to dataset.')
parser.add_argument('max_records', type=int, help='Max number of records to be evaluated.')
parser = argparse.ArgumentParser(
description="SxS eval for tabby",
epilog="Example usage: python main.py ./tabby/dataset/data.jsonl 5 > output.jsonl",
)
parser.add_argument("filepattern", type=str, help="File pattern to dataset.")
parser.add_argument(
"max_records", type=int, help="Max number of records to be evaluated."
)
args = parser.parse_args()
logging.info("args %s", args)
df = pd.DataFrame(run_eval(args))
print(df.to_json(orient='records', lines=True))
print(df.to_json(orient="records", lines=True))

0 comments on commit 7ca416a

Please sign in to comment.