1
1
import collections
2
+ import copy
2
3
import json
3
4
import logging
4
5
import os
42
43
CARRY_OVER_SLOTS_KEY = "carry_over_slots_to_new_session"
43
44
SESSION_EXPIRATION_TIME_KEY = "session_expiration_time"
44
45
SESSION_CONFIG_KEY = "session_config"
46
+ USED_ENTITIES_KEY = "used_entities"
47
+ USE_ENTITIES_KEY = "use_entities"
48
+ IGNORE_ENTITIES_KEY = "ignore_entities"
45
49
46
50
if typing .TYPE_CHECKING :
47
51
from rasa .core .trackers import DialogueStateTracker
@@ -75,7 +79,7 @@ class Domain:
75
79
"""The domain specifies the universe in which the bot's policy acts.
76
80
77
81
A Domain subclass provides the actions the bot can take, the intents
78
- and entities it can recognise"""
82
+ and entities it can recognise. """
79
83
80
84
@classmethod
81
85
def empty (cls ) -> "Domain" :
@@ -264,40 +268,97 @@ def collect_slots(slot_dict: Dict[Text, Any]) -> List[Slot]:
264
268
return slots
265
269
266
270
@staticmethod
271
+ def _transform_intent_properties_for_internal_use (
272
+ intent : Dict [Text , Any ], entities : List
273
+ ) -> Dict [Text , Any ]:
274
+ """Transform intent properties coming from a domain file for internal use.
275
+
276
+ In domain files, `use_entities` or `ignore_entities` is used. Internally, there
277
+ is a property `used_entities` instead that lists all entities to be used.
278
+
279
+ Args:
280
+ intent: The intents as provided by a domain file.
281
+ entities: All entities as provided by a domain file.
282
+
283
+ Returns:
284
+ The intents as they should be used internally.
285
+ """
286
+ name , properties = list (intent .items ())[0 ]
287
+
288
+ properties .setdefault (USE_ENTITIES_KEY , True )
289
+ properties .setdefault (IGNORE_ENTITIES_KEY , [])
290
+ if not properties [USE_ENTITIES_KEY ]: # this covers False, None and []
291
+ properties [USE_ENTITIES_KEY ] = []
292
+
293
+ # `use_entities` is either a list of explicitly included entities
294
+ # or `True` if all should be included
295
+ if properties [USE_ENTITIES_KEY ] is True :
296
+ included_entities = set (entities )
297
+ else :
298
+ included_entities = set (properties [USE_ENTITIES_KEY ])
299
+ excluded_entities = set (properties [IGNORE_ENTITIES_KEY ])
300
+ used_entities = list (included_entities - excluded_entities )
301
+ used_entities .sort ()
302
+
303
+ # Only print warning for ambiguous configurations if entities were included
304
+ # explicitly.
305
+ explicitly_included = isinstance (properties [USE_ENTITIES_KEY ], list )
306
+ ambiguous_entities = included_entities .intersection (excluded_entities )
307
+ if explicitly_included and ambiguous_entities :
308
+ raise_warning (
309
+ f"Entities: '{ ambiguous_entities } ' are explicitly included and"
310
+ f" excluded for intent '{ name } '."
311
+ f"Excluding takes precedence in this case. "
312
+ f"Please resolve that ambiguity." ,
313
+ docs = f"{ DOCS_URL_DOMAINS } #ignoring-entities-for-certain-intents" ,
314
+ )
315
+
316
+ properties [USED_ENTITIES_KEY ] = used_entities
317
+ del properties [USE_ENTITIES_KEY ]
318
+ del properties [IGNORE_ENTITIES_KEY ]
319
+
320
+ return intent
321
+
322
+ @classmethod
267
323
def collect_intent_properties (
268
- intents : List [Union [Text , Dict [Text , Any ]]]
324
+ cls , intents : List [Union [Text , Dict [Text , Any ]]], entities : List [ Text ]
269
325
) -> Dict [Text , Dict [Text , Union [bool , List ]]]:
326
+ """Get intent properties for a domain from what is provided by a domain file.
327
+
328
+ Args:
329
+ intents: The intents as provided by a domain file.
330
+ entities: All entities as provided by a domain file.
331
+
332
+ Returns:
333
+ The intent properties to be stored in the domain.
334
+ """
270
335
intent_properties = {}
336
+ duplicates = set ()
271
337
for intent in intents :
272
- if isinstance (intent , dict ):
273
- name = list (intent .keys ())[0 ]
274
- for properties in intent .values ():
275
- properties .setdefault ("use_entities" , True )
276
- properties .setdefault ("ignore_entities" , [])
277
- if (
278
- properties ["use_entities" ] is None
279
- or properties ["use_entities" ] is False
280
- ):
281
- properties ["use_entities" ] = []
282
- else :
283
- name = intent
284
- intent = {intent : {"use_entities" : True , "ignore_entities" : []}}
338
+ if not isinstance (intent , dict ):
339
+ intent = {intent : {USE_ENTITIES_KEY : True , IGNORE_ENTITIES_KEY : []}}
285
340
341
+ name = list (intent .keys ())[0 ]
286
342
if name in intent_properties .keys ():
287
- raise InvalidDomain (
288
- "Intents are not unique! Found two intents with name '{}'. "
289
- "Either rename or remove one of them." .format (name )
290
- )
343
+ duplicates .add (name )
344
+
345
+ intent = cls ._transform_intent_properties_for_internal_use (intent , entities )
291
346
292
347
intent_properties .update (intent )
348
+
349
+ if duplicates :
350
+ raise InvalidDomain (
351
+ f"Intents are not unique! Found multiple intents with name(s) { sorted (duplicates )} . "
352
+ f"Either rename or remove the duplicate ones."
353
+ )
354
+
293
355
return intent_properties
294
356
295
357
@staticmethod
296
358
def collect_templates (
297
359
yml_templates : Dict [Text , List [Any ]]
298
360
) -> Dict [Text , List [Dict [Text , Any ]]]:
299
- """Go through the templates and make sure they are all in dict format
300
- """
361
+ """Go through the templates and make sure they are all in dict format."""
301
362
templates = {}
302
363
for template_key , template_variations in yml_templates .items ():
303
364
validated_variations = []
@@ -345,7 +406,7 @@ def __init__(
345
406
session_config : SessionConfig = SessionConfig .default (),
346
407
) -> None :
347
408
348
- self .intent_properties = self .collect_intent_properties (intents )
409
+ self .intent_properties = self .collect_intent_properties (intents , entities )
349
410
self .entities = entities
350
411
self .form_names = form_names
351
412
self .slots = slots
@@ -376,7 +437,7 @@ def __hash__(self) -> int:
376
437
377
438
@lazy_property
378
439
def user_actions_and_forms (self ):
379
- """Returns combination of user actions and forms"""
440
+ """Returns combination of user actions and forms. """
380
441
381
442
return self .user_actions + self .form_names
382
443
@@ -394,7 +455,7 @@ def num_states(self):
394
455
return len (self .input_states )
395
456
396
457
def add_categorical_slot_default_value (self ) -> None :
397
- """Add a default value to all categorical slots
458
+ """Add a default value to all categorical slots.
398
459
399
460
All unseen values found for the slot will be mapped to this default value
400
461
for featurization.
@@ -439,7 +500,7 @@ def add_knowledge_base_slots(self) -> None:
439
500
def action_for_name (
440
501
self , action_name : Text , action_endpoint : Optional [EndpointConfig ]
441
502
) -> Optional [Action ]:
442
- """Looks up which action corresponds to this action name."""
503
+ """Look up which action corresponds to this action name."""
443
504
444
505
if action_name not in self .action_names :
445
506
self ._raise_action_not_found_exception (action_name )
@@ -470,7 +531,7 @@ def actions(self, action_endpoint) -> List[Optional[Action]]:
470
531
]
471
532
472
533
def index_for_action (self , action_name : Text ) -> Optional [int ]:
473
- """Looks up which action index corresponds to this action name"""
534
+ """Look up which action index corresponds to this action name. """
474
535
475
536
try :
476
537
return self .action_names .index (action_name )
@@ -532,13 +593,13 @@ def form_states(self) -> List[Text]:
532
593
return [f"active_form_{ f } " for f in self .form_names ]
533
594
534
595
def index_of_state (self , state_name : Text ) -> Optional [int ]:
535
- """Provides the index of a state."""
596
+ """Provide the index of a state."""
536
597
537
598
return self .input_state_map .get (state_name )
538
599
539
600
@lazy_property
540
601
def input_state_map (self ) -> Dict [Text , int ]:
541
- """Provides a mapping from state names to indices."""
602
+ """Provide a mapping from state names to indices."""
542
603
return {f : i for i , f in enumerate (self .input_states )}
543
604
544
605
@lazy_property
@@ -598,32 +659,14 @@ def _get_featurized_entities(self, latest_message: UserUttered) -> Set[Text]:
598
659
entity ["entity" ] for entity in entities if "entity" in entity .keys ()
599
660
}
600
661
601
- # `use_entities` is either a list of explicitly included entities
602
- # or `True` if all should be included
603
- include = intent_config .get ("use_entities" , True )
604
- included_entities = set (entity_names if include is True else include )
605
- excluded_entities = set (intent_config .get ("ignore_entities" , []))
606
- wanted_entities = included_entities - excluded_entities
607
-
608
- # Only print warning for ambiguous configurations if entities were included
609
- # explicitly.
610
- explicitly_included = isinstance (include , list )
611
- ambiguous_entities = included_entities .intersection (excluded_entities )
612
- if explicitly_included and ambiguous_entities :
613
- raise_warning (
614
- f"Entities: '{ ambiguous_entities } ' are explicitly included and"
615
- f" excluded for intent '{ intent_name } '."
616
- f"Excluding takes precedence in this case. "
617
- f"Please resolve that ambiguity." ,
618
- docs = DOCS_URL_DOMAINS + "#ignoring-entities-for-certain-intents" ,
619
- )
662
+ wanted_entities = set (intent_config .get (USED_ENTITIES_KEY , entity_names ))
620
663
621
664
return entity_names .intersection (wanted_entities )
622
665
623
666
def get_prev_action_states (
624
667
self , tracker : "DialogueStateTracker"
625
668
) -> Dict [Text , float ]:
626
- """Turns the previous taken action into a state name."""
669
+ """Turn the previous taken action into a state name."""
627
670
628
671
latest_action = tracker .latest_action_name
629
672
if latest_action :
@@ -637,15 +680,15 @@ def get_prev_action_states(
637
680
638
681
@staticmethod
639
682
def get_active_form (tracker : "DialogueStateTracker" ) -> Dict [Text , float ]:
640
- """Turns tracker's active form into a state name."""
683
+ """Turn tracker's active form into a state name."""
641
684
form = tracker .active_form .get ("name" )
642
685
if form is not None :
643
686
return {ACTIVE_FORM_PREFIX + form : 1.0 }
644
687
else :
645
688
return {}
646
689
647
690
def get_active_states (self , tracker : "DialogueStateTracker" ) -> Dict [Text , float ]:
648
- """Return a bag of active states from the tracker state"""
691
+ """Return a bag of active states from the tracker state. """
649
692
state_dict = self .get_parsing_states (tracker )
650
693
state_dict .update (self .get_prev_action_states (tracker ))
651
694
state_dict .update (self .get_active_form (tracker ))
@@ -677,7 +720,7 @@ def slots_for_entities(self, entities: List[Dict[Text, Any]]) -> List[SlotSet]:
677
720
return []
678
721
679
722
def persist_specification (self , model_path : Text ) -> None :
680
- """Persists the domain specification to storage."""
723
+ """Persist the domain specification to storage."""
681
724
682
725
domain_spec_path = os .path .join (model_path , "domain.json" )
683
726
rasa .utils .io .create_directory_for_file (domain_spec_path )
@@ -694,7 +737,7 @@ def load_specification(cls, path: Text) -> Dict[Text, Any]:
694
737
return specification
695
738
696
739
def compare_with_specification (self , path : Text ) -> bool :
697
- """Compares the domain spec of the current and the loaded domain.
740
+ """Compare the domain spec of the current and the loaded domain.
698
741
699
742
Throws exception if the loaded domain specification is different
700
743
to the current domain are different."""
@@ -727,7 +770,7 @@ def as_dict(self) -> Dict[Text, Any]:
727
770
SESSION_EXPIRATION_TIME_KEY : self .session_config .session_expiration_time ,
728
771
CARRY_OVER_SLOTS_KEY : self .session_config .carry_over_slots ,
729
772
},
730
- "intents" : [{ k : v } for k , v in self .intent_properties . items ()] ,
773
+ "intents" : self ._transform_intents_for_file () ,
731
774
"entities" : self .entities ,
732
775
"slots" : self ._slot_definitions (),
733
776
"responses" : self .templates ,
@@ -741,16 +784,51 @@ def persist(self, filename: Union[Text, Path]) -> None:
741
784
domain_data = self .as_dict ()
742
785
utils .dump_obj_as_yaml_to_file (filename , domain_data )
743
786
787
+ def _transform_intents_for_file (self ) -> List [Union [Text , Dict [Text , Any ]]]:
788
+ """Transform intent properties for displaying or writing into a domain file.
789
+
790
+ Internally, there is a property `used_entities` that lists all entities to be
791
+ used. In domain files, `use_entities` or `ignore_entities` is used instead to
792
+ list individual entities to ex- or include, because this is easier to read.
793
+
794
+ Returns:
795
+ The intent properties as they are used in domain files.
796
+ """
797
+ intent_properties = copy .deepcopy (self .intent_properties )
798
+ intents_for_file = []
799
+
800
+ for intent_name , intent_props in intent_properties .items ():
801
+ use_entities = set (intent_props [USED_ENTITIES_KEY ])
802
+ ignore_entities = set (self .entities ) - use_entities
803
+ if len (use_entities ) == len (self .entities ):
804
+ intent_props [USE_ENTITIES_KEY ] = True
805
+ elif len (use_entities ) <= len (self .entities ) / 2 :
806
+ intent_props [USE_ENTITIES_KEY ] = list (use_entities )
807
+ else :
808
+ intent_props [IGNORE_ENTITIES_KEY ] = list (ignore_entities )
809
+ intent_props .pop (USED_ENTITIES_KEY )
810
+ intents_for_file .append ({intent_name : intent_props })
811
+
812
+ return intents_for_file
813
+
744
814
def cleaned_domain (self ) -> Dict [Text , Any ]:
745
- """Fetch cleaned domain, replacing redundant keys with default values."""
815
+ """Fetch cleaned domain to display or write into a file.
746
816
817
+ The internal `used_entities` property is replaced by `use_entities` or
818
+ `ignore_entities` and redundant keys are replaced with default values
819
+ to make the domain easier readable.
820
+
821
+ Returns:
822
+ A cleaned dictionary version of the domain.
823
+ """
747
824
domain_data = self .as_dict ()
825
+
748
826
for idx , intent_info in enumerate (domain_data ["intents" ]):
749
827
for name , intent in intent_info .items ():
750
- if intent .get ("use_entities" ) is True :
751
- intent . pop ( "use_entities" )
752
- if not intent .get ("ignore_entities" ):
753
- intent .pop ("ignore_entities" , None )
828
+ if intent .get (USE_ENTITIES_KEY ) is True :
829
+ del intent [ USE_ENTITIES_KEY ]
830
+ if not intent .get (IGNORE_ENTITIES_KEY ):
831
+ intent .pop (IGNORE_ENTITIES_KEY , None )
754
832
if len (intent ) == 0 :
755
833
domain_data ["intents" ][idx ] = name
756
834
@@ -988,7 +1066,7 @@ def check_missing_templates(self) -> None:
988
1066
)
989
1067
990
1068
def is_empty (self ) -> bool :
991
- """Checks whether the domain is empty."""
1069
+ """Check whether the domain is empty."""
992
1070
993
1071
return self .as_dict () == Domain .empty ().as_dict ()
994
1072
0 commit comments