Skip to content

Commit

Permalink
QKV projection with dot attention for CNN models. (awslabs#249)
Browse files Browse the repository at this point in the history
* Optional QKV projection in cnn attention.

* Use projection in system test.

* Changelog.

* tests fixed

* changelog fixed

* minor

* Make mypy happy
  • Loading branch information
tdomhan authored and fhieber committed Dec 18, 2017
1 parent 70175ac commit 7baef3e
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 13 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.15.5]
### Added
- Optionally apply query, key and value projections to the source and target hidden vectors in the CNN model
before applying the attention mechanism. CLI parameter: `--cnn-project-qkv`.

## [1.15.4]
### Added
- A warning will be printed if the checkpoint decoder slows down training.

## [1.15.3]
### Added
- Exposing the xavier random number generator through --weight-init-xavier-rand-type.
- Exposing the xavier random number generator through `--weight-init-xavier-rand-type`.

## [1.15.2]
### Added
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.15.4'
__version__ = '1.15.5'
6 changes: 6 additions & 0 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,12 @@ def add_model_parameters(params):
choices=C.POSITIONAL_EMBEDDING_TYPES,
default=C.LEARNED_POSITIONAL_EMBEDDING,
help='The type of positional embedding. Default: %(default)s.')
model_params.add_argument('--cnn-project-qkv',
action='store_true',
default=False,
help="Optionally apply query, key and value projections to the source and target hidden "
"vectors before applying the attention mechanism.")


# rnn arguments
model_params.add_argument('--rnn-cell-type',
Expand Down
24 changes: 14 additions & 10 deletions sockeye/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ def __init__(self,
encoder_num_hidden: int,
num_layers: int,
positional_embedding_type: str,
project_qkv: bool = False,
hidden_dropout: float = .0) -> None:
super().__init__()
self.cnn_config = cnn_config
Expand All @@ -905,6 +906,7 @@ def __init__(self,
self.encoder_num_hidden = encoder_num_hidden
self.num_layers = num_layers
self.positional_embedding_type = positional_embedding_type
self.project_qkv = project_qkv
self.hidden_dropout = hidden_dropout


Expand Down Expand Up @@ -950,6 +952,12 @@ def __init__(self,
config.cnn_config,
pad_type='left',
prefix="%s%d_" % (prefix, i)) for i in range(config.num_layers)]
if self.config.project_qkv:
self.attention_layers = [layers.ProjectedDotAttention("%s%d_" % (prefix, i),
self.config.cnn_config.num_hidden)
for i in range(config.num_layers)]
else:
self.attention_layers = [layers.PlainDotAttention() for _ in range(config.num_layers)] # type: ignore

self.i2h_weight = mx.sym.Variable('%si2h_weight' % prefix)

Expand Down Expand Up @@ -1013,16 +1021,13 @@ def _decode(self,

drop_prob = self.config.hidden_dropout

for layer in self.layers:
for layer, att_layer in zip(self.layers, self.attention_layers):
# (batch_size, target_seq_len, num_hidden)
target_hidden = layer(mx.sym.Dropout(target_hidden, p=drop_prob) if drop_prob > 0 else target_hidden,
target_embed_lengths, target_embed_max_length)

# (batch_size, target_seq_len, num_embed)
context = layers.dot_attention(queries=target_hidden,
keys=source_encoded,
values=source_encoded,
lengths=source_encoded_lengths)
context = att_layer(target_hidden, source_encoded, source_encoded_lengths)

# residual connection:
target_hidden = target_hidden_prev + target_hidden + context
Expand Down Expand Up @@ -1082,16 +1087,15 @@ def decode_step(self,

drop_prob = self.config.hidden_dropout

for layer, layer_state in zip(self.layers, cnn_layer_states):
for layer, att_layer, layer_state in zip(self.layers, self.attention_layers, cnn_layer_states):
# (batch_size, kernel_width, num_hidden) -> (batch_size, 1, num_hidden)
target_hidden_step = layer.step(mx.sym.Dropout(target_hidden, p=drop_prob)
if drop_prob > 0 else target_hidden)

# (batch_size, 1, num_embed)
context_step = layers.dot_attention(queries=target_hidden_step,
keys=source_encoded,
values=source_encoded,
lengths=source_encoded_lengths)
# TODO: compute the source encoded projection only once for efficiency reasons
context_step = att_layer(target_hidden_step, source_encoded, source_encoded_lengths)

# residual connection:
target_hidden_step = target_hidden_step_prev + target_hidden_step + context_step
target_hidden_step_prev = target_hidden_step
Expand Down
83 changes: 83 additions & 0 deletions sockeye/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,89 @@ def __call__(self,
bias=bias)


class ProjectedDotAttention:
"""
Dot attention layer for queries independent from keys/values.
:param prefix: Attention prefix.
:param num_hidden: Attention depth / number of hidden units.
"""

def __init__(self,
prefix: str,
num_hidden) -> None:
self.prefix = prefix
self.num_hidden = num_hidden
self.w_q2h = mx.sym.Variable("%sq2h_weight" % prefix)
self.b_q2h = mx.sym.Variable("%sq2h_bias" % prefix)
self.w_kv2h = mx.sym.Variable("%skv2h_weight" % prefix)
self.b_kv2h = mx.sym.Variable("%skv2h_bias" % prefix)

def __call__(self,
queries: mx.sym.Symbol,
memory: mx.sym.Symbol,
memory_lengths: mx.sym.Symbol) -> mx.sym.Symbol:
"""
Apply project, apply dot attention and return new context vectors.
:param queries: Symbol of shape (batch, queries_max_length, input_num_hidden).
:param memory: Symbol of shape (batch, memory_max_length, input_num_hidden).
:param memory_lengths: Symbol of shape (batch, 1).
:return: Symbol of shape (batch, queries_max_length, num_hidden).
"""
# (batch, memory_max_length, num_hidden * 2)
combined = mx.sym.FullyConnected(data=memory,
weight=self.w_kv2h,
bias=self.b_kv2h,
num_hidden=self.num_hidden * 2,
flatten=False,
name="%skv_transform" % self.prefix)

# split into keys and values
# pylint: disable=unbalanced-tuple-unpacking
keys, values = mx.sym.split(data=combined, num_outputs=2, axis=2)

# (batch, queries_max_length, num_hidden)
queries = mx.sym.FullyConnected(data=queries,
weight=self.w_q2h,
bias=self.b_q2h,
num_hidden=self.num_hidden,
flatten=False,
name="%sq_transform" % self.prefix)
# scale by sqrt(num_hidden)
queries = queries * (self.num_hidden ** -0.5)

# (batch, queries_max_length, num_hidden)
contexts = dot_attention(queries, keys, values, memory_lengths)

return contexts


class PlainDotAttention:
"""
Dot attention layer for queries independent from keys/values.
"""

def __call__(self,
queries: mx.sym.Symbol,
memory: mx.sym.Symbol,
memory_lengths: mx.sym.Symbol) -> mx.sym.Symbol:
"""
Returns a symbol of shape (batch, max_length, output_depth).
:param queries: Symbol of shape (batch, queries_max_length, input_depth).
:param memory: Symbol of shape (batch, memory_max_length, input_depth).
:param memory_lengths: Symbol of shape (batch, 1).
:return: Symbol of shape (batch, queries_max_length, output_depth).
"""

# (batch*heads, queries_max_length, depth_per_head)
contexts = dot_attention(queries, memory, memory, memory_lengths)

return contexts



class PositionalEncodings(mx.operator.CustomOp):
"""
Returns a symbol of shape (1, max_seq_len, num_embed)
Expand Down
1 change: 1 addition & 0 deletions sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def create_decoder_config(args: argparse.Namespace, encoder_num_hidden: int) ->
encoder_num_hidden=encoder_num_hidden,
num_layers=decoder_num_layers,
positional_embedding_type=args.cnn_positional_embedding_type,
project_qkv=args.cnn_project_qkv,
hidden_dropout=args.cnn_hidden_dropout)

else:
Expand Down
2 changes: 1 addition & 1 deletion test/system/test_seq_copy_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
("Copy:cnn:cnn",
"--encoder cnn --decoder cnn "
" --batch-size 16 --num-layers 3 --max-updates 3000"
" --cnn-num-hidden 32 --cnn-positional-embedding-type fixed"
" --cnn-num-hidden 32 --cnn-positional-embedding-type fixed --cnn-project-qkv "
" --checkpoint-frequency 1000 --optimizer adam --initial-learning-rate 0.001",
"--beam-size 1",
1.02,
Expand Down
1 change: 1 addition & 0 deletions test/unit/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def test_device_args(test_params, expected_params):
cnn_kernel_width=(3, 5),
cnn_num_hidden=512,
cnn_positional_embedding_type="learned",
cnn_project_qkv=False,
layer_normalization=False,
weight_normalization=False,
encoder=C.RNN_NAME,
Expand Down

0 comments on commit 7baef3e

Please sign in to comment.