Skip to content

Commit

Permalink
Rename TextStatistics to Plotter
Browse files Browse the repository at this point in the history
Also moved to util/
  • Loading branch information
louislefevre committed Mar 27, 2021
1 parent 548f566 commit 0babaef
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
11 changes: 3 additions & 8 deletions retrieval/DatasetParser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os

from retrieval.TextStatistics import plotter
from retrieval.data.Dataset import Dataset
from retrieval.data.InvertedIndex import InvertedIndex
from retrieval.models.BM25 import BM25
Expand All @@ -16,12 +15,12 @@ def __init__(self, dataset_path: str):
self._queries = self._dataset.queries()
self._mapping = self._dataset.id_mapping()

def parse(self, model: str, plot: bool = False, smoothing: str = None,
def parse(self, model: str, plot_freq: bool = False, smoothing: str = None,
index_path: str = 'index.p') -> dict[int, dict[int, float]]:
index = self._generate_index(index_path, self._passages)

if plot:
self._plot_frequency(index)
if plot_freq:
index.plot()

if model == 'bm25':
model = BM25(index, self._mapping)
Expand All @@ -34,10 +33,6 @@ def parse(self, model: str, plot: bool = False, smoothing: str = None,

return {qid: model.rank(qid, query) for qid, query in self._queries.items()}

@staticmethod
def _plot_frequency(index: InvertedIndex):
plotter(index.counter)

@staticmethod
def _generate_index(file: str, passages: dict[int, str]) -> InvertedIndex:
if os.path.isfile(file) and not os.stat(file).st_size == 0:
Expand Down
4 changes: 4 additions & 0 deletions retrieval/data/InvertedIndex.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
from math import log

from retrieval.util.Plotter import plot_frequency
from retrieval.util.TextProcessor import clean_collection


Expand All @@ -15,6 +16,9 @@ def parse(self):
self._index_passages()
self._tfidf_passages()

def plot(self):
plot_frequency(self.counter)

def _index_passages(self):
for pid, passage in self._collection.items():
for term in passage:
Expand Down
2 changes: 1 addition & 1 deletion retrieval/TextStatistics.py → retrieval/util/Plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from matplotlib import pyplot as plt


def plot(counter: Counter):
def plot_frequency(counter: Counter):
frequencies = _normalise(counter.values())
frequencies.sort(reverse=True)
frequencies = frequencies[:100]
Expand Down

0 comments on commit 0babaef

Please sign in to comment.