@@ -74,6 +74,9 @@ class EmbeddingPolicy(Policy):
74
74
"similarity_type" : "auto" , # string 'auto' or 'cosine' or 'inner'
75
75
# the type of the loss function
76
76
"loss_type" : "softmax" , # string 'softmax' or 'margin'
77
+ # number of top actions to normalize scores for softmax loss_type
78
+ # set to 0 to turn off normalization
79
+ "ranking_length" : 10 ,
77
80
# how similar the algorithm should try
78
81
# to make embedding vectors for correct labels
79
82
"mu_pos" : 0.8 , # should be 0.0 < ... < 1.0 for 'cosine'
@@ -192,6 +195,7 @@ def _load_embedding_params(self, config: Dict[Text, Any]) -> None:
192
195
self .similarity_type = "inner"
193
196
elif self .loss_type == "margin" :
194
197
self .similarity_type = "cosine"
198
+ self .ranking_length = config ["ranking_length" ]
195
199
196
200
self .mu_pos = config ["mu_pos" ]
197
201
self .mu_neg = config ["mu_neg" ]
@@ -567,8 +571,12 @@ def predict_action_probabilities(
567
571
tf_feed_dict = self .tf_feed_dict_for_prediction (tracker , domain )
568
572
569
573
confidence = self .session .run (self .pred_confidence , feed_dict = tf_feed_dict )
574
+ confidence = confidence [0 , - 1 , :]
570
575
571
- return confidence [0 , - 1 , :].tolist ()
576
+ if self .loss_type == "softmax" and self .ranking_length > 0 :
577
+ confidence = train_utils .normalize (confidence , self .ranking_length )
578
+
579
+ return confidence .tolist ()
572
580
573
581
def persist (self , path : Text ) -> None :
574
582
"""Persists the policy to a storage."""
@@ -583,7 +591,11 @@ def persist(self, path: Text) -> None:
583
591
584
592
self .featurizer .persist (path )
585
593
586
- meta = {"priority" : self .priority }
594
+ meta = {
595
+ "priority" : self .priority ,
596
+ "loss_type" : self .loss_type ,
597
+ "ranking_length" : self .ranking_length ,
598
+ }
587
599
588
600
meta_file = os .path .join (path , "embedding_policy.json" )
589
601
rasa .utils .io .dump_obj_as_json_to_file (meta_file , meta )
@@ -665,7 +677,7 @@ def load(cls, path: Text) -> "EmbeddingPolicy":
665
677
666
678
return cls (
667
679
featurizer = featurizer ,
668
- priority = meta [ "priority" ] ,
680
+ priority = meta . pop ( "priority" ) ,
669
681
graph = graph ,
670
682
session = session ,
671
683
user_placeholder = a_in ,
@@ -677,4 +689,5 @@ def load(cls, path: Text) -> "EmbeddingPolicy":
677
689
bot_embed = bot_embed ,
678
690
all_bot_embed = all_bot_embed ,
679
691
attention_weights = attention_weights ,
692
+ ** meta ,
680
693
)
0 commit comments