Skip to content

Commit

Permalink
[ip-adapter] fix problem using embeds with the plus version of ip ada…
Browse files Browse the repository at this point in the history
…pters (huggingface#7189)

* initial

* check_inputs fix to the rest of pipelines

* add fix for no cfg too

* use of variable

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
asomoza and sayakpaul authored Mar 3, 2024
1 parent f55873b commit 001b140
Show file tree
Hide file tree
Showing 20 changed files with 240 additions and 100 deletions.
17 changes: 12 additions & 5 deletions src/diffusers/pipelines/animatediff/pipeline_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,15 +400,22 @@ def prepare_ip_adapter_image_embeds(

image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)

return image_embeds
Expand Down Expand Up @@ -509,9 +516,9 @@ def check_inputs(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)

# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,15 +478,22 @@ def prepare_ip_adapter_image_embeds(

image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)

return image_embeds
Expand Down Expand Up @@ -589,9 +596,9 @@ def check_inputs(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)

def get_timesteps(self, num_inference_steps, timesteps, strength, device):
Expand Down
17 changes: 12 additions & 5 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,15 +510,22 @@ def prepare_ip_adapter_image_embeds(

image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)

return image_embeds
Expand Down Expand Up @@ -726,9 +733,9 @@ def check_inputs(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)

def check_image(self, image, prompt, prompt_embeds):
Expand Down
17 changes: 12 additions & 5 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,15 +503,22 @@ def prepare_ip_adapter_image_embeds(

image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)

return image_embeds
Expand Down Expand Up @@ -713,9 +720,9 @@ def check_inputs(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)

# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
Expand Down
17 changes: 12 additions & 5 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,15 +628,22 @@ def prepare_ip_adapter_image_embeds(

image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)

return image_embeds
Expand Down Expand Up @@ -871,9 +878,9 @@ def check_inputs(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)

# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -537,15 +537,22 @@ def prepare_ip_adapter_image_embeds(

image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)

return image_embeds
Expand Down Expand Up @@ -817,9 +824,9 @@ def check_inputs(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)

def prepare_control_image(
Expand Down
17 changes: 12 additions & 5 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,15 +515,22 @@ def prepare_ip_adapter_image_embeds(

image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)

return image_embeds
Expand Down Expand Up @@ -730,9 +737,9 @@ def check_inputs(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)

# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -567,15 +567,22 @@ def prepare_ip_adapter_image_embeds(

image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)

return image_embeds
Expand Down Expand Up @@ -794,9 +801,9 @@ def check_inputs(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)

# Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image
Expand Down
Loading

0 comments on commit 001b140

Please sign in to comment.