18
18
MaxHistoryTrackerFeaturizer ,
19
19
)
20
20
from rasa .core .policies .policy import Policy
21
- from rasa .core .constants import DEFAULT_POLICY_PRIORITY
21
+ from rasa .core .constants import DEFAULT_POLICY_PRIORITY , DIALOGUE
22
22
from rasa .core .trackers import DialogueStateTracker
23
23
from rasa .utils import train_utils
24
24
from rasa .utils .tensorflow import layers
25
25
from rasa .utils .tensorflow .transformer import TransformerEncoder
26
26
from rasa .utils .tensorflow .models import RasaModel
27
27
from rasa .utils .tensorflow .model_data import RasaModelData , FeatureSignature
28
28
from rasa .utils .tensorflow .constants import (
29
- HIDDEN_LAYERS_SIZES_LABEL ,
29
+ LABEL ,
30
+ HIDDEN_LAYERS_SIZES ,
30
31
TRANSFORMER_SIZE ,
31
32
NUM_TRANSFORMER_LAYERS ,
32
33
NUM_HEADS ,
48
49
MU_NEG ,
49
50
MU_POS ,
50
51
EMBED_DIM ,
51
- HIDDEN_LAYERS_SIZES_DIALOGUE ,
52
52
DROPRATE_DIALOGUE ,
53
53
DROPRATE_LABEL ,
54
54
DROPRATE_ATTENTION ,
@@ -72,12 +72,9 @@ class TEDPolicy(Policy):
72
72
# default properties (DOC MARKER - don't remove)
73
73
defaults = {
74
74
# nn architecture
75
- # a list of hidden layers sizes before user embed layer
75
+ # a list of hidden layers sizes before dialogue and action embed layers
76
76
# number of hidden layers is equal to the length of this list
77
- HIDDEN_LAYERS_SIZES_DIALOGUE : [],
78
- # a list of hidden layers sizes before bot embed layer
79
- # number of hidden layers is equal to the length of this list
80
- HIDDEN_LAYERS_SIZES_LABEL : [],
77
+ HIDDEN_LAYERS_SIZES : {DIALOGUE : [], LABEL : []},
81
78
# number of units in transformer
82
79
TRANSFORMER_SIZE : 128 ,
83
80
# number of transformer layers
@@ -262,8 +259,6 @@ def train(
262
259
) -> None :
263
260
"""Train the policy on given training trackers."""
264
261
265
- logger .debug ("Started training embedding policy." )
266
-
267
262
# set numpy random seed
268
263
np .random .seed (self .config [RANDOM_SEED ])
269
264
@@ -282,6 +277,12 @@ def train(
282
277
283
278
# extract actual training data to feed to model
284
279
model_data = self ._create_model_data (training_data .X , training_data .y )
280
+ if model_data .is_empty ():
281
+ logger .error (
282
+ f"Can not train '{ self .__class__ .__name__ } '. No data was provided. "
283
+ f"Skipping training of the policy."
284
+ )
285
+ return
285
286
286
287
# keep one example for persisting and loading
287
288
self .data_example = {k : [v [:1 ] for v in vs ] for k , vs in model_data .items ()}
@@ -452,14 +453,16 @@ def __init__(
452
453
config : Dict [Text , Any ],
453
454
max_history_tracker_featurizer_used : bool ,
454
455
label_data : RasaModelData ,
455
- ):
456
+ ) -> None :
456
457
super ().__init__ (name = "TED" , random_seed = config [RANDOM_SEED ])
457
458
458
459
self .config = config
459
460
self .max_history_tracker_featurizer_used = max_history_tracker_featurizer_used
460
461
461
462
# data
462
463
self .data_signature = data_signature
464
+ self ._check_data ()
465
+
463
466
self .predict_data_signature = {
464
467
k : vs for k , vs in data_signature .items () if "dialogue" in k
465
468
}
@@ -475,14 +478,26 @@ def __init__(
475
478
)
476
479
477
480
# metrics
478
- self .metric_loss = tf .keras .metrics .Mean (name = "loss" )
479
- self .metric_acc = tf .keras .metrics .Mean (name = "acc" )
481
+ self .action_loss = tf .keras .metrics .Mean (name = "loss" )
482
+ self .action_acc = tf .keras .metrics .Mean (name = "acc" )
480
483
self .metrics_to_log += ["loss" , "acc" ]
481
484
482
485
# set up tf layers
483
486
self ._tf_layers = {}
484
487
self ._prepare_layers ()
485
488
489
+ def _check_data (self ) -> None :
490
+ if "dialogue_features" not in self .data_signature :
491
+ raise ValueError (
492
+ f"No text features specified. "
493
+ f"Cannot train '{ self .__class__ .__name__ } ' model."
494
+ )
495
+ if "label_features" not in self .data_signature :
496
+ raise ValueError (
497
+ f"No label features specified. "
498
+ f"Cannot train '{ self .__class__ .__name__ } ' model."
499
+ )
500
+
486
501
def _prepare_layers (self ) -> None :
487
502
self ._tf_layers ["loss.label" ] = layers .DotProductLoss (
488
503
self .config [NUM_NEG ],
@@ -496,16 +511,16 @@ def _prepare_layers(self) -> None:
496
511
parallel_iterations = 1 if self .random_seed is not None else 1000 ,
497
512
)
498
513
self ._tf_layers ["ffnn.dialogue" ] = layers .Ffnn (
499
- self .config [HIDDEN_LAYERS_SIZES_DIALOGUE ],
514
+ self .config [HIDDEN_LAYERS_SIZES ][ DIALOGUE ],
500
515
self .config [DROPRATE_DIALOGUE ],
501
516
self .config [REGULARIZATION_CONSTANT ],
502
- layer_name_suffix = "dialogue" ,
517
+ layer_name_suffix = DIALOGUE ,
503
518
)
504
519
self ._tf_layers ["ffnn.label" ] = layers .Ffnn (
505
- self .config [HIDDEN_LAYERS_SIZES_LABEL ],
520
+ self .config [HIDDEN_LAYERS_SIZES ][ LABEL ],
506
521
self .config [DROPRATE_LABEL ],
507
522
self .config [REGULARIZATION_CONSTANT ],
508
- layer_name_suffix = "label" ,
523
+ layer_name_suffix = LABEL ,
509
524
)
510
525
self ._tf_layers ["transformer" ] = TransformerEncoder (
511
526
self .config [NUM_TRANSFORMER_LAYERS ],
@@ -520,18 +535,18 @@ def _prepare_layers(self) -> None:
520
535
use_key_relative_position = self .config [KEY_RELATIVE_ATTENTION ],
521
536
use_value_relative_position = self .config [VALUE_RELATIVE_ATTENTION ],
522
537
max_relative_position = self .config [MAX_RELATIVE_POSITION ],
523
- name = "dialogue_encoder " ,
538
+ name = DIALOGUE + "_encoder " ,
524
539
)
525
540
self ._tf_layers ["embed.dialogue" ] = layers .Embed (
526
541
self .config [EMBED_DIM ],
527
542
self .config [REGULARIZATION_CONSTANT ],
528
- "dialogue" ,
543
+ DIALOGUE ,
529
544
self .config [SIMILARITY_TYPE ],
530
545
)
531
546
self ._tf_layers ["embed.label" ] = layers .Embed (
532
547
self .config [EMBED_DIM ],
533
548
self .config [REGULARIZATION_CONSTANT ],
534
- "label" ,
549
+ LABEL ,
535
550
self .config [SIMILARITY_TYPE ],
536
551
)
537
552
@@ -588,8 +603,8 @@ def batch_loss(
588
603
dialogue_embed , label_embed , label_in , all_labels_embed , all_labels , mask
589
604
)
590
605
591
- self .metric_loss .update_state (loss )
592
- self .metric_acc .update_state (acc )
606
+ self .action_loss .update_state (loss )
607
+ self .action_acc .update_state (acc )
593
608
594
609
return loss
595
610
0 commit comments