-
-
Notifications
You must be signed in to change notification settings - Fork 449
/
Copy pathmodel_omnigen.py
31 lines (27 loc) · 1.13 KB
/
model_omnigen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def load_omnigen(checkpoint_info, diffusers_load_config={}): # pylint: disable=unused-argument
from modules import shared, devices, sd_models
repo_id = sd_models.path_to_repo(checkpoint_info.name)
# load
from modules.omnigen import OmniGenPipeline
pipe = OmniGenPipeline.from_pretrained(
model_name=repo_id,
vae_path='madebyollin/sdxl-vae-fp16-fix',
cache_dir=shared.opts.diffusers_dir,
)
# init
pipe.device = devices.device
pipe.dtype = devices.dtype
pipe.model.device = devices.device
pipe.separate_cfg_infer = True
pipe.use_kv_cache = False
pipe.model.to(device=devices.device, dtype=devices.dtype)
if shared.opts.diffusers_eval:
pipe.model.eval()
pipe.vae.to(devices.device, dtype=devices.dtype)
devices.torch_gc()
# register
# from diffusers import pipelines
# pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["omnigen"] = pipe.__class__
# pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["omnigen"] = pipe.__class__
# pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["omnigen"] = pipe.__class__
return pipe