Skip to content

Commit

Permalink
[fix] multi t2i adapter set total_downscale_factor (huggingface#4621)
Browse files Browse the repository at this point in the history
* [fix] multi t2i adapter set total_downscale_factor

* move image checks into check inputs

* remove copied from
  • Loading branch information
williamberman authored Aug 24, 2023
1 parent 58f5f74 commit 3105c71
Show file tree
Hide file tree
Showing 3 changed files with 301 additions and 27 deletions.
33 changes: 26 additions & 7 deletions src/diffusers/models/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,31 @@ def __init__(self, adapters: List["T2IAdapter"]):
self.num_adapter = len(adapters)
self.adapters = nn.ModuleList(adapters)

if len(adapters) == 0:
raise ValueError("Expecting at least one adapter")

if len(adapters) == 1:
raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")

# The outputs from each adapter are added together with a weight
# This means that the change in dimenstions from downsampling must
# be the same for all adapters. Inductively, it also means the total
# downscale factor must also be the same for all adapters.

first_adapter_total_downscale_factor = adapters[0].total_downscale_factor

for idx in range(1, len(adapters)):
adapter_idx_total_downscale_factor = adapters[idx].total_downscale_factor

if adapter_idx_total_downscale_factor != first_adapter_total_downscale_factor:
raise ValueError(
f"Expecting all adapters to have the same total_downscale_factor, "
f"but got adapters[0].total_downscale_factor={first_adapter_total_downscale_factor} and "
f"adapter[`{idx}`]={adapter_idx_total_downscale_factor}"
)

self.total_downscale_factor = adapters[0].total_downscale_factor

def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
r"""
Args:
Expand All @@ -56,14 +81,8 @@ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = Non
else:
adapter_weights = torch.tensor(adapter_weights)

if xs.shape[1] % self.num_adapter != 0:
raise ValueError(
f"Expecting multi-adapter's input have number of channel that cab be evenly divisible "
f"by num_adapter: {xs.shape[1]} % {self.num_adapter} != 0"
)
x_list = torch.chunk(xs, self.num_adapter, dim=1)
accume_state = None
for x, w, adapter in zip(x_list, adapter_weights, self.adapters):
for x, w, adapter in zip(xs, adapter_weights, self.adapters):
features = adapter(x)
if accume_state is None:
accume_state = features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,13 +453,13 @@ def prepare_extra_step_kwargs(self, generator, eta):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
image,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
Expand Down Expand Up @@ -501,6 +501,17 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

if isinstance(self.adapter, MultiAdapter):
if not isinstance(image, list):
raise ValueError(
"MultiAdapter is enabled, but `image` is not a list. Please pass a list of images to `image`."
)

if len(image) != len(self.adapter.adapters):
raise ValueError(
f"MultiAdapter requires passing the same number of images as adapters. Given {len(image)} images and {len(self.adapter.adapters)} adapters."
)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
Expand Down Expand Up @@ -653,17 +664,19 @@ def __call__(

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
prompt, height, width, callback_steps, image, negative_prompt, prompt_embeds, negative_prompt_embeds
)

is_multi_adapter = isinstance(self.adapter, MultiAdapter)
if is_multi_adapter:
adapter_input = [_preprocess_adapter_image(img, height, width).to(device) for img in image]
n, c, h, w = adapter_input[0].shape
adapter_input = torch.stack([x.reshape([n * c, h, w]) for x in adapter_input])
if isinstance(self.adapter, MultiAdapter):
adapter_input = []

for one_image in image:
one_image = _preprocess_adapter_image(one_image, height, width)
one_image = one_image.to(device=device, dtype=self.adapter.dtype)
adapter_input.append(one_image)
else:
adapter_input = _preprocess_adapter_image(image, height, width).to(device)
adapter_input = adapter_input.to(self.adapter.dtype)
adapter_input = _preprocess_adapter_image(image, height, width)
adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype)

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down
Loading

0 comments on commit 3105c71

Please sign in to comment.