Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speed up domain loading #12987

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rasa/graph_components/providers/domain_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def load(
) -> DomainProvider:
"""Creates provider using a persisted version of itself."""
with model_storage.read_from(resource) as resource_directory:
domain = Domain.from_path(resource_directory)
domain = Domain.from_path(resource_directory, is_validated=True)
return cls(model_storage, resource, domain)

def _persist(self, domain: Domain) -> None:
Expand Down
40 changes: 28 additions & 12 deletions rasa/shared/core/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,20 +188,21 @@ def load(cls, paths: Union[List[Union[Path, Text]], Text, Path]) -> "Domain":
return domain

@classmethod
def from_path(cls, path: Union[Text, Path]) -> "Domain":
def from_path(cls, path: Union[Text, Path], is_validated: bool = False) -> "Domain":
"""Loads the `Domain` from a path."""
logger.debug(f"Loading from {path}")
path = os.path.abspath(path)

if os.path.isfile(path):
domain = cls.from_file(path)
elif os.path.isdir(path):
domain = cls.from_directory(path)
domain = cls.from_directory(path, is_validated=is_validated)
else:
raise InvalidDomain(
"Failed to load domain specification from '{}'. "
"File not found!".format(os.path.abspath(path))
)

logger.debug(f"done loading domain from {path}")
return domain

@classmethod
Expand Down Expand Up @@ -287,20 +288,30 @@ def _get_session_config(session_config: Dict) -> SessionConfig:
return SessionConfig(session_expiration_time_min, carry_over_slots)

@classmethod
def from_directory(cls, path: Text) -> "Domain":
def from_directory(cls, path: Text, is_validated: bool = False) -> "Domain":
"""Loads and merges multiple domain files recursively from a directory tree."""
combined: Dict[Text, Any] = {}
for root, _, files in os.walk(path, followlinks=True):
for file in files:
logger.debug(f"Processing {file=}")
full_path = os.path.join(root, file)
if Domain.is_domain_file(full_path):
_ = Domain.from_file(full_path) # does the validation here only
other_dict = rasa.shared.utils.io.read_yaml(
rasa.shared.utils.io.read_file(full_path)
)
logger.debug(f"Checking file type of {file=}")
if other_dict := Domain.is_domain_file(full_path):
if not is_validated:
logger.debug(f"Validating {file=}")
_ = Domain.from_dict(
other_dict
) # does the validation here only
# logger.debug(f"Reading {file=}")
# other_dict = rasa.shared.utils.io.read_yaml(
# rasa.shared.utils.io.read_file(full_path)
# )
logger.debug(f"Merging {file=}")
combined = Domain.merge_domain_dicts(other_dict, combined)

logger.debug("Building domain from dict.")
domain = Domain.from_dict(combined)
logger.debug(f"Merged domain from directory {path=}")
return domain

def merge(
Expand Down Expand Up @@ -1795,14 +1806,16 @@ def is_empty(self) -> bool:
return self.as_dict() == Domain.empty().as_dict()

@staticmethod
def is_domain_file(filename: Union[Text, Path]) -> bool:
def is_domain_file(
filename: Union[Text, Path]
) -> Union[bool, Dict[str, Any], List[Any]]:
"""Checks whether the given file path is a Rasa domain file.

Args:
filename: Path of the file which should be checked.

Returns:
`True` if it's a domain file, otherwise `False`.
a domain file, otherwise `False`.

Raises:
YamlException: if the file seems to be a YAML file (extension) but
Expand All @@ -1824,7 +1837,10 @@ def is_domain_file(filename: Union[Text, Path]) -> bool:
)
return False

return any(key in content for key in ALL_DOMAIN_KEYS)
if any(key in content for key in ALL_DOMAIN_KEYS):
return content
else:
return False

def required_slots_for_form(self, form_name: Text) -> List[Text]:
"""Retrieve the list of required slot names for a form defined in the domain.
Expand Down
Loading