Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
priya-dwivedi committed Mar 13, 2018
0 parents commit 1b4b5e6
Show file tree
Hide file tree
Showing 36 changed files with 6,760 additions and 0 deletions.
544 changes: 544 additions & 0 deletions analysis_data_length.ipynb

Large diffs are not rendered by default.

299 changes: 299 additions & 0 deletions char_vocab.ipynb

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions code/.idea/code.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions code/.idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions code/.idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1,100 changes: 1,100 additions & 0 deletions code/.idea/workspace.xml

Large diffs are not rendered by default.

Empty file added code/__init__.py
Empty file.
220 changes: 220 additions & 0 deletions code/data_batcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright 2018 Stanford University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This file contains code to read tokenized data from file,
truncate, pad and process it into batches ready for training"""

from __future__ import absolute_import
from __future__ import division

import random
import time
import re

import numpy as np
from six.moves import xrange
from vocab import PAD_ID, UNK_ID


class Batch(object):
"""A class to hold the information needed for a training batch"""

def __init__(self, context_ids, context_mask, context_tokens, qn_ids, qn_mask, qn_tokens, ans_span, ans_tokens, uuids=None):
"""
Inputs:
{context/qn}_ids: Numpy arrays.
Shape (batch_size, {context_len/question_len}). Contains padding.
{context/qn}_mask: Numpy arrays, same shape as _ids.
Contains 1s where there is real data, 0s where there is padding.
{context/qn/ans}_tokens: Lists length batch_size, containing lists (unpadded) of tokens (strings)
ans_span: numpy array, shape (batch_size, 2)
uuid: a list (length batch_size) of strings.
Not needed for training. Used by official_eval mode.
"""
self.context_ids = context_ids
self.context_mask = context_mask
self.context_tokens = context_tokens

self.qn_ids = qn_ids
self.qn_mask = qn_mask
self.qn_tokens = qn_tokens

self.ans_span = ans_span
self.ans_tokens = ans_tokens

self.uuids = uuids

self.batch_size = len(self.context_tokens)


def split_by_whitespace(sentence):
words = []
for space_separated_fragment in sentence.strip().split():
words.extend(re.split(" ", space_separated_fragment))
return [w for w in words if w]


def intstr_to_intlist(string):
"""Given a string e.g. '311 9 1334 635 6192 56 639', returns as a list of integers"""
return [int(s) for s in string.split()]


def sentence_to_token_ids(sentence, word2id):
"""Turns an already-tokenized sentence string into word indices
e.g. "i do n't know" -> [9, 32, 16, 96]
Note any token that isn't in the word2id mapping gets mapped to the id for UNK
"""
tokens = split_by_whitespace(sentence) # list of strings
ids = [word2id.get(w, UNK_ID) for w in tokens]
return tokens, ids


def padded(token_batch, batch_pad=0):
"""
Inputs:
token_batch: List (length batch size) of lists of ints.
batch_pad: Int. Length to pad to. If 0, pad to maximum length sequence in token_batch.
Returns:
List (length batch_size) of padded of lists of ints.
All are same length - batch_pad if batch_pad!=0, otherwise the maximum length in token_batch
"""
maxlen = max(map(lambda x: len(x), token_batch)) if batch_pad == 0 else batch_pad
return map(lambda token_list: token_list + [PAD_ID] * (maxlen - len(token_list)), token_batch)


def refill_batches(batches, word2id, context_file, qn_file, ans_file, batch_size, context_len, question_len, discard_long):
"""
Adds more batches into the "batches" list.
Inputs:
batches: list to add batches to
word2id: dictionary mapping word (string) to word id (int)
context_file, qn_file, ans_file: paths to {train/dev}.{context/question/answer} data files
batch_size: int. how big to make the batches
context_len, question_len: max length of context and question respectively
discard_long: If True, discard any examples that are longer than context_len or question_len.
If False, truncate those exmaples instead.
"""
print "Refilling batches..."
tic = time.time()
examples = [] # list of (qn_ids, context_ids, ans_span, ans_tokens) triples
context_line, qn_line, ans_line = context_file.readline(), qn_file.readline(), ans_file.readline() # read the next line from each

while context_line and qn_line and ans_line: # while you haven't reached the end

# Convert tokens to word ids
context_tokens, context_ids = sentence_to_token_ids(context_line, word2id)
qn_tokens, qn_ids = sentence_to_token_ids(qn_line, word2id)
ans_span = intstr_to_intlist(ans_line)

# read the next line from each file
context_line, qn_line, ans_line = context_file.readline(), qn_file.readline(), ans_file.readline()

# get ans_tokens from ans_span
assert len(ans_span) == 2
if ans_span[1] < ans_span[0]:
print "Found an ill-formed gold span: start=%i end=%i" % (ans_span[0], ans_span[1])
continue
ans_tokens = context_tokens[ans_span[0] : ans_span[1]+1] # list of strings

# discard or truncate too-long questions
if len(qn_ids) > question_len:
if discard_long:
continue
else: # truncate
qn_ids = qn_ids[:question_len]

# discard or truncate too-long contexts
if len(context_ids) > context_len:
if discard_long:
continue
else: # truncate
context_ids = context_ids[:context_len]

# add to examples
examples.append((context_ids, context_tokens, qn_ids, qn_tokens, ans_span, ans_tokens))

# stop refilling if you have 160 batches
if len(examples) == batch_size * 160:
break

# Once you've either got 160 batches or you've reached end of file:

# Sort by question length
# Note: if you sort by context length, then you'll have batches which contain the same context many times (because each context appears several times, with different questions)
examples = sorted(examples, key=lambda e: len(e[2]))

# Make into batches and append to the list batches
for batch_start in xrange(0, len(examples), batch_size):

# Note: each of these is a list length batch_size of lists of ints (except on last iter when it might be less than batch_size)
context_ids_batch, context_tokens_batch, qn_ids_batch, qn_tokens_batch, ans_span_batch, ans_tokens_batch = zip(*examples[batch_start:batch_start+batch_size])

batches.append((context_ids_batch, context_tokens_batch, qn_ids_batch, qn_tokens_batch, ans_span_batch, ans_tokens_batch))

# shuffle the batches
random.shuffle(batches)

toc = time.time()
print "Refilling batches took %.2f seconds" % (toc-tic)
return


def get_batch_generator(word2id, context_path, qn_path, ans_path, batch_size, context_len, question_len, discard_long):
"""
This function returns a generator object that yields batches.
The last batch in the dataset will be a partial batch.
Read this to understand generators and the yield keyword in Python: https://stackoverflow.com/questions/231767/what-does-the-yield-keyword-do
Inputs:
word2id: dictionary mapping word (string) to word id (int)
context_file, qn_file, ans_file: paths to {train/dev}.{context/question/answer} data files
batch_size: int. how big to make the batches
context_len, question_len: max length of context and question respectively
discard_long: If True, discard any examples that are longer than context_len or question_len.
If False, truncate those exmaples instead.
"""
context_file, qn_file, ans_file = open(context_path), open(qn_path), open(ans_path)
batches = []

while True:
if len(batches) == 0: # add more batches
refill_batches(batches, word2id, context_file, qn_file, ans_file, batch_size, context_len, question_len, discard_long)
if len(batches) == 0:
break

# Get next batch. These are all lists length batch_size
(context_ids, context_tokens, qn_ids, qn_tokens, ans_span, ans_tokens) = batches.pop(0)

# Pad context_ids and qn_ids
qn_ids = padded(qn_ids, question_len) # pad questions to length question_len
context_ids = padded(context_ids, context_len) # pad contexts to length context_len

# Make qn_ids into a np array and create qn_mask
qn_ids = np.array(qn_ids) # shape (question_len, batch_size)
qn_mask = (qn_ids != PAD_ID).astype(np.int32) # shape (question_len, batch_size)

# Make context_ids into a np array and create context_mask
context_ids = np.array(context_ids) # shape (context_len, batch_size)
context_mask = (context_ids != PAD_ID).astype(np.int32) # shape (context_len, batch_size)

# Make ans_span into a np array
ans_span = np.array(ans_span) # shape (batch_size, 2)

# Make into a Batch object
batch = Batch(context_ids, context_mask, context_tokens, qn_ids, qn_mask, qn_tokens, ans_span, ans_tokens)

yield batch

return
Binary file added code/data_batcher.pyc
Binary file not shown.
Loading

0 comments on commit 1b4b5e6

Please sign in to comment.