Skip to content

Commit

Permalink
Move Zipfs plotting to zipfs.py
Browse files Browse the repository at this point in the history
  • Loading branch information
louislefevre committed Apr 10, 2021
1 parent 118e120 commit 5092183
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 21 deletions.
11 changes: 1 addition & 10 deletions retrieval/DatasetParser.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os
from collections import Counter

from retrieval.data.Dataset import Dataset
from retrieval.data.InvertedIndex import InvertedIndex
from retrieval.models.BM25 import BM25
from retrieval.models.QueryLikelihood import QueryLikelihood
from retrieval.models.VectorSpace import VectorSpace
from retrieval.util.FileManager import read_pickle, write_pickle
from util.Plotter import zipfs


class DatasetParser:
Expand All @@ -17,12 +15,9 @@ def __init__(self, dataset_path: str):
self._queries = self._dataset.queries()
self._mapping = self._dataset.id_mapping()

def parse(self, model: str, plot_freq: bool = False, smoothing: str = None,
index_path: str = 'index.p') -> dict[int, dict[int, float]]:
def parse(self, model: str, smoothing: str = None, index_path: str = 'index.p') -> dict[int, dict[int, float]]:
index = self._generate_index(index_path, self._passages)

if plot_freq:
self._zipfs_law(index.counter)
if model == 'bm25':
model = BM25(index, self._mapping)
elif model == 'vs':
Expand All @@ -35,10 +30,6 @@ def parse(self, model: str, plot_freq: bool = False, smoothing: str = None,
print("Ranking queries against passages...")
return {qid: model.rank(qid, query) for qid, query in self._queries.items()}

@staticmethod
def _zipfs_law(counter: Counter):
zipfs(counter)

@staticmethod
def _generate_index(file: str, passages: dict[int, str]) -> InvertedIndex:
if os.path.isfile(file) and not os.stat(file).st_size == 0:
Expand Down
9 changes: 5 additions & 4 deletions retrieval/util/TextProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
nltk.download('stopwords', quiet=True)


def clean_collection(collection: dict[int, str]) -> dict[int, list[str]]:
return {pid: clean(passage) for pid, passage in collection.items()}
def clean_collection(collection: dict[int, str], remove_sw: bool = True) -> dict[int, list[str]]:
return {pid: clean(passage, remove_sw=remove_sw) for pid, passage in collection.items()}


def clean(text: str) -> list[str]:
def clean(text: str, remove_sw: bool = True) -> list[str]:
tokens = _tokenize(text)
tokens = _convert_numbers(tokens)
tokens = _normalise(tokens)
tokens = _remove_stopwords(tokens)
if remove_sw:
tokens = _remove_stopwords(tokens)
tokens = _stem(tokens)
return tokens

Expand Down
4 changes: 1 addition & 3 deletions start.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,20 @@ def main():
parser = argparse.ArgumentParser(description='Information retrieval models.')
parser.add_argument('dataset', help='dataset for retrieving passages and queries')
parser.add_argument('model', help='model for ranking passages against queries')
parser.add_argument('-p', '--plot', action='store_true', help='generate term frequency graph')
parser.add_argument('-s', '--smoothing', help='smoothing for the Query Likelihood model')

args = parser.parse_args()
dataset = args.dataset
model = args.model
smoothing = args.smoothing
plot = args.plot

if model == 'lm' and smoothing is None:
raise ValueError("Smoothing must be supplied when using the Query Likelihood model.")
if not model == 'lm' and smoothing is not None:
raise ValueError("Smoothing can only be applied to the Query Likelihood model.")

parser = DatasetParser(dataset)
results = parser.parse(model, plot_freq=plot, smoothing=smoothing)
results = parser.parse(model, smoothing=smoothing)

model = model.upper()
smoothing = f'-{smoothing.capitalize()}' if smoothing is not None else ""
Expand Down
33 changes: 29 additions & 4 deletions retrieval/util/Plotter.py → zipfs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import argparse
import itertools
from collections import Counter

from matplotlib import pyplot as plt
from tabulate import tabulate

from data.Dataset import Dataset
from util.FileManager import write_txt
from util.TextProcessor import clean_collection


def zipfs(counter: Counter):
def _zipfs_distribution(counter: Counter):
prob_distribution = []
rows = []
total_count = sum(counter.values())
Expand All @@ -24,9 +28,9 @@ def zipfs(counter: Counter):
_report_parameters(rows, c)


def _plot_distribution(prob_distribution: list[float]):
zipf_distribution = [0.1 / i for i in range(1, 101)]
_generate_figure(prob_distribution, zipf_distribution, title="Zipf's Law", x_label="Rank",
def _plot_distribution(prob_dist: list[float]):
zipf_dist = [0.1 / i for i in range(1, 101)]
_generate_figure(prob_dist, zipf_dist, title="Zipf's Law", x_label="Rank",
y_label="Probability", file_name='zipf-plot.png')


Expand All @@ -45,3 +49,24 @@ def _generate_figure(*data, title=None, x_label=None, y_label=None, file_name='f
plt.ylabel(y_label)
plt.grid()
plt.savefig(file_name)


def main():
parser = argparse.ArgumentParser(description='Zipfs law distribution plot and report.')
parser.add_argument('dataset', help='dataset for retrieving passages')
args = parser.parse_args()

print('Processing dataset...')
dataset = Dataset(args.dataset)
collection = clean_collection(dataset.passages(), remove_sw=False)

print('Analysing data...')
words = list(itertools.chain.from_iterable(collection.values()))
counter = Counter(words)

print('Generating plot...')
_zipfs_distribution(counter)


if __name__ == '__main__':
main()

0 comments on commit 5092183

Please sign in to comment.