Skip to content

Commit

Permalink
[From pretrained] Speed-up loading from cache (huggingface#2515)
Browse files Browse the repository at this point in the history
* [From pretrained] Speed-up loading from cache

* up

* Fix more

* fix one more bug

* make style

* bigger refactor

* factor out function

* Improve more

* better

* deprecate return cache folder

* clean up

* improve tests

* up

* upload

* add nice tests

* simplify

* finish

* correct

* fix version

* rename

* Apply suggestions from code review

Co-authored-by: Lucain <[email protected]>

* rename

* correct doc string

* correct more

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

* apply code suggestions

* finish

---------

Co-authored-by: Lucain <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
  • Loading branch information
3 people authored Mar 10, 2023
1 parent 7fe638c commit d761b58
Show file tree
Hide file tree
Showing 12 changed files with 636 additions and 318 deletions.
4 changes: 2 additions & 2 deletions scripts/convert_original_stable_diffusion_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import argparse

from diffusers.pipelines.stable_diffusion.convert_from_ckpt import load_pipeline_from_original_stable_diffusion_ckpt
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt


if __name__ == "__main__":
Expand Down Expand Up @@ -125,7 +125,7 @@
)
args = parser.parse_args()

pipe = load_pipeline_from_original_stable_diffusion_ckpt(
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=args.checkpoint_path,
original_config_file=args.original_config_file,
image_size=args.image_size,
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.10.0",
"huggingface-hub>=0.13.0",
"requests-mock==1.10.0",
"importlib_metadata",
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2",
Expand Down Expand Up @@ -192,6 +193,7 @@ def run(self):
"pytest",
"pytest-timeout",
"pytest-xdist",
"requests-mock",
"safetensors",
"sentencepiece",
"scipy",
Expand Down
39 changes: 33 additions & 6 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,15 @@
from requests import HTTPError

from . import __version__
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
from .utils import (
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
DummyObject,
deprecate,
extract_commit_hash,
http_user_agent,
logging,
)


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -231,7 +239,11 @@ def get_config_dict(cls, *args, **kwargs):

@classmethod
def load_config(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
return_unused_kwargs=False,
return_commit_hash=False,
**kwargs,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
r"""
Instantiate a Python class from a config dictionary
Expand Down Expand Up @@ -271,6 +283,10 @@ def load_config(
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
return_unused_kwargs (`bool`, *optional*, defaults to `False):
Whether unused keyword arguments of the config shall be returned.
return_commit_hash (`bool`, *optional*, defaults to `False):
Whether the commit_hash of the loaded configuration shall be returned.
<Tip>
Expand All @@ -295,8 +311,10 @@ def load_config(
revision = kwargs.pop("revision", None)
_ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})

user_agent = {"file_type": "config"}
user_agent = {**user_agent, "file_type": "config"}
user_agent = http_user_agent(user_agent)

pretrained_model_name_or_path = str(pretrained_model_name_or_path)

Expand Down Expand Up @@ -336,7 +354,6 @@ def load_config(
subfolder=subfolder,
revision=revision,
)

except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
Expand Down Expand Up @@ -378,13 +395,23 @@ def load_config(
try:
# Load config dict
config_dict = cls._dict_from_json_file(config_file)

commit_hash = extract_commit_hash(config_file)
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")

if not (return_unused_kwargs or return_commit_hash):
return config_dict

outputs = (config_dict,)

if return_unused_kwargs:
return config_dict, kwargs
outputs += (kwargs,)

if return_commit_hash:
outputs += (commit_hash,)

return config_dict
return outputs

@staticmethod
def _get_init_keys(cls):
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.10.0",
"huggingface-hub": "huggingface-hub>=0.13.0",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2",
Expand Down
78 changes: 28 additions & 50 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,18 +458,34 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)

# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path

user_agent = {
"diffusers": __version__,
"file_type": "model",
"framework": "pytorch",
}

# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path

# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model

# load config
config, unused_kwargs, commit_hash = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
return_commit_hash=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
user_agent=user_agent,
**kwargs,
)

# load model
model_file = None
if from_flax:
model_file = _get_model_file(
Expand All @@ -484,20 +500,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
config, unused_kwargs = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
commit_hash=commit_hash,
)
model = cls.from_config(config, **unused_kwargs)

Expand All @@ -520,6 +523,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
except: # noqa: E722
pass
Expand All @@ -536,25 +540,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)

if low_cpu_mem_usage:
# Instantiate model with empty weights
with accelerate.init_empty_weights():
config, unused_kwargs = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)
model = cls.from_config(config, **unused_kwargs)

# if device_map is None, load the state dict and move the params from meta device to the cpu
Expand Down Expand Up @@ -593,20 +584,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
"error_msgs": [],
}
else:
config, unused_kwargs = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)
model = cls.from_config(config, **unused_kwargs)

state_dict = load_state_dict(model_file, variant=variant)
Expand Down Expand Up @@ -803,6 +780,7 @@ def _get_model_file(
use_auth_token,
user_agent,
revision,
commit_hash=None,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
Expand Down Expand Up @@ -840,7 +818,7 @@ def _get_model_file(
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
revision=revision or commit_hash,
)
warnings.warn(
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
Expand All @@ -865,7 +843,7 @@ def _get_model_file(
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
revision=revision or commit_hash,
)
return model_file

Expand Down
Loading

0 comments on commit d761b58

Please sign in to comment.