Skip to content

Commit

Permalink
Support nested NeMo models (NVIDIA#5671)
Browse files Browse the repository at this point in the history
Nested NeMo models support

Signed-off-by: Vladimir Bataev <[email protected]>

Co-authored-by: Oleksii Kuchaiev <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sean Naren <[email protected]>
  • Loading branch information
5 people authored Jan 23, 2023
1 parent ac50e59 commit 97973c5
Show file tree
Hide file tree
Showing 5 changed files with 856 additions and 90 deletions.
59 changes: 58 additions & 1 deletion docs/source/core/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,64 @@ The resulting .nemo file will then have the following file:
4978b28103264263a03439aaa6560e5e_tokenizer.model
If ``verify_src_exists`` is set to ``False``, then the artifact is optional. This means that ``.register_artifact`` will return ``None``
if the ``src`` cannot be found.
if the ``src`` cannot be found.

Nested NeMo Models
------------------

In some cases, it may be helpful to use NeMo models inside other NeMo models. For example, we can incorporate language models into ASR models to use in a decoding process to improve accuracy or use hybrid ASR-TTS models to generate audio from the text on the fly to train or finetune the ASR model.

There are 3 ways to instantiate child models inside parent models:

- use subconfig directly
- use the ``.nemo`` checkpoint path to load the child model
- use a pretrained NeMo model

To register a child model, use the ``register_nemo_submodule`` method of the parent model. This method will add the child model to a provided model attribute and, in the serialization process, will handle child artifacts correctly and store the child model config in the parent model config in ``config_field``.

.. code-block:: python
from nemo.core.classes import ModelPT
class ChildModel(ModelPT):
... # implement necessary methods
class ParentModel(ModelPT):
def __init__(self, cfg, trainer=None):
super().__init__(cfg=cfg, trainer=trainer)
# optionally annotate type for IDE autocompletion and type checking
self.child_model: Optional[ChildModel]
if cfg.get("child_model") is not None:
# load directly from config
# either if config provided initially, or automatically
# after model restoration
self.register_nemo_submodule(
name="child_model",
config_field="child_model",
model=ChildModel(self.cfg.child_model, trainer=trainer),
)
elif cfg.get('child_model_path') is not None:
# load from .nemo model checkpoint
# while saving, config will be automatically assigned/updated
# in cfg.child_model
self.register_nemo_submodule(
name="child_model",
config_field="child_model",
model=ChildModel.restore_from(self.cfg.child_model_path, trainer=trainer),
)
elif cfg.get('child_model_name') is not None:
# load from pretrained model
# while saving, config will be automatically assigned/updated
# in cfg.child_model
self.register_nemo_submodule(
name="child_model",
config_field="child_model",
model=ChildModel.from_pretrained(self.cfg.child_model_name, trainer=trainer),
)
else:
self.child_model = None
Neural Modules
==============
Expand Down
110 changes: 107 additions & 3 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import copy
import inspect
Expand All @@ -19,7 +20,7 @@
from abc import abstractmethod
from os import path
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union

import hydra
import torch
Expand All @@ -35,6 +36,7 @@
from nemo.utils import logging, model_utils
from nemo.utils.app_state import AppState
from nemo.utils.debug_hook import register_debug_hooks
from nemo.utils.exceptions import NeMoBaseException
from nemo.utils.get_rank import get_rank, is_global_rank_zero

__all__ = ['ModelPT']
Expand Down Expand Up @@ -110,6 +112,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

self._cfg = cfg

# init mapping submodule attribute -> config_field for nested NeMo models
self._nemo_submodule_name_to_config_field = dict()

self.save_hyperparameters("cfg")
self._train_dl = None
self._validation_dl = None
Expand Down Expand Up @@ -221,11 +226,15 @@ def register_artifact(
str: If src is not None or empty it always returns absolute path which is guaranteed to exist during model instance life
"""

app_state = AppState()

if src is None or src == "":
return src

if Path(src).suffix == ".nemo":
raise NeMoBaseException(
"Registering .nemo files as artifacts not supported. "
"If you are trying to make a nested model, use `register_nemo_submodule`."
)

if not hasattr(self, 'artifacts'):
self.artifacts = {}

Expand All @@ -240,6 +249,101 @@ def register_artifact(

return self._save_restore_connector.register_artifact(self, config_path, src, verify_src_exists)

def has_artifacts(self) -> bool:
"""Returns True if model has artifacts registered"""
return hasattr(self, 'artifacts') and self.artifacts is not None and len(self.artifacts) > 0

def has_native_or_submodules_artifacts(self) -> bool:
"""Returns True if it has artifacts or any of the submodules have artifacts"""
for module in self.modules():
if (
isinstance(module, ModelPT)
and hasattr(module, 'artifacts')
and module.artifacts is not None
and len(module.artifacts) > 0
):
return True
return False

def has_nemo_submodules(self) -> bool:
"""Returns True if it has any registered NeMo submodules"""
return len(self._nemo_submodule_name_to_config_field) > 0

def register_nemo_submodule(self, name: str, config_field: str, model: "ModelPT") -> None:
"""
Adds a NeMo model as a submodule. Submodule can be accessed via the `name` attribute on the parent NeMo model this submodule was registered on (`self`).
In the saving process, the whole parent model (self) is held as a solid model with artifacts
from the child submodule, the submodule config will be saved to the `config_field` of the parent model.
This method is necessary to create a nested model, e.g.
.. code-block:: python
class ParentModel(ModelPT):
def __init__(self, cfg, trainer=None):
super().__init__(cfg=cfg, trainer=trainer)
# annotate type for autocompletion and type checking (optional)
self.child_model: Optional[ChildModel] = None
if cfg.get("child_model") is not None:
self.register_nemo_submodule(
name="child_model",
config_field="child_model",
model=ChildModel(self.cfg.child_model, trainer=trainer),
)
# ... other code
Args:
name: name of the attribute for the submodule
config_field: field in config, where submodule config should be saved
model: NeMo model, instance of ModelPT
"""
# check it is a real NeMo model
if not isinstance(model, ModelPT):
raise NeMoBaseException(
f"Model is not and instance of ModelPT, so can't be registered. Got {type(model).__name__}"
)
# check if it is called after __init__
if not hasattr(self, "_nemo_submodule_name_to_config_field"):
raise NeMoBaseException(
"You are trying to register a submodule before the model is initialized. This is not allowed. "
"Did you forget to call `super().__init__`?"
)
# assign attribute to self
setattr(self, name, model)
# add to the submodules mapping
self._nemo_submodule_name_to_config_field[name] = config_field

def named_nemo_modules(
self, prefix_name: str = "", prefix_config: str = ""
) -> Iterator[Tuple[str, str, "ModelPT"]]:
"""
Returns an iterator over all NeMo submodules recursively, yielding
tuples of (attribute path, path in config, submodule), starting from the core module
Args:
prefix_name: prefix for the name path
prefix_config: prefix for the path in config
Returns:
Iterator over (attribute path, path in config, submodule), starting from (prefix, self)
"""
if not hasattr(self, "_nemo_submodule_name_to_config_field"):
raise NeMoBaseException(
"Model is not fully initialized. Calling `named_nemo_modules` before __init__ not allowed. "
"Did you forget to call `super().__init__`?"
)

yield prefix_name, prefix_config, self

# recursive iteration over all NeMo submodules
for name, config_field in self._nemo_submodule_name_to_config_field.items():
attribute_path = f"{prefix_name}.{name}" if prefix_name else name
config_path = f"{prefix_config}.{config_field}" if prefix_config else config_field
module: ModelPT = getattr(self, name)
for submodule_name, subconfig_path, submodule in module.named_nemo_modules(
prefix_name=attribute_path, prefix_config=config_path
):
yield submodule_name, subconfig_path, submodule

def save_to(self, save_path: str):
"""
Saves model instance (weights and configuration) into .nemo file
Expand Down
105 changes: 79 additions & 26 deletions nemo/core/connectors/save_restore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations # necessary for lazy types evaluation

import os
import shutil
import tarfile
import tempfile
import uuid
from typing import Optional, Union
from typing import Optional, Set, Union

import torch
from omegaconf import DictConfig, OmegaConf
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer

from nemo.core import classes as nemo_classes # to avoid circular import do not import ModelPT directly
from nemo.utils import logging, model_utils
from nemo.utils.app_state import AppState
from nemo.utils.get_rank import is_global_rank_zero
Expand All @@ -37,14 +39,14 @@ def __init__(self) -> None:
self._model_weights_ckpt = "model_weights.ckpt"
self._model_extracted_dir = None

def save_to(self, model, save_path: str):
def save_to(self, model: "nemo_classes.ModelPT", save_path: str):
"""
Saves model instance (weights and configuration) into .nemo file.
You can use "restore_from" method to fully restore instance from .nemo file.
.nemo file is an archive (tar.gz) with the following:
model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor
model_wights.chpt - model checkpoint
model_wights.ckpt - model checkpoint
Args:
model: ModelPT object to be saved.
Expand All @@ -56,7 +58,9 @@ def save_to(self, model, save_path: str):
config_yaml = os.path.join(tmpdir, self.model_config_yaml)
model_weights = os.path.join(tmpdir, self.model_weights_ckpt)
model.to_config_file(path2yaml_file=config_yaml)
if hasattr(model, 'artifacts') and model.artifacts is not None:
# update subconfigs, if there are child model, since child model can change its config
self._update_subconfigs(model, path2yaml_file=config_yaml)
if model.has_native_or_submodules_artifacts():
self._handle_artifacts(model, nemo_file_folder=tmpdir)
# We should not update self._cfg here - the model can still be in use
self._update_artifact_paths(model, path2yaml_file=config_yaml)
Expand Down Expand Up @@ -400,40 +404,70 @@ def register_artifact(self, model, config_path: str, src: str, verify_src_exists
def _handle_artifacts(self, model, nemo_file_folder):
tarfile_artifacts = []
app_state = AppState()
for conf_path, artiitem in model.artifacts.items():
if artiitem.path_type == model_utils.ArtifactPathType.LOCAL_PATH:
if not os.path.exists(artiitem.path):
raise FileNotFoundError(f"Artifact {conf_path} not found at location: {artiitem.path}")

# Generate new uniq artifact name and copy it to nemo_file_folder
# Note uuid.uuid4().hex is guaranteed to be 32 character long
artifact_base_name = os.path.basename(artiitem.path)
artifact_uniq_name = f"{uuid.uuid4().hex}_{artifact_base_name}"
shutil.copy2(artiitem.path, os.path.join(nemo_file_folder, artifact_uniq_name))
# aggregate artifacts from self and all children recursively
artifacts_containers = []
for _, config_path, module in model.named_nemo_modules():
if module.has_artifacts(): # NeMo model with artifacts
artifacts_containers.append((config_path, module.artifacts))

if len(artifacts_containers) > 0 and (not hasattr(model, "artifacts") or model.artifacts is None):
# model has no artifacts, but submodules have some
model.artifacts = dict()
for config_path, artifacts in artifacts_containers:
for subconf_path, artiitem in artifacts.items():
conf_path = f"{config_path}.{subconf_path}" if config_path else f"{subconf_path}"
if artiitem.path_type == model_utils.ArtifactPathType.LOCAL_PATH:
if not os.path.exists(artiitem.path):
raise FileNotFoundError(f"Artifact {conf_path} not found at location: {artiitem.path}")

# Generate new uniq artifact name and copy it to nemo_file_folder
# Note uuid.uuid4().hex is guaranteed to be 32 character long
artifact_base_name = os.path.basename(artiitem.path)
artifact_uniq_name = f"{uuid.uuid4().hex}_{artifact_base_name}"
shutil.copy2(artiitem.path, os.path.join(nemo_file_folder, artifact_uniq_name))

# Update artifacts registry
artiitem.hashed_path = "nemo:" + artifact_uniq_name
model.artifacts[conf_path] = artiitem

elif artiitem.path_type == model_utils.ArtifactPathType.TAR_PATH:
# process all tarfile artifacts in one go, so preserve key-value pair
tarfile_artifacts.append((conf_path, artiitem))
if subconf_path: # artifact from submodule
model.artifacts[conf_path] = artiitem

# Update artifacts registry
artiitem.hashed_path = "nemo:" + artifact_uniq_name
model.artifacts[conf_path] = artiitem

elif artiitem.path_type == model_utils.ArtifactPathType.TAR_PATH:
# process all tarfile artifacts in one go, so preserve key-value pair
tarfile_artifacts.append((conf_path, artiitem))

else:
raise ValueError(f"Directly referencing artifacts from other nemo files isn't supported yet")
else:
raise ValueError(f"Directly referencing artifacts from other nemo files isn't supported yet")

# Process current tarfile artifacts by unpacking the previous tarfile and extract the artifacts
# that are currently required.
# artifacts can be native (from the model itself) and from submodules
restoration_paths: Set[str] = set() # model + submodules restoration paths, handle only unique paths
model_metadata = app_state.get_model_metadata_from_guid(model.model_guid)
if len(tarfile_artifacts) > 0 and model_metadata.restoration_path is not None:
if model_metadata.restoration_path is not None:
restoration_paths.add(model_metadata.restoration_path)
# aggregate restoration paths for all submodules recursively
for module in model.modules():
if isinstance(module, nemo_classes.ModelPT): # if NeMo model
submodule_restoration_path = app_state.get_model_metadata_from_guid(module.model_guid).restoration_path
if submodule_restoration_path is not None:
restoration_paths.add(submodule_restoration_path)
if len(tarfile_artifacts) > 0 and len(restoration_paths) == 0:
# TODO: see cases when this can occur, and if we can fix them
logging.warning("Model contains registered artifacts, but no restoration paths found")
if len(tarfile_artifacts) > 0 and len(restoration_paths) > 0:
# Need to step into nemo archive to extract file
# Get path where the command is executed - the artifacts will be "retrieved" there
# (original .nemo behavior)
cwd = os.getcwd()
try:
# Step into the nemo archive to try and find the file
with tempfile.TemporaryDirectory() as archive_dir:
self._unpack_nemo_file(path2file=model_metadata.restoration_path, out_folder=archive_dir)
# unpack all restorations paths (nemo checkpoints)
# in nemo checkpoints all resources contain hash in name, so there should be no collisions
for path in restoration_paths:
self._unpack_nemo_file(path2file=path, out_folder=archive_dir)
os.chdir(archive_dir)
for conf_path, artiitem in tarfile_artifacts:
# Get basename and copy it to nemo_file_folder
Expand All @@ -454,8 +488,27 @@ def _handle_artifacts(self, model, nemo_file_folder):
# change back working directory
os.chdir(cwd)

@staticmethod
def _update_subconfigs(model: "nemo_classes.ModelPT", path2yaml_file):
"""
Update subconfigs of the model if ModelPT has submodules
Should be called before updating artifacts paths
"""
if not model.has_nemo_submodules():
# no submodules => nothing to update
return
conf = OmegaConf.load(path2yaml_file)
# update subconfigs for all children recoursively
# parent configs updated before children
for _, conf_path, submodule in model.named_nemo_modules():
if not conf_path: # self
continue
OmegaConf.update(conf, conf_path, submodule.cfg)
with open(path2yaml_file, 'w', encoding='utf-8') as fout:
OmegaConf.save(config=conf, f=fout, resolve=True)

def _update_artifact_paths(self, model, path2yaml_file):
if model.artifacts is not None and len(model.artifacts) > 0:
if hasattr(model, "artifacts") and model.artifacts is not None and len(model.artifacts) > 0:
conf = OmegaConf.load(path2yaml_file)
for conf_path, item in model.artifacts.items():
if item.hashed_path is None:
Expand Down
Loading

0 comments on commit 97973c5

Please sign in to comment.