Skip to content

Commit

Permalink
rewrote scoring to load data on demand and not skip sentences (awslab…
Browse files Browse the repository at this point in the history
…s#593)

The current scoring implementation (a) loads the entire training data into memory and (b) skips sentences that are too long so that the output is not parallel with the input. This PR fixes both of these issues. Data is loaded from disk one batch at a time via a new `BatchRawParallelSampleIter`, and sentences that are too long are truncated.

One application of this is [dual cross-entropy filtering](http://aclweb.org/anthology/W18-6478) of the training data.
  • Loading branch information
mjpost authored and fhieber committed Feb 21, 2019
1 parent cd7069f commit 02ff337
Show file tree
Hide file tree
Showing 16 changed files with 346 additions and 171 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.


## [1.18.77]
### Added
- `sockeye.score` now loads data on demand and doesn't skip any input lines

## [1.18.76]
### Changed
- Do not compare scores from translation and scoring in integration tests.
Expand All @@ -28,7 +33,7 @@ In case this is turned on a checkpoint decoder is launched right when training s

## [1.18.73]
### Fixed
- Fixed a bug where `source-factors-num-embed` was not correctly adjusted to `num-embed`
- Fixed a bug where `source-factors-num-embed` was not correctly adjusted to `num-embed`
when using prepared data & `source-factor-combine` sum.

## [1.18.72]
Expand Down
4 changes: 2 additions & 2 deletions sockeye/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017, 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017--2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand All @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.18.76'
__version__ = '1.18.77'
5 changes: 0 additions & 5 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,11 +801,6 @@ def add_training_args(params):
action='store_true',
help='Pre-train a decoder. This is currently for RNN decoders only. '
'Default: %(default)s.')
train_params.add_argument('--fill-up',
type=str,
default=C.FILL_UP_DEFAULT,
choices=C.FILL_UP_CHOICES,
help=argparse.SUPPRESS)

train_params.add_argument('--loss',
default=C.CROSS_ENTROPY,
Expand Down
5 changes: 0 additions & 5 deletions sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,6 @@
PREPARED_DATA_VERSION_FILE = "data.version"
PREPARED_DATA_VERSION = 2

FILL_UP_REPLICATE = 'replicate'
FILL_UP_ZEROS = 'zeros'
FILL_UP_DEFAULT = FILL_UP_REPLICATE
FILL_UP_CHOICES = [FILL_UP_REPLICATE, FILL_UP_ZEROS]

# reranking
RERANK_BLEU = "bleu"
RERANK_CHRF = "chrf"
Expand Down
297 changes: 234 additions & 63 deletions sockeye/data_io.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion sockeye/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def __init__(self,

self.embed_factor_weights = [] # type: List[mx.sym.Symbol]
if self.config.factor_configs is not None:
# Factors weights aren't shared so they're not passed in and we create them here.
# Factor weights aren't shared so they're not passed in and we create them here.
for i, fc in enumerate(self.config.factor_configs):
self.embed_factor_weights.append(mx.sym.Variable(prefix + "factor%d_weight" % i,
shape=(fc.vocab_size, fc.num_embed)))
Expand Down
9 changes: 2 additions & 7 deletions sockeye/image_captioning/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def get_validation_image_text_data_iter(data_loader: RawParallelDatasetLoader,
vocab_target: vocab.Vocab,
max_seq_len_target: int,
batch_size: int,
fill_up: str,
use_feature_loader: bool = False,
preload_features: bool = False) -> 'ParallelSampleIter':
"""
Expand All @@ -156,8 +155,7 @@ def get_validation_image_text_data_iter(data_loader: RawParallelDatasetLoader,

validation_data = data_loader.load(validation_source_images[0],
validation_target_sentences,
validation_data_statistics.num_sents_per_bucket).fill_up(bucket_batch_sizes,
fill_up)
validation_data_statistics.num_sents_per_bucket).fill_up(bucket_batch_sizes)
return ImageTextSampleIter(data=validation_data,
buckets=buckets,
batch_size=batch_size,
Expand All @@ -177,7 +175,6 @@ def get_training_image_text_data_iters(source_root: str,
batch_by_words: bool,
batch_num_devices: int,
source_image_size: tuple,
fill_up: str,
max_seq_len_target: int,
bucketing: bool,
bucket_width: int,
Expand All @@ -200,7 +197,6 @@ def get_training_image_text_data_iters(source_root: str,
:param batch_by_words: Size batches by words rather than sentences.
:param batch_num_devices: Number of devices batches will be parallelized across.
:param source_image_size: size to resize the image to (for iterator)
:param fill_up: Fill-up strategy for buckets.
:param max_seq_len_target: Maximum target sequence length.
:param bucketing: Whether to use bucketing.
:param bucket_width: Size of buckets.
Expand Down Expand Up @@ -241,7 +237,7 @@ def get_training_image_text_data_iters(source_root: str,
pad_id=C.PAD_ID)

training_data = data_loader.load(source_images[0], target_sentences,
data_statistics.num_sents_per_bucket).fill_up(bucket_batch_sizes, fill_up)
data_statistics.num_sents_per_bucket).fill_up(bucket_batch_sizes)

data_info = DataInfo(sources=source_images,
target=target,
Expand Down Expand Up @@ -278,7 +274,6 @@ def get_training_image_text_data_iters(source_root: str,
vocab_target=vocab_target,
max_seq_len_target=max_seq_len_target,
batch_size=batch_size,
fill_up=fill_up,
use_feature_loader=use_feature_loader,
preload_features=preload_features)

Expand Down
1 change: 0 additions & 1 deletion sockeye/image_captioning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def create_data_iters_and_vocab(args: argparse.Namespace,
batch_by_words=batch_by_words,
batch_num_devices=batch_num_devices,
source_image_size=args.source_image_size,
fill_up=args.fill_up,
max_seq_len_target=max_seq_len_target,
bucketing=not args.no_bucketing,
bucket_width=args.bucket_width,
Expand Down
16 changes: 12 additions & 4 deletions sockeye/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import logging.config
import sys
from typing import Optional
from typing import Optional, Dict, Any

FORMATTERS = {
'verbose': {
Expand Down Expand Up @@ -90,10 +90,16 @@
}
}

NO_LOGGING = {
'version': 1,
'disable_existing_loggers': True,
}

LOGGING_CONFIGS = {
"file_only": FILE_LOGGING,
"console_only": CONSOLE_LOGGING,
"file_console": FILE_CONSOLE_LOGGING,
"none": NO_LOGGING,
}


Expand All @@ -111,16 +117,18 @@ def setup_main_logger(file_logging=True, console=True, path: Optional[str] = Non
:param path: Optional path to write logfile to.
"""
if file_logging and console:
log_config = LOGGING_CONFIGS["file_console"]
log_config = LOGGING_CONFIGS["file_console"] # type: ignore
elif file_logging:
log_config = LOGGING_CONFIGS["file_only"]
else:
elif console:
log_config = LOGGING_CONFIGS["console_only"]
else:
log_config = LOGGING_CONFIGS["none"]

if path:
log_config["handlers"]["rotating"]["filename"] = path # type: ignore

logging.config.dictConfig(log_config)
logging.config.dictConfig(log_config) # type: ignore

def exception_hook(exc_type, exc_value, exc_traceback):
if is_python34():
Expand Down
42 changes: 8 additions & 34 deletions sockeye/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# permissions and limitations under the License.

"""
Simple Training CLI.
Scoring CLI.
"""
import argparse
import os
Expand All @@ -37,7 +37,6 @@


def main():
setup_main_logger(file_logging=False, console=True)
params = arguments.ConfigArgumentParser(description='Score data with an existing model.')
arguments.add_score_cli_args(params)
args = params.parse_args()
Expand All @@ -46,15 +45,13 @@ def main():

def get_data_iters_and_vocabs(args: argparse.Namespace,
model_folder: Optional[str]) -> Tuple['data_io.BaseParallelSampleIter',
'data_io.DataConfig',
List[vocab.Vocab], vocab.Vocab, model.ModelConfig]:
"""
Loads the data iterators and vocabularies.
:param args: Arguments as returned by argparse.
:param model_folder: Output folder.
:return: The data iterators (train, validation, config_data) as well as the source and target vocabularies,
and data_info if not using prepared data.
:return: The scoring data iterator as well as the source and target vocabularies.
"""

model_config = model.SockeyeModel.load_config(os.path.join(args.model, C.CONFIG_NAME))
Expand All @@ -66,7 +63,6 @@ def get_data_iters_and_vocabs(args: argparse.Namespace,
max_seq_len_source, max_seq_len_target = args.max_seq_len

batch_num_devices = 1 if args.use_cpu else sum(-di if di < 0 else 1 for di in args.device_ids)
batch_by_words = args.batch_type == C.BATCH_TYPE_WORD

# Load the existing vocabs created when starting the training run.
source_vocabs = vocab.load_source_vocabs(model_folder)
Expand All @@ -75,32 +71,23 @@ def get_data_iters_and_vocabs(args: argparse.Namespace,
sources = [args.source] + args.source_factors
sources = [str(os.path.abspath(source)) for source in sources]

train_iter, _, config_data, data_info = data_io.get_training_data_iters(
score_iter = data_io.get_scoring_data_iters(
sources=sources,
target=os.path.abspath(args.target),
validation_sources=None,
validation_target=None,
source_vocabs=source_vocabs,
target_vocab=target_vocab,
source_vocab_paths=[None],
target_vocab_path=None,
shared_vocab=False,
batch_size=args.batch_size,
batch_by_words=batch_by_words,
batch_num_devices=batch_num_devices,
fill_up=C.FILL_UP_ZEROS,
permute=False,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
bucketing=False,
bucket_width=args.bucket_width,
allow_empty=True)
max_seq_len_target=max_seq_len_target)

return train_iter, config_data, source_vocabs, target_vocab, model_config
return score_iter, source_vocabs, target_vocab, model_config


def score(args: argparse.Namespace):

setup_main_logger(file_logging=False, console=not args.quiet)

utils.log_basic_info(args)

with ExitStack() as exit_stack:
Expand All @@ -119,12 +106,10 @@ def score(args: argparse.Namespace):
# one-for-one and in the same order as the input data.
# To enable code reuse, we stuff the `args` parameter with some values.
# Bucketing and permuting need to be turned off in order to preserve the ordering of sentences.
# The 'zeros' fill_up strategy fills underfilled buckets with zeros which can then be used to find the last item.
# Finally, 'resume_training' needs to be set to True because it causes the model to be loaded instead of initialized.
args.no_bucketing = True
args.fill_up = 'zeros'
args.bucket_width = 10
score_iter, config_data, source_vocabs, target_vocab, model_config = get_data_iters_and_vocabs(
score_iter, source_vocabs, target_vocab, model_config = get_data_iters_and_vocabs(
args=args,
model_folder=args.model)

Expand All @@ -135,7 +120,6 @@ def score(args: argparse.Namespace):
provide_label=score_iter.provide_label,
default_bucket_key=score_iter.default_bucket_key,
score_type=args.score_type,
bucketing=False,
length_penalty=inference.LengthPenalty(alpha=args.length_penalty_alpha,
beta=args.length_penalty_beta),
softmax_temperature=args.softmax_temperature)
Expand All @@ -146,16 +130,6 @@ def score(args: argparse.Namespace):
output_handler=get_output_handler(output_type=args.output_type,
output_fname=args.output))

if config_data.data_statistics.num_discarded != 0:
num_discarded = config_data.data_statistics.num_discarded
logger.warning('Warning: %d %s longer than %s %s skipped. '
'As a result, the output won\'t be parallel with the input. '
'Increase the maximum length (--max-seq-len M:N) or trim your training data.',
num_discarded,
utils.inflect('sentence', num_discarded),
args.max_seq_len,
utils.inflect('was', num_discarded))


if __name__ == "__main__":
main()
49 changes: 28 additions & 21 deletions sockeye/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
Code for scoring.
"""
import logging
import math
import os
import time
from typing import List, Optional, Tuple

import mxnet as mx
import numpy as np

from . import constants as C
from . import data_io
Expand Down Expand Up @@ -56,14 +58,12 @@ def __init__(self,
context: List[mx.context.Context],
provide_data: List[mx.io.DataDesc],
provide_label: List[mx.io.DataDesc],
bucketing: bool,
default_bucket_key: Tuple[int, int],
score_type: str,
length_penalty: inference.LengthPenalty,
softmax_temperature: Optional[float] = None) -> None:
super().__init__(config)
self.context = context
self.bucketing = bucketing
self.score_type = score_type
self.length_penalty = length_penalty
self.softmax_temperature = softmax_temperature
Expand Down Expand Up @@ -170,19 +170,12 @@ def sym_gen(seq_lens):
# sums: (batch_size,) target_dists: (batch_size, target_seq_len, target_vocab_size)
return mx.sym.Group([sums, target_dists]), data_names, label_names

if self.bucketing:
logger.info("Using bucketing. Default max_seq_len=%s", default_bucket_key)
self.module = mx.mod.BucketingModule(sym_gen=sym_gen,
logger=logger,
default_bucket_key=default_bucket_key,
context=self.context)
else:
symbol, _, __ = sym_gen(default_bucket_key)
self.module = mx.mod.Module(symbol=symbol,
data_names=data_names,
label_names=label_names,
logger=logger,
context=self.context)
symbol, _, __ = sym_gen(default_bucket_key)
self.module = mx.mod.Module(symbol=symbol,
data_names=data_names,
label_names=label_names,
logger=logger,
context=self.context)

self.module.bind(data_shapes=provide_data,
label_shapes=provide_label,
Expand Down Expand Up @@ -225,8 +218,8 @@ def score(self,

total_time = 0.
sentence_no = 0
for i, batch in enumerate(score_iter):

batch_no = 0
for batch_no, batch in enumerate(score_iter, 1):
batch_tic = time.time()

# Run the model and get the outputs
Expand All @@ -235,10 +228,12 @@ def score(self,
batch_time = time.time() - batch_tic
total_time += batch_time

for source, target, score in zip(batch.data[0], batch.data[1], scores):
batch_size = len(batch.data[0])

for sentno, (source, target, score) in enumerate(zip(batch.data[0], batch.data[1], scores), 1):

# The "zeros" padding method will have filled remainder batches with zeros, so we can skip them here
if source[0][0] == C.PAD_ID:
# The last batch may be underfilled, in which case batch.pad will be set
if sentno > (batch_size - batch.pad):
break

sentence_no += 1
Expand All @@ -249,9 +244,21 @@ def score(self,
target_ids = [int(x) for x in target.asnumpy().tolist()]
target_string = C.TOKEN_SEPARATOR.join(
data_io.ids2tokens(target_ids, self.target_vocab_inv, self.exclude_list))
score = score.asscalar()

# Report a score of -inf for invalid sentence pairs (empty source and/or target)
if source[0][0] == C.PAD_ID or target[0] == C.PAD_ID:
score = -np.inf
else:
score = score.asscalar()

# Output handling routines require us to make use of inference classes.
output_handler.handle(TranslatorInput(sentence_no, source_tokens),
TranslatorOutput(sentence_no, target_string, None, None, score),
batch_time)

if sentence_no != 0:
logger.info("Processed %d lines in %d batches. Total time: %.4f, sec/sent: %.4f, sent/sec: %.4f",
sentence_no, math.ceil(sentence_no / batch_no), total_time,
total_time / sentence_no, sentence_no / total_time)
else:
logger.info("Processed 0 lines.")
Loading

0 comments on commit 02ff337

Please sign in to comment.