diff --git a/demucs/pretrained.py b/demucs/pretrained.py index 23cb8b0e..65851cb6 100644 --- a/demucs/pretrained.py +++ b/demucs/pretrained.py @@ -14,6 +14,7 @@ from .hdemucs import HDemucs from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa +from .states import _check_diffq logger = logging.getLogger(__name__) ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/" @@ -71,7 +72,13 @@ def get_model(name: str, model_repo = LocalRepo(repo) bag_repo = BagOnlyRepo(repo, model_repo) any_repo = AnyModelRepo(model_repo, bag_repo) - model = any_repo.get_model(name) + try: + model = any_repo.get_model(name) + except ImportError as exc: + if 'diffq' in exc.args[0]: + _check_diffq() + raise + model.eval() return model diff --git a/demucs/states.py b/demucs/states.py index 48252c5e..361bb419 100644 --- a/demucs/states.py +++ b/demucs/states.py @@ -16,19 +16,32 @@ import warnings from omegaconf import OmegaConf -from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state +from dora.log import fatal import torch +def _check_diffq(): + try: + import diffq # noqa + except ImportError: + fatal('Trying to use DiffQ, but diffq is not installed.\n' + 'On Windows run: python.exe -m pip install diffq \n' + 'On Linux/Mac, run: python3 -m pip install diffq') + + def get_quantizer(model, args, optimizer=None): """Return the quantizer given the XP quantization args.""" quantizer = None if args.diffq: + _check_diffq() + from diffq import DiffQuantizer quantizer = DiffQuantizer( model, min_size=args.min_size, group_size=args.group_size) if optimizer is not None: quantizer.setup_optimizer(optimizer) elif args.qat: + _check_diffq() + from diffq import UniformQuantizer quantizer = UniformQuantizer( model, bits=args.qat, min_size=args.min_size) return quantizer @@ -86,6 +99,8 @@ def set_state(model, state, quantizer=None): if quantizer is not None: quantizer.restore_quantized_state(model, state['quantized']) else: + _check_diffq() + from diffq import restore_quantized_state restore_quantized_state(model, state) else: model.load_state_dict(state) diff --git a/docs/release.md b/docs/release.md index 3870e165..005055d7 100644 --- a/docs/release.md +++ b/docs/release.md @@ -5,6 +5,8 @@ Various improvements by @CarlGao4. Support for `segment` param inside of HTDemucs model. +Made diffq an optional dependency, with an error message if not installed. + ## V4.0.0, 7th of December 2022 Adding hybrid transformer Demucs model. diff --git a/hubconf.py b/hubconf.py index 019b2e89..0cdb553e 100644 --- a/hubconf.py +++ b/hubconf.py @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -dependencies = ['dora-search', 'diffq', 'julius', 'lameenc', 'openunmix', 'pyyaml', +dependencies = ['dora-search', 'julius', 'lameenc', 'openunmix', 'pyyaml', 'torch', 'torchaudio', 'tqdm'] from demucs.pretrained import get_model diff --git a/requirements_minimal.txt b/requirements_minimal.txt index f1ccb059..8c6f1e57 100644 --- a/requirements_minimal.txt +++ b/requirements_minimal.txt @@ -1,6 +1,5 @@ # please make sure you have already a pytorch install that is cuda enabled! dora-search -diffq>=0.2.1 einops julius>=0.2.3 lameenc>=1.2