Skip to content

Commit

Permalink
Typing, mypy and code formatting (awslabs#274)
Browse files Browse the repository at this point in the history
  • Loading branch information
fhieber authored Jan 18, 2018
1 parent 7391058 commit 17b4d32
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 11 deletions.
10 changes: 5 additions & 5 deletions sockeye/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
from . import utils

logger = logging.getLogger(__name__)
EncoderConfigs = Union['RecurrentEncoderConfig', transformer.TransformerConfig, 'ConvolutionalEncoderConfig']
EncoderConfig = Union['RecurrentEncoderConfig', transformer.TransformerConfig, 'ConvolutionalEncoderConfig']


def get_encoder(config: EncoderConfigs) -> 'Encoder':
def get_encoder(config: EncoderConfig) -> 'Encoder':
if isinstance(config, RecurrentEncoderConfig):
return get_recurrent_encoder(config)
elif isinstance(config, transformer.TransformerConfig):
Expand Down Expand Up @@ -125,8 +125,8 @@ def get_recurrent_encoder(config: RecurrentEncoderConfig) -> 'Encoder':
remaining_rnn_config = config.rnn_config.copy(num_layers=config.rnn_config.num_layers - 1,
first_residual_layer=config.rnn_config.first_residual_layer - 1)
encoders.append(RecurrentEncoder(rnn_config=remaining_rnn_config,
prefix=C.STACKEDRNN_PREFIX,
layout=C.TIME_MAJOR))
prefix=C.STACKEDRNN_PREFIX,
layout=C.TIME_MAJOR))

return EncoderSequence(encoders)

Expand Down Expand Up @@ -800,7 +800,7 @@ def encode(self,
max_length=seq_len,
num_heads=self.config.attention_heads,
fold_heads=True,
name="%sbias"% self.prefix), axis=1)
name="%sbias" % self.prefix), axis=1)

for i, layer in enumerate(self.layers):
# (batch_size, seq_len, config.model_size)
Expand Down
1 change: 0 additions & 1 deletion sockeye/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,6 @@ def __call__(self,
return contexts



class PositionalEncodings(mx.operator.CustomOp):
"""
Returns a symbol of shape (1, max_seq_len, num_embed)
Expand Down
5 changes: 3 additions & 2 deletions sockeye/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ModelConfig(Config):
:param weight_tying: Enables weight tying if True.
:param weight_tying_type: Determines which weights get tied. Must be set if weight_tying is enabled.
"""

def __init__(self,
config_data: data_io.DataConfig,
max_seq_len_source: int,
Expand All @@ -58,8 +59,8 @@ def __init__(self,
vocab_target_size: int,
config_embed_source: encoder.EmbeddingConfig,
config_embed_target: encoder.EmbeddingConfig,
config_encoder: Config,
config_decoder: Config,
config_encoder: encoder.EncoderConfig,
config_decoder: decoder.DecoderConfig,
config_loss: loss.LossConfig,
weight_tying: bool = False,
weight_tying_type: Optional[str] = C.WEIGHT_TYING_TRG_SOFTMAX,
Expand Down
6 changes: 4 additions & 2 deletions sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,8 @@ def create_lr_scheduler(args: argparse.Namespace, resume_training: bool,


def create_encoder_config(args: argparse.Namespace,
config_conv: Optional[encoder.ConvolutionalEmbeddingConfig]) -> Tuple[Config, int]:
config_conv: Optional[encoder.ConvolutionalEmbeddingConfig]) -> Tuple[encoder.EncoderConfig,
int]:
"""
Create the encoder config.
Expand Down Expand Up @@ -456,11 +457,12 @@ def create_encoder_config(args: argparse.Namespace,
return config_encoder, encoder_num_hidden


def create_decoder_config(args: argparse.Namespace, encoder_num_hidden: int) -> Config:
def create_decoder_config(args: argparse.Namespace, encoder_num_hidden: int) -> decoder.DecoderConfig:
"""
Create the config for the decoder.
:param args: Arguments as returned by argparse.
:param encoder_num_hidden: Number of hidden units of the Encoder.
:return: The config for the decoder.
"""
_, decoder_num_layers = args.num_layers
Expand Down
2 changes: 1 addition & 1 deletion sockeye/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def __call__(self,
source: mx.sym.Symbol,
source_bias: mx.sym.Symbol,
cache: Optional[Dict[str, Optional[mx.sym.Symbol]]] = None) -> mx.sym.Symbol:

# self-attention
target_self_att = self.self_attention(inputs=self.pre_self_attention(target, None),
bias=target_bias,
Expand Down Expand Up @@ -243,6 +242,7 @@ class TransformerFeedForward:
"""
Position-wise feed-forward network with activation.
"""

def __init__(self,
num_hidden: int,
num_model: int,
Expand Down

0 comments on commit 17b4d32

Please sign in to comment.