Skip to content

Commit

Permalink
Updates train_bidaf.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisc36 committed May 14, 2018
1 parent 21fc045 commit 48c598d
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions docqa/scripts/train_bidaf.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from docqa.data_processing.paragraph_qa import ContextLenKey, ContextLenBucketedKey, DocumentQaTrainingData
from docqa.squad.squad_eval import SentenceSpanEvaluator, SquadSpanEvaluator, BoundedSquadSpanEvaluator

from docqa import model_dir
from docqa import trainer
from docqa.data_processing.qa_training_data import ContextLenBucketedKey, ContextLenKey
from docqa.dataset import ClusteredBatcher
from docqa.doc_qa_models import Attention
from docqa.encoder import DocumentAndQuestionEncoder, SingleSpanAnswerEncoder
from docqa.evaluator import LossEvaluator
from docqa.evaluator import LossEvaluator, SpanEvaluator
from docqa.nn.attention import BiAttention
from docqa.nn.embedder import FixedWordEmbedder, CharWordEmbedder, LearnedCharEmbedder
from docqa.nn.layers import NullBiMapper, NullMapper, SequenceMapperSeq, ReduceLayer, Conv1d, HighwayLayer, ChainConcat, \
Expand All @@ -15,6 +13,7 @@
from docqa.nn.similarity_layers import TriLinear
from docqa.nn.span_prediction import BoundsPredictor
from docqa.squad.build_squad_dataset import SquadCorpus
from docqa.squad.squad_data import DocumentQaTrainingData
from docqa.trainer import SerializableOptimizer, TrainParams
from docqa.utils import get_output_name_from_cli

Expand Down Expand Up @@ -52,6 +51,7 @@ def main():
embed_mapper=SequenceMapperSeq(
HighwayLayer(activation="relu"), HighwayLayer(activation="relu"),
recurrent_layer),
preprocess=None,
question_mapper=None,
context_mapper=None,
memory_builder=NullBiMapper(),
Expand All @@ -64,20 +64,21 @@ def main():
recurrent_layer),
end_layer=recurrent_layer
)
)
),

)

with open(__file__, "r") as f:
notes = f.read()

eval = [LossEvaluator(), SquadSpanEvaluator(), BoundedSquadSpanEvaluator([18]), SentenceSpanEvaluator()]
eval = [LossEvaluator(), SpanEvaluator(bound=[17], text_eval="squad")]

corpus = SquadCorpus()
train_batching = ClusteredBatcher(60, ContextLenBucketedKey(3), True, False)
eval_batching = ClusteredBatcher(60, ContextLenKey(), False, False)
data = DocumentQaTrainingData(corpus, None, train_batching, eval_batching)

trainer.start_training(data, model, train_params, eval, model_dir.ModelDir(out), notes, False)
trainer.start_training(data, model, train_params, eval, model_dir.ModelDir(out), notes)


if __name__ == "__main__":
Expand Down

0 comments on commit 48c598d

Please sign in to comment.