Skip to content

Commit 4d66948

Browse files
authored
Merge branch 'master' into issue_3975
2 parents 78b2966 + aaf4e6a commit 4d66948

File tree

4 files changed

+219
-102
lines changed

4 files changed

+219
-102
lines changed

changelog/3923.misc.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Internally, intents now have only one property ``used_entities`` to indicate which
2+
entities should be used. For displaying purposes and in ``domain.yml`` files, the
3+
properties ``use_entities`` and/or ``ignore_entities`` will be used as before.

rasa/core/domain.py

Lines changed: 137 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import collections
2+
import copy
23
import json
34
import logging
45
import os
@@ -42,6 +43,9 @@
4243
CARRY_OVER_SLOTS_KEY = "carry_over_slots_to_new_session"
4344
SESSION_EXPIRATION_TIME_KEY = "session_expiration_time"
4445
SESSION_CONFIG_KEY = "session_config"
46+
USED_ENTITIES_KEY = "used_entities"
47+
USE_ENTITIES_KEY = "use_entities"
48+
IGNORE_ENTITIES_KEY = "ignore_entities"
4549

4650
if typing.TYPE_CHECKING:
4751
from rasa.core.trackers import DialogueStateTracker
@@ -75,7 +79,7 @@ class Domain:
7579
"""The domain specifies the universe in which the bot's policy acts.
7680
7781
A Domain subclass provides the actions the bot can take, the intents
78-
and entities it can recognise"""
82+
and entities it can recognise."""
7983

8084
@classmethod
8185
def empty(cls) -> "Domain":
@@ -264,40 +268,97 @@ def collect_slots(slot_dict: Dict[Text, Any]) -> List[Slot]:
264268
return slots
265269

266270
@staticmethod
271+
def _transform_intent_properties_for_internal_use(
272+
intent: Dict[Text, Any], entities: List
273+
) -> Dict[Text, Any]:
274+
"""Transform intent properties coming from a domain file for internal use.
275+
276+
In domain files, `use_entities` or `ignore_entities` is used. Internally, there
277+
is a property `used_entities` instead that lists all entities to be used.
278+
279+
Args:
280+
intent: The intents as provided by a domain file.
281+
entities: All entities as provided by a domain file.
282+
283+
Returns:
284+
The intents as they should be used internally.
285+
"""
286+
name, properties = list(intent.items())[0]
287+
288+
properties.setdefault(USE_ENTITIES_KEY, True)
289+
properties.setdefault(IGNORE_ENTITIES_KEY, [])
290+
if not properties[USE_ENTITIES_KEY]: # this covers False, None and []
291+
properties[USE_ENTITIES_KEY] = []
292+
293+
# `use_entities` is either a list of explicitly included entities
294+
# or `True` if all should be included
295+
if properties[USE_ENTITIES_KEY] is True:
296+
included_entities = set(entities)
297+
else:
298+
included_entities = set(properties[USE_ENTITIES_KEY])
299+
excluded_entities = set(properties[IGNORE_ENTITIES_KEY])
300+
used_entities = list(included_entities - excluded_entities)
301+
used_entities.sort()
302+
303+
# Only print warning for ambiguous configurations if entities were included
304+
# explicitly.
305+
explicitly_included = isinstance(properties[USE_ENTITIES_KEY], list)
306+
ambiguous_entities = included_entities.intersection(excluded_entities)
307+
if explicitly_included and ambiguous_entities:
308+
raise_warning(
309+
f"Entities: '{ambiguous_entities}' are explicitly included and"
310+
f" excluded for intent '{name}'."
311+
f"Excluding takes precedence in this case. "
312+
f"Please resolve that ambiguity.",
313+
docs=f"{DOCS_URL_DOMAINS}#ignoring-entities-for-certain-intents",
314+
)
315+
316+
properties[USED_ENTITIES_KEY] = used_entities
317+
del properties[USE_ENTITIES_KEY]
318+
del properties[IGNORE_ENTITIES_KEY]
319+
320+
return intent
321+
322+
@classmethod
267323
def collect_intent_properties(
268-
intents: List[Union[Text, Dict[Text, Any]]]
324+
cls, intents: List[Union[Text, Dict[Text, Any]]], entities: List[Text]
269325
) -> Dict[Text, Dict[Text, Union[bool, List]]]:
326+
"""Get intent properties for a domain from what is provided by a domain file.
327+
328+
Args:
329+
intents: The intents as provided by a domain file.
330+
entities: All entities as provided by a domain file.
331+
332+
Returns:
333+
The intent properties to be stored in the domain.
334+
"""
270335
intent_properties = {}
336+
duplicates = set()
271337
for intent in intents:
272-
if isinstance(intent, dict):
273-
name = list(intent.keys())[0]
274-
for properties in intent.values():
275-
properties.setdefault("use_entities", True)
276-
properties.setdefault("ignore_entities", [])
277-
if (
278-
properties["use_entities"] is None
279-
or properties["use_entities"] is False
280-
):
281-
properties["use_entities"] = []
282-
else:
283-
name = intent
284-
intent = {intent: {"use_entities": True, "ignore_entities": []}}
338+
if not isinstance(intent, dict):
339+
intent = {intent: {USE_ENTITIES_KEY: True, IGNORE_ENTITIES_KEY: []}}
285340

341+
name = list(intent.keys())[0]
286342
if name in intent_properties.keys():
287-
raise InvalidDomain(
288-
"Intents are not unique! Found two intents with name '{}'. "
289-
"Either rename or remove one of them.".format(name)
290-
)
343+
duplicates.add(name)
344+
345+
intent = cls._transform_intent_properties_for_internal_use(intent, entities)
291346

292347
intent_properties.update(intent)
348+
349+
if duplicates:
350+
raise InvalidDomain(
351+
f"Intents are not unique! Found multiple intents with name(s) {sorted(duplicates)}. "
352+
f"Either rename or remove the duplicate ones."
353+
)
354+
293355
return intent_properties
294356

295357
@staticmethod
296358
def collect_templates(
297359
yml_templates: Dict[Text, List[Any]]
298360
) -> Dict[Text, List[Dict[Text, Any]]]:
299-
"""Go through the templates and make sure they are all in dict format
300-
"""
361+
"""Go through the templates and make sure they are all in dict format."""
301362
templates = {}
302363
for template_key, template_variations in yml_templates.items():
303364
validated_variations = []
@@ -345,7 +406,7 @@ def __init__(
345406
session_config: SessionConfig = SessionConfig.default(),
346407
) -> None:
347408

348-
self.intent_properties = self.collect_intent_properties(intents)
409+
self.intent_properties = self.collect_intent_properties(intents, entities)
349410
self.entities = entities
350411
self.form_names = form_names
351412
self.slots = slots
@@ -376,7 +437,7 @@ def __hash__(self) -> int:
376437

377438
@lazy_property
378439
def user_actions_and_forms(self):
379-
"""Returns combination of user actions and forms"""
440+
"""Returns combination of user actions and forms."""
380441

381442
return self.user_actions + self.form_names
382443

@@ -394,7 +455,7 @@ def num_states(self):
394455
return len(self.input_states)
395456

396457
def add_categorical_slot_default_value(self) -> None:
397-
"""Add a default value to all categorical slots
458+
"""Add a default value to all categorical slots.
398459
399460
All unseen values found for the slot will be mapped to this default value
400461
for featurization.
@@ -439,7 +500,7 @@ def add_knowledge_base_slots(self) -> None:
439500
def action_for_name(
440501
self, action_name: Text, action_endpoint: Optional[EndpointConfig]
441502
) -> Optional[Action]:
442-
"""Looks up which action corresponds to this action name."""
503+
"""Look up which action corresponds to this action name."""
443504

444505
if action_name not in self.action_names:
445506
self._raise_action_not_found_exception(action_name)
@@ -470,7 +531,7 @@ def actions(self, action_endpoint) -> List[Optional[Action]]:
470531
]
471532

472533
def index_for_action(self, action_name: Text) -> Optional[int]:
473-
"""Looks up which action index corresponds to this action name"""
534+
"""Look up which action index corresponds to this action name."""
474535

475536
try:
476537
return self.action_names.index(action_name)
@@ -532,13 +593,13 @@ def form_states(self) -> List[Text]:
532593
return [f"active_form_{f}" for f in self.form_names]
533594

534595
def index_of_state(self, state_name: Text) -> Optional[int]:
535-
"""Provides the index of a state."""
596+
"""Provide the index of a state."""
536597

537598
return self.input_state_map.get(state_name)
538599

539600
@lazy_property
540601
def input_state_map(self) -> Dict[Text, int]:
541-
"""Provides a mapping from state names to indices."""
602+
"""Provide a mapping from state names to indices."""
542603
return {f: i for i, f in enumerate(self.input_states)}
543604

544605
@lazy_property
@@ -598,32 +659,14 @@ def _get_featurized_entities(self, latest_message: UserUttered) -> Set[Text]:
598659
entity["entity"] for entity in entities if "entity" in entity.keys()
599660
}
600661

601-
# `use_entities` is either a list of explicitly included entities
602-
# or `True` if all should be included
603-
include = intent_config.get("use_entities", True)
604-
included_entities = set(entity_names if include is True else include)
605-
excluded_entities = set(intent_config.get("ignore_entities", []))
606-
wanted_entities = included_entities - excluded_entities
607-
608-
# Only print warning for ambiguous configurations if entities were included
609-
# explicitly.
610-
explicitly_included = isinstance(include, list)
611-
ambiguous_entities = included_entities.intersection(excluded_entities)
612-
if explicitly_included and ambiguous_entities:
613-
raise_warning(
614-
f"Entities: '{ambiguous_entities}' are explicitly included and"
615-
f" excluded for intent '{intent_name}'."
616-
f"Excluding takes precedence in this case. "
617-
f"Please resolve that ambiguity.",
618-
docs=DOCS_URL_DOMAINS + "#ignoring-entities-for-certain-intents",
619-
)
662+
wanted_entities = set(intent_config.get(USED_ENTITIES_KEY, entity_names))
620663

621664
return entity_names.intersection(wanted_entities)
622665

623666
def get_prev_action_states(
624667
self, tracker: "DialogueStateTracker"
625668
) -> Dict[Text, float]:
626-
"""Turns the previous taken action into a state name."""
669+
"""Turn the previous taken action into a state name."""
627670

628671
latest_action = tracker.latest_action_name
629672
if latest_action:
@@ -637,15 +680,15 @@ def get_prev_action_states(
637680

638681
@staticmethod
639682
def get_active_form(tracker: "DialogueStateTracker") -> Dict[Text, float]:
640-
"""Turns tracker's active form into a state name."""
683+
"""Turn tracker's active form into a state name."""
641684
form = tracker.active_form.get("name")
642685
if form is not None:
643686
return {ACTIVE_FORM_PREFIX + form: 1.0}
644687
else:
645688
return {}
646689

647690
def get_active_states(self, tracker: "DialogueStateTracker") -> Dict[Text, float]:
648-
"""Return a bag of active states from the tracker state"""
691+
"""Return a bag of active states from the tracker state."""
649692
state_dict = self.get_parsing_states(tracker)
650693
state_dict.update(self.get_prev_action_states(tracker))
651694
state_dict.update(self.get_active_form(tracker))
@@ -677,7 +720,7 @@ def slots_for_entities(self, entities: List[Dict[Text, Any]]) -> List[SlotSet]:
677720
return []
678721

679722
def persist_specification(self, model_path: Text) -> None:
680-
"""Persists the domain specification to storage."""
723+
"""Persist the domain specification to storage."""
681724

682725
domain_spec_path = os.path.join(model_path, "domain.json")
683726
rasa.utils.io.create_directory_for_file(domain_spec_path)
@@ -694,7 +737,7 @@ def load_specification(cls, path: Text) -> Dict[Text, Any]:
694737
return specification
695738

696739
def compare_with_specification(self, path: Text) -> bool:
697-
"""Compares the domain spec of the current and the loaded domain.
740+
"""Compare the domain spec of the current and the loaded domain.
698741
699742
Throws exception if the loaded domain specification is different
700743
to the current domain are different."""
@@ -727,7 +770,7 @@ def as_dict(self) -> Dict[Text, Any]:
727770
SESSION_EXPIRATION_TIME_KEY: self.session_config.session_expiration_time,
728771
CARRY_OVER_SLOTS_KEY: self.session_config.carry_over_slots,
729772
},
730-
"intents": [{k: v} for k, v in self.intent_properties.items()],
773+
"intents": self._transform_intents_for_file(),
731774
"entities": self.entities,
732775
"slots": self._slot_definitions(),
733776
"responses": self.templates,
@@ -741,16 +784,51 @@ def persist(self, filename: Union[Text, Path]) -> None:
741784
domain_data = self.as_dict()
742785
utils.dump_obj_as_yaml_to_file(filename, domain_data)
743786

787+
def _transform_intents_for_file(self) -> List[Union[Text, Dict[Text, Any]]]:
788+
"""Transform intent properties for displaying or writing into a domain file.
789+
790+
Internally, there is a property `used_entities` that lists all entities to be
791+
used. In domain files, `use_entities` or `ignore_entities` is used instead to
792+
list individual entities to ex- or include, because this is easier to read.
793+
794+
Returns:
795+
The intent properties as they are used in domain files.
796+
"""
797+
intent_properties = copy.deepcopy(self.intent_properties)
798+
intents_for_file = []
799+
800+
for intent_name, intent_props in intent_properties.items():
801+
use_entities = set(intent_props[USED_ENTITIES_KEY])
802+
ignore_entities = set(self.entities) - use_entities
803+
if len(use_entities) == len(self.entities):
804+
intent_props[USE_ENTITIES_KEY] = True
805+
elif len(use_entities) <= len(self.entities) / 2:
806+
intent_props[USE_ENTITIES_KEY] = list(use_entities)
807+
else:
808+
intent_props[IGNORE_ENTITIES_KEY] = list(ignore_entities)
809+
intent_props.pop(USED_ENTITIES_KEY)
810+
intents_for_file.append({intent_name: intent_props})
811+
812+
return intents_for_file
813+
744814
def cleaned_domain(self) -> Dict[Text, Any]:
745-
"""Fetch cleaned domain, replacing redundant keys with default values."""
815+
"""Fetch cleaned domain to display or write into a file.
746816
817+
The internal `used_entities` property is replaced by `use_entities` or
818+
`ignore_entities` and redundant keys are replaced with default values
819+
to make the domain easier readable.
820+
821+
Returns:
822+
A cleaned dictionary version of the domain.
823+
"""
747824
domain_data = self.as_dict()
825+
748826
for idx, intent_info in enumerate(domain_data["intents"]):
749827
for name, intent in intent_info.items():
750-
if intent.get("use_entities") is True:
751-
intent.pop("use_entities")
752-
if not intent.get("ignore_entities"):
753-
intent.pop("ignore_entities", None)
828+
if intent.get(USE_ENTITIES_KEY) is True:
829+
del intent[USE_ENTITIES_KEY]
830+
if not intent.get(IGNORE_ENTITIES_KEY):
831+
intent.pop(IGNORE_ENTITIES_KEY, None)
754832
if len(intent) == 0:
755833
domain_data["intents"][idx] = name
756834

@@ -988,7 +1066,7 @@ def check_missing_templates(self) -> None:
9881066
)
9891067

9901068
def is_empty(self) -> bool:
991-
"""Checks whether the domain is empty."""
1069+
"""Check whether the domain is empty."""
9921070

9931071
return self.as_dict() == Domain.empty().as_dict()
9941072

0 commit comments

Comments
 (0)