Skip to content

Commit

Permalink
Add saving results to file
Browse files Browse the repository at this point in the history
  • Loading branch information
louislefevre committed Mar 27, 2021
1 parent ef00ec6 commit 9807b20
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# Data
dataset/
results/
*.p
*.png

Expand Down
10 changes: 10 additions & 0 deletions retrieval/util/FileManager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import csv
import os
from typing import Union, Any
import pickle

Expand All @@ -10,6 +11,15 @@ def read_tsv(file_name: str) -> Union[list, str]:
raise RuntimeError("Invalid file type: only '.tsv' files are accepted.")


def write_txt(file_name: str, data: str, mode='w'):
directory = os.path.dirname(file_name)
if not os.path.exists(directory):
os.makedirs(directory)
file = open(file_name, mode=mode)
file.write(data)
file.close()


def write_pickle(data: object, file_name: str):
pickle.dump(data, open(file_name, "wb"))

Expand Down
21 changes: 19 additions & 2 deletions start.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse

from retrieval.DatasetParser import DatasetParser
from retrieval.util.FileManager import write_txt


def main():
Expand All @@ -11,8 +12,24 @@ def main():
parser.add_argument('-s', '--smoothing', help='smoothing for the Query Likelihood model')

args = parser.parse_args()
parser = DatasetParser(args.dataset)
results = parser.parse(args.model, plot_freq=args.plot, smoothing=args.smoothing)
dataset = args.dataset
model = args.model
smoothing = args.smoothing
plot = args.plot
parser = DatasetParser(dataset)
results = parser.parse(model, plot_freq=plot, smoothing=smoothing)

models = {'bm25': 'BM25', 'vector': 'VS', 'query': 'LM'}
smoothers = {'laplace': '-Laplace', 'lidstone': '-Lidstone', 'dirichlet': '-Dirichlet'}
model = models[model]
smoothing = smoothers[smoothing] if smoothing is not None else ""

data = ''
for qid, passages in results.items():
for rank, (pid, score) in enumerate(passages.items()):
data += f"{qid}\t{'A1'}\t{pid}\t{rank}\t{format(score, '.2f')}\t{model}{smoothing}\n"

write_txt(f'results/{model}.txt', data)


if __name__ == '__main__':
Expand Down

0 comments on commit 9807b20

Please sign in to comment.