Skip to content

Commit 24a8242

Browse files
authored
Merge pull request RasaHQ#5185 from RasaHQ/tf2-rel-attn
relative attention
2 parents fa5857d + 4108753 commit 24a8242

File tree

15 files changed

+717
-366
lines changed

15 files changed

+717
-366
lines changed

rasa/core/policies/embedding_policy.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
from rasa.constants import DOCS_BASE_URL
88
from rasa.utils.tensorflow.constants import (
99
HIDDEN_LAYERS_SIZES_LABEL,
10+
TRANSFORMER_SIZE,
1011
NUM_TRANSFORMER_LAYERS,
12+
NUM_HEADS,
13+
MAX_SEQ_LENGTH,
1114
BATCH_SIZES,
1215
BATCH_STRATEGY,
1316
EPOCHS,
@@ -18,22 +21,23 @@
1821
NUM_NEG,
1922
EVAL_NUM_EXAMPLES,
2023
EVAL_NUM_EPOCHS,
21-
C_EMB,
22-
C2,
24+
NEG_MARGIN_SCALE,
25+
REGULARIZATION_CONSTANT,
2326
SCALE_LOSS,
2427
USE_MAX_SIM_NEG,
2528
MU_NEG,
2629
MU_POS,
2730
EMBED_DIM,
2831
HIDDEN_LAYERS_SIZES_DIALOGUE,
29-
TRANSFORMER_SIZE,
30-
MAX_SEQ_LENGTH,
31-
NUM_HEADS,
3232
DROPRATE_DIALOGUE,
3333
DROPRATE_LABEL,
34+
DROPRATE_ATTENTION,
35+
KEY_RELATIVE_ATTENTION,
36+
VALUE_RELATIVE_ATTENTION,
37+
MAX_RELATIVE_POSITION,
3438
)
3539
from rasa.utils.common import raise_warning
36-
from rasa.utils.tensorflow.tf_models import RasaModel
40+
from rasa.utils.tensorflow.models import RasaModel
3741

3842
logger = logging.getLogger(__name__)
3943

@@ -94,20 +98,28 @@ class EmbeddingPolicy(TEDPolicy):
9498
# scale loss inverse proportionally to confidence of correct prediction
9599
SCALE_LOSS: True,
96100
# regularization
97-
# the scale of L2 regularization
98-
C2: 0.001,
101+
# the scale of regularization
102+
REGULARIZATION_CONSTANT: 0.001,
99103
# the scale of how important is to minimize the maximum similarity
100104
# between embeddings of different labels
101-
C_EMB: 0.8,
105+
NEG_MARGIN_SCALE: 0.8,
102106
# dropout rate for dial nn
103107
DROPRATE_DIALOGUE: 0.1,
104108
# dropout rate for bot nn
105109
DROPRATE_LABEL: 0.0,
110+
# dropout rate for attention
111+
DROPRATE_ATTENTION: 0,
106112
# visualization of accuracy
107113
# how often calculate validation accuracy
108114
EVAL_NUM_EPOCHS: 20, # small values may hurt performance
109115
# how many examples to use for hold out validation set
110116
EVAL_NUM_EXAMPLES: 0, # large values may hurt performance
117+
# if true use key relative embeddings in attention
118+
KEY_RELATIVE_ATTENTION: False,
119+
# if true use key relative embeddings in attention
120+
VALUE_RELATIVE_ATTENTION: False,
121+
# max position for relative embeddings
122+
MAX_RELATIVE_POSITION: None,
111123
}
112124
# end default properties (DOC MARKER - don't remove)
113125

rasa/core/policies/ted_policy.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
from rasa.core.constants import DEFAULT_POLICY_PRIORITY
2222
from rasa.core.trackers import DialogueStateTracker
2323
from rasa.utils import train_utils
24-
from rasa.utils.tensorflow import tf_layers
25-
from rasa.utils.tensorflow.tf_models import RasaModel
26-
from rasa.utils.tensorflow.tf_model_data import RasaModelData, FeatureSignature
24+
from rasa.utils.tensorflow import layers
25+
from rasa.utils.tensorflow.transformer import TransformerEncoder
26+
from rasa.utils.tensorflow.models import RasaModel
27+
from rasa.utils.tensorflow.model_data import RasaModelData, FeatureSignature
2728
from rasa.utils.tensorflow.constants import (
2829
HIDDEN_LAYERS_SIZES_LABEL,
2930
TRANSFORMER_SIZE,
@@ -40,8 +41,8 @@
4041
NUM_NEG,
4142
EVAL_NUM_EXAMPLES,
4243
EVAL_NUM_EPOCHS,
43-
C_EMB,
44-
C2,
44+
NEG_MARGIN_SCALE,
45+
REGULARIZATION_CONSTANT,
4546
SCALE_LOSS,
4647
USE_MAX_SIM_NEG,
4748
MU_NEG,
@@ -50,6 +51,10 @@
5051
HIDDEN_LAYERS_SIZES_DIALOGUE,
5152
DROPRATE_DIALOGUE,
5253
DROPRATE_LABEL,
54+
DROPRATE_ATTENTION,
55+
KEY_RELATIVE_ATTENTION,
56+
VALUE_RELATIVE_ATTENTION,
57+
MAX_RELATIVE_POSITION,
5358
)
5459

5560

@@ -114,20 +119,28 @@ class TEDPolicy(Policy):
114119
# scale loss inverse proportionally to confidence of correct prediction
115120
SCALE_LOSS: True,
116121
# regularization
117-
# the scale of L2 regularization
118-
C2: 0.001,
122+
# the scale of regularization
123+
REGULARIZATION_CONSTANT: 0.001,
119124
# the scale of how important is to minimize the maximum similarity
120125
# between embeddings of different labels
121-
C_EMB: 0.8,
126+
NEG_MARGIN_SCALE: 0.8,
122127
# dropout rate for dial nn
123128
DROPRATE_DIALOGUE: 0.1,
124129
# dropout rate for bot nn
125130
DROPRATE_LABEL: 0.0,
131+
# dropout rate for attention
132+
DROPRATE_ATTENTION: 0,
126133
# visualization of accuracy
127134
# how often calculate validation accuracy
128135
EVAL_NUM_EPOCHS: 20, # small values may hurt performance
129136
# how many examples to use for hold out validation set
130137
EVAL_NUM_EXAMPLES: 0, # large values may hurt performance
138+
# if true use key relative embeddings in attention
139+
KEY_RELATIVE_ATTENTION: False,
140+
# if true use key relative embeddings in attention
141+
VALUE_RELATIVE_ATTENTION: False,
142+
# max position for relative embeddings
143+
MAX_RELATIVE_POSITION: None,
131144
}
132145
# end default properties (DOC MARKER - don't remove)
133146

@@ -471,50 +484,53 @@ def __init__(
471484
self._prepare_layers()
472485

473486
def _prepare_layers(self) -> None:
474-
self._tf_layers["loss.label"] = tf_layers.DotProductLoss(
487+
self._tf_layers["loss.label"] = layers.DotProductLoss(
475488
self.config[NUM_NEG],
476489
self.config[LOSS_TYPE],
477490
self.config[MU_POS],
478491
self.config[MU_NEG],
479492
self.config[USE_MAX_SIM_NEG],
480-
self.config[C_EMB],
493+
self.config[NEG_MARGIN_SCALE],
481494
self.config[SCALE_LOSS],
482495
# set to 1 to get deterministic behaviour
483496
parallel_iterations=1 if self.random_seed is not None else 1000,
484497
)
485-
self._tf_layers["ffnn.dialogue"] = tf_layers.Ffnn(
498+
self._tf_layers["ffnn.dialogue"] = layers.Ffnn(
486499
self.config[HIDDEN_LAYERS_SIZES_DIALOGUE],
487500
self.config[DROPRATE_DIALOGUE],
488-
self.config[C2],
501+
self.config[REGULARIZATION_CONSTANT],
489502
layer_name_suffix="dialogue",
490503
)
491-
self._tf_layers["ffnn.label"] = tf_layers.Ffnn(
504+
self._tf_layers["ffnn.label"] = layers.Ffnn(
492505
self.config[HIDDEN_LAYERS_SIZES_LABEL],
493506
self.config[DROPRATE_LABEL],
494-
self.config[C2],
507+
self.config[REGULARIZATION_CONSTANT],
495508
layer_name_suffix="label",
496509
)
497-
self._tf_layers["transformer"] = tf_layers.TransformerEncoder(
510+
self._tf_layers["transformer"] = TransformerEncoder(
498511
self.config[NUM_TRANSFORMER_LAYERS],
499512
self.config[TRANSFORMER_SIZE],
500513
self.config[NUM_HEADS],
501514
self.config[TRANSFORMER_SIZE] * 4,
502515
self.config[MAX_SEQ_LENGTH],
503-
self.config[C2],
516+
self.config[REGULARIZATION_CONSTANT],
504517
dropout_rate=self.config[DROPRATE_DIALOGUE],
505-
attention_dropout_rate=0,
518+
attention_dropout_rate=self.config[DROPRATE_ATTENTION],
506519
unidirectional=True,
520+
use_key_relative_position=self.config[KEY_RELATIVE_ATTENTION],
521+
use_value_relative_position=self.config[VALUE_RELATIVE_ATTENTION],
522+
max_relative_position=self.config[MAX_RELATIVE_POSITION],
507523
name="dialogue_encoder",
508524
)
509-
self._tf_layers["embed.dialogue"] = tf_layers.Embed(
525+
self._tf_layers["embed.dialogue"] = layers.Embed(
510526
self.config[EMBED_DIM],
511-
self.config[C2],
527+
self.config[REGULARIZATION_CONSTANT],
512528
"dialogue",
513529
self.config[SIMILARITY_TYPE],
514530
)
515-
self._tf_layers["embed.label"] = tf_layers.Embed(
531+
self._tf_layers["embed.label"] = layers.Embed(
516532
self.config[EMBED_DIM],
517-
self.config[C2],
533+
self.config[REGULARIZATION_CONSTANT],
518534
"label",
519535
self.config[SIMILARITY_TYPE],
520536
)

0 commit comments

Comments
 (0)