Skip to content

Commit

Permalink
Add a local attention hparam set for pg19.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 372158240
  • Loading branch information
royaurko authored and copybara-github committed May 5, 2021
1 parent 38dee8f commit 8bd05d6
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions routing_transformer/sparse_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,39 +581,40 @@ def wikitext103_local_hash1k():


@registry.register_hparams
def pg19_local_cluster8k():
"""Routing attention on sequence length 8k."""
def pg19_local8k():
"""Local attention on sequence length 8k."""
hparams = wikitext103_local4k()
hparams.max_length = 8192
hparams.batch_size = 8192
hparams.max_target_length = 8192
hparams.hidden_size = 1032
hparams.embedding_dims = 1032
hparams.filter_size = 4096
hparams.local_num_heads = 6
hparams.sparsity_cluster_num_heads = 2
hparams.num_decoder_layers = 22
hparams.sparsity_skip_first = 21
hparams.sparsity_cluster_size = 16
hparams.local_num_heads = 8
hparams.sparsity_cluster_num_heads = 0
hparams.num_decoder_layers = 24
hparams.query_shape = (256,)
hparams.memory_flange = (256,)
hparams.attention_dropout = 0.0
hparams.relu_dropout = 0.0
hparams.dropout = 0.0
hparams.target_dropout = 0.0
hparams.sparsity_cluster_attention_window = 512
hparams.max_relative_position = 513
hparams.weight_decay = 0
return hparams


@registry.register_hparams
def pg19_local8k():
"""Local attention on sequence length 8k."""
hparams = pg19_local_cluster8k()
hparams.local_num_heads = 8
hparams.sparsity_cluster_num_heads = 0
hparams.num_decoder_layers = 24
def pg19_local_cluster8k():
"""Routing attention on sequence length 8k."""
hparams = pg19_local8k()
hparams.local_num_heads = 6
hparams.sparsity_cluster_num_heads = 2
hparams.num_decoder_layers = 22
hparams.sparsity_skip_first = 21
hparams.sparsity_cluster_size = 16
hparams.sparsity_cluster_attention_window = 512
hparams.max_relative_position = 513
return hparams


Expand Down

0 comments on commit 8bd05d6

Please sign in to comment.