Skip to content

Commit

Permalink
Improves doc
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisc36 committed Nov 21, 2017
1 parent d14f37b commit 01bedd4
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions docqa/data_processing/multi_paragraph_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,18 @@ def __init__(self,
true_len: int,
batch_size: int,
force_answer: bool,
overample_first_answer: List[int],
oversample_first_answer: List[int],
merge: bool):
self.overample_first_answer = overample_first_answer
"""
:param true_len: Number questions before any filtering was done
:param batch_size: Batch size to use
:param force_answer: Require an answer exists for at least
one paragraph for each question each batch
:param oversample_first_answer: Over sample the top-ranked answer-containing paragraphs
by duplicating them the specified amount
:param merge: Merge all selected paragraphs for each question into a single super-paragraph
"""
self.overample_first_answer = oversample_first_answer
self.questions = questions
self.merge = merge
self.true_len = true_len
Expand Down Expand Up @@ -402,6 +411,7 @@ def get_epoch(self):

def _build_expanded_batches(self, questions):
out = []
# Decide what paragraphs to use for each question
for i, q in enumerate(questions):
order = self._order[i]
out.append(ParagraphSelection(q, order[self._on[i]]))
Expand All @@ -410,16 +420,18 @@ def _build_expanded_batches(self, questions):
self._on[i] = 0
np.random.shuffle(order)

# Sort by context length
out.sort(key=lambda x: x.n_context_words)

# Yield the correct batches
group = 0
for selection_batch in self.batcher.get_epoch(out):
batch = []
for selected in selection_batch:
q = selected.question
if self.merge:
paras = [q.paragraphs[i] for i in selected.selection]
# Sort paragraph my reading order, not rank order
# Sort paragraph by reading order, not rank order
paras.sort(key=lambda x: x.get_order())
answer_spans = []
text = []
Expand Down

0 comments on commit 01bedd4

Please sign in to comment.