Skip to content

Commit

Permalink
static script cache
Browse files Browse the repository at this point in the history
  • Loading branch information
ljleb committed Mar 4, 2023
1 parent 721fd78 commit 57ed434
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
7 changes: 3 additions & 4 deletions scripts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,14 @@ def create_cn_script_runner(script_runner, control_unit_requests: List[ControlNe

cn_script_runner = copy.copy(script_runner)

cn_script_args = [False] # is_img2img
cn_script_args: List[Any] = [False] # is_img2img
for control_unit_request in control_unit_requests:
cn_script_args += create_cn_unit_args(control_unit_request)

script_titles = [script.title().lower() for script in script_runner.alwayson_scripts]
cn_script_id = script_titles.index('controlnet')
cn_script = copy.copy(script_runner.alwayson_scripts[cn_script_id])
cn_script.args_from = 0
cn_script.args_to = len(cn_script_args)
cn_script = script_runner.alwayson_scripts[cn_script_id]
cn_script_args = ([None] * cn_script.args_from) + cn_script_args

def make_script_runner_f_hijack(fixed_original_f):
def script_runner_f_hijack(p, *args, **kwargs):
Expand Down
20 changes: 11 additions & 9 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,13 @@ def update_cn_models():
update_cn_models()




class Script(scripts.Script):
model_cache = {}

def __init__(self) -> None:
super().__init__()
self.model_cache = {}
self.latest_network = None
self.preprocessor = {
"none": lambda x, *args, **kwargs: x,
Expand Down Expand Up @@ -583,15 +586,14 @@ def process(self, p, is_img2img=False, *args):
hook_lowvram = False

# cache stuff
models_changed = self.latest_model_hash != p.sd_model.sd_model_hash or self.model_cache == {} or self.model_cache is None
if models_changed or len(self.model_cache) >= shared.opts.data.get("control_net_model_cache_size", 2):
for key, model in self.model_cache.items():
models_changed = self.latest_model_hash != p.sd_model.sd_model_hash or not Script.model_cache or Script.model_cache is None
if models_changed or len(Script.model_cache) >= shared.opts.data.get("control_net_model_cache_size", 2):
for key, model in Script.model_cache.items():
model.to("cpu")
del self.model_cache
Script.model_cache.clear()
gc.collect()
devices.torch_gc()
self.model_cache = {}


# unload unused preproc
module_list = [mod[0] for mod in control_groups]
for key in self.unloadable:
Expand All @@ -608,12 +610,12 @@ def process(self, p, is_img2img=False, *args):
if lowvram:
hook_lowvram = True

model_net = self.model_cache[model] if model in self.model_cache \
model_net = Script.model_cache[model] if model in Script.model_cache \
else self.build_control_model(p, unet, model, lowvram)

model_net.reset()
networks.append(model_net)
self.model_cache[model] = model_net
Script.model_cache[model] = model_net

is_img2img_batch_tab = is_img2img and img2img_tab_tracker.submit_img2img_tab == 'img2img_batch_tab'
if is_img2img_batch_tab and hasattr(p, "image_control") and p.image_control is not None:
Expand Down

0 comments on commit 57ed434

Please sign in to comment.