Skip to content

Commit

Permalink
Adds document splitter
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisc36 committed Nov 15, 2017
1 parent 026eb57 commit d4831e5
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions docqa/data_processing/document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,25 +222,29 @@ def reads_first_n(self):
return None

def split(self, doc: List[List[List[str]]]) -> List[ExtractedParagraph]:
"""
Splits a list paragraphs->sentences->words to a list of `ExtractedParagraph`
"""
raise NotImplementedError()

def split_annotated(self, doc: List[List[List[str]]], spans: np.ndarray) -> List[ExtractedParagraphWithAnswers]:
"""
Split a document and additionally splits answer_span of each paragraph
"""
out = []
for para in self.split(doc):
para_spans = spans[np.logical_and(spans[:, 0] >= para.start, spans[:, 1] < para.end)] - para.start
out.append(ExtractedParagraphWithAnswers(para.text, para.start, para.end, para_spans))
return out

def split_inverse(self, paras: List[ParagraphWithInverse]) -> List[ParagraphWithInverse]:
"""
Split a document consisting of `ParagraphWithInverse` objects
"""
full_para = ParagraphWithInverse.concat(paras, "\n")

split_docs = self.split([x.text for x in paras])

max_len = len(full_para.get_context())
for para in split_docs:
if para.end > max_len:
raise RuntimeError()

out = []
for para in split_docs:
# Grad the correct inverses and convert back to the paragraph level
Expand Down Expand Up @@ -340,6 +344,21 @@ def split(self, doc: List[List[List[str]]]):
return all_paragraphs


class PreserveParagraphs(DocumentSplitter):
"""
Convience class that preserves the document's natural paragraph delimitation
"""
def split(self, doc: List[List[List[str]]]):
out = []
on_token = 0
for para in doc:
flattened_para = flatten_iterable(para)
end = on_token + len(flattened_para)
out.append(ExtractedParagraph([flatten_iterable(para)], on_token, end))
on_token = end
return out


def extract_tokens(paragraph: List[List[str]], n_tokens) -> List[List[str]]:
output = []
cur_tokens = 0
Expand Down

0 comments on commit d4831e5

Please sign in to comment.