Skip to content

Commit

Permalink
Support for CosXL models.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Apr 5, 2024
1 parent 41ed7e8 commit 1088d18
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
6 changes: 4 additions & 2 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,10 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None):
class SDXL_instructpix2pix(IP2P, SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
# self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image)
self.process_ip2p_image_in = lambda image: image
if model_type == ModelType.V_PREDICTION_EDM:
self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image) #cosxl ip2p
else:
self.process_ip2p_image_in = lambda image: image #diffusers ip2p


class StableCascade_C(BaseModel):
Expand Down
7 changes: 6 additions & 1 deletion comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ def model_type(self, state_dict, prefix=""):
self.sampling_settings["sigma_max"] = 80.0
self.sampling_settings["sigma_min"] = 0.002
return model_base.ModelType.EDM
elif "edm_vpred.sigma_max" in state_dict:
self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item())
if "edm_vpred.sigma_min" in state_dict:
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
return model_base.ModelType.V_PREDICTION_EDM
elif "v_pred" in state_dict:
return model_base.ModelType.V_PREDICTION
else:
Expand Down Expand Up @@ -469,7 +474,7 @@ class SDXL_instructpix2pix(SDXL):
}

def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL_instructpix2pix(self, device=device)
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)

models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]

Expand Down

0 comments on commit 1088d18

Please sign in to comment.