Skip to content

Commit

Permalink
Fix model needing to be loaded on GPU to generate the sigmas.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Apr 5, 2024
1 parent 1f8d8e6 commit 1a0486b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 16 deletions.
6 changes: 6 additions & 0 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ def set_model_output_block_patch(self, patch):
def add_object_patch(self, name, obj):
self.object_patches[name] = obj

def get_model_object(self, name):
if name in self.object_patches:
return self.object_patches[name]
else:
return comfy.utils.get_attr(self.model, name)

def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
Expand Down
28 changes: 14 additions & 14 deletions comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,17 +274,17 @@ def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
out = out * denoise_mask + self.latent_image * latent_mask
return out

def simple_scheduler(model, steps):
s = model.model_sampling
def simple_scheduler(model_sampling, steps):
s = model_sampling
sigs = []
ss = len(s.sigmas) / steps
for x in range(steps):
sigs += [float(s.sigmas[-(1 + int(x * ss))])]
sigs += [0.0]
return torch.FloatTensor(sigs)

def ddim_scheduler(model, steps):
s = model.model_sampling
def ddim_scheduler(model_sampling, steps):
s = model_sampling
sigs = []
ss = max(len(s.sigmas) // steps, 1)
x = 1
Expand All @@ -295,8 +295,8 @@ def ddim_scheduler(model, steps):
sigs += [0.0]
return torch.FloatTensor(sigs)

def normal_scheduler(model, steps, sgm=False, floor=False):
s = model.model_sampling
def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
s = model_sampling
start = s.timestep(s.sigma_max)
end = s.timestep(s.sigma_min)

Expand Down Expand Up @@ -660,19 +660,19 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]

def calculate_sigmas_scheduler(model, scheduler_name, steps):
def calculate_sigmas(model_sampling, scheduler_name, steps):
if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "normal":
sigmas = normal_scheduler(model, steps)
sigmas = normal_scheduler(model_sampling, steps)
elif scheduler_name == "simple":
sigmas = simple_scheduler(model, steps)
sigmas = simple_scheduler(model_sampling, steps)
elif scheduler_name == "ddim_uniform":
sigmas = ddim_scheduler(model, steps)
sigmas = ddim_scheduler(model_sampling, steps)
elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model, steps, sgm=True)
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
else:
logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas
Expand Down Expand Up @@ -714,7 +714,7 @@ def calculate_sigmas(self, steps):
steps += 1
discard_penultimate_sigma = True

sigmas = calculate_sigmas_scheduler(self.model.model, self.scheduler, steps)
sigmas = calculate_sigmas(self.model.get_model_object("model_sampling"), self.scheduler, steps)

if discard_penultimate_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
Expand Down
3 changes: 1 addition & 2 deletions comfy_extras/nodes_custom_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ def get_sigmas(self, model, scheduler, steps, denoise):
return (torch.FloatTensor([]),)
total_steps = int(steps/denoise)

comfy.model_management.load_models_gpu([model])
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu()
sigmas = sigmas[-(steps + 1):]
return (sigmas, )

Expand Down

0 comments on commit 1a0486b

Please sign in to comment.