From 8bb33f9a1b1ece35d0679a19b39474412362f501 Mon Sep 17 00:00:00 2001 From: Leonard Bruns Date: Tue, 16 Apr 2024 15:52:59 +0200 Subject: [PATCH] Update depth-guided stable diffusion example to diffusers 0.27.2 (#5985) ### What Closes #5976 ### Checklist * [x] I have read and agree to [Contributor Guide](https://github.com/rerun-io/rerun/blob/main/CONTRIBUTING.md) and the [Code of Conduct](https://github.com/rerun-io/rerun/blob/main/CODE_OF_CONDUCT.md)manifest_url=https://app.rerun.io/version/main/examples_manifest.json) * Using full set of examples from `nightly` build: [rerun.io/viewer](https://rerun.io/viewer/pr/5985?manifest_url=https://app.rerun.io/version/nightly/examples_manifest.json) * [x] The PR title and labels are set such as to maximize their usefulness for the next release's CHANGELOG - [PR Build Summary](https://build.rerun.io/pr/5985) - [Recent benchmark results](https://build.rerun.io/graphs/crates.html) - [Wasm size tracking](https://build.rerun.io/graphs/sizes.html) To run all checks from `main`, comment on the PR with `@rerun-bot full-check`. --- .../depth_guided_stable_diffusion/README.md | 6 +- .../huggingface_pipeline.py | 781 ++++++++++++------ .../depth_guided_stable_diffusion/main.py | 4 +- .../requirements.txt | 2 +- 4 files changed, 544 insertions(+), 249 deletions(-) diff --git a/examples/python/depth_guided_stable_diffusion/README.md b/examples/python/depth_guided_stable_diffusion/README.md index ac014c2628b1..3dfda2e44952 100644 --- a/examples/python/depth_guided_stable_diffusion/README.md +++ b/examples/python/depth_guided_stable_diffusion/README.md @@ -40,10 +40,10 @@ rr.log("prompt/uncond_input/ids", rr.Tensor(uncond_input.input_ids)) ``` ### Text embeddings -Visualizing the text embeddings. The text embeddings are generated in response to the specific prompts used while the unconditional text embeddings represent a neutral or baseline state without specific input conditions. +Visualizing the text embeddings (i.e., numerical representation of the input texts) from the prompt and negative prompt. ```python -rr.log("prompt/text_embeddings", rr.Tensor(text_embeddings)) -rr.log("prompt/uncond_embeddings", rr.Tensor(uncond_embeddings)) +rr.log("prompt/text_embeddings", rr.Tensor(prompt_embeds)) +rr.log("prompt/negative_text_embeddings", rr.Tensor(negative_prompt_embeds)) ``` ### Depth map diff --git a/examples/python/depth_guided_stable_diffusion/huggingface_pipeline.py b/examples/python/depth_guided_stable_diffusion/huggingface_pipeline.py index eab92e43bd89..9928f4843ff0 100644 --- a/examples/python/depth_guided_stable_diffusion/huggingface_pipeline.py +++ b/examples/python/depth_guided_stable_diffusion/huggingface_pipeline.py @@ -1,13 +1,13 @@ +# Modified StableDiffusionDepth2ImgPipeline to log internal data to Rerun # # This file was originally copied from -# https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py -# commit 9b37ed33b5fa09e594b38e4e6f7477beff3bd66a +# https://github.com/huggingface/diffusers/blob/8e963d1c2a2b6ebf2880675516809bd9a39359a4/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py # # The original copyright and license note are included below. # # This file has only been modified to use the Rerun SDK to log internal data for visualization. # -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,48 +20,59 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# # type: ignore from __future__ import annotations import contextlib import inspect -from typing import Callable -from typing import List -from typing import Optional -from typing import Union +from typing import Any, Callable, Optional, Union import numpy as np -import PIL +import PIL.Image import rerun as rr # pip install rerun-sdk import torch from diffusers.configuration_utils import FrozenDict -from diffusers.models import AutoencoderKL -from diffusers.models import UNet2DConditionModel -from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.pipeline_utils import ImagePipelineOutput -from diffusers.schedulers import DDIMScheduler -from diffusers.schedulers import DPMSolverMultistepScheduler -from diffusers.schedulers import EulerAncestralDiscreteScheduler -from diffusers.schedulers import EulerDiscreteScheduler -from diffusers.schedulers import LMSDiscreteScheduler -from diffusers.schedulers import PNDMScheduler -from diffusers.utils import deprecate -from diffusers.utils import is_accelerate_available -from diffusers.utils import logging -from diffusers.utils import PIL_INTERPOLATION +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) from diffusers.utils.torch_utils import randn_tensor from packaging import version -from transformers import CLIPTextModel -from transformers import CLIPTokenizer -from transformers import DPTFeatureExtractor -from transformers import DPTForDepthEstimation +from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger.addHandler(rr.LoggingHandler("logs")) +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(…) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -69,7 +80,7 @@ def preprocess(image): if isinstance(image[0], PIL.Image.Image): w, h = image[0].size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) @@ -82,41 +93,44 @@ def preprocess(image): return image -class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): +class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" - Pipeline for text-guided image to image generation using Stable Diffusion. - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + Pipeline for text-guided depth-based image-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + Args: + ---- vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + model_cpu_offload_seq = "text_encoder->unet->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "depth_mask"] + def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: ( - DDIMScheduler - | PNDMScheduler - | LMSDiscreteScheduler - | EulerDiscreteScheduler - | EulerAncestralDiscreteScheduler - | DPMSolverMultistepScheduler - ), + scheduler: KarrasDiffusionSchedulers, depth_estimator: DPTForDepthEstimation, feature_extractor: DPTFeatureExtractor, ): @@ -153,49 +167,60 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.depth_estimator]: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device + return prompt_embeds - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): r""" Encodes the prompt into text encoder hidden states. + Args: - prompt (`str` or `list(int)`): + ---- + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -203,54 +228,113 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 rr.log("prompt/text", rr.TextLog(prompt)) - rr.log("prompt/text_negative", rr.TextLog(negative_prompt)) - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - rr.log("prompt/text_input/ids", rr.BarChart(text_input_ids)) - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + rr.log("prompt/text_negative", rr.TextLog(negative_prompt if negative_prompt is not None else "")) + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + rr.log("prompt/text_input/ids", rr.BarChart(text_input_ids)) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + rr.log("prompt/text_input/attention_mask", rr.BarChart(text_inputs.attention_mask)) + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - rr.log("prompt/text_input/attention_mask", rr.BarChart(text_inputs.attention_mask)) - attention_mask = text_inputs.attention_mask.to(device) + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype else: - attention_mask = None + prompt_embeds_dtype = prompt_embeds.dtype - text_embeddings = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - text_embeddings = text_embeddings[0] + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: list[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -266,7 +350,11 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -281,43 +369,53 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: attention_mask = None - uncond_embeddings = self.text_encoder( + negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - uncond_embeddings = uncond_embeddings[0] + negative_prompt_embeds = negative_prompt_embeds[0] + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - rr.log("prompt/text_embeddings", rr.Tensor(text_embeddings)) - rr.log("prompt/uncond_embeddings", rr.Tensor(uncond_embeddings)) - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) - return text_embeddings + rr.log("prompt/text_embeddings", rr.Tensor(prompt_embeds)) + rr.log("prompt/negative_text_embeddings", rr.Tensor(negative_prompt_embeds)) + return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(…) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image @@ -339,28 +437,66 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - def check_inputs(self, prompt, strength, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}") + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start @@ -370,19 +506,37 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) + image = image.to(device=device, dtype=dtype) - init_latent_dist = self.vae.encode(image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - rr.log("encoded_input_image", rr.Tensor(init_latents, dim_names=["b", "c", "h", "w"])) - # Decode the initial latents so we can inspect the image->latent->image loop - decoded = self.vae.decode(init_latents).sample - decoded = (decoded / 2 + 0.5).clamp(0, 1) - decoded = decoded.cpu().permute(0, 2, 3, 1).float().numpy() - decoded = (decoded * 255).round().astype("uint8") - rr.log("decoded_init_latents", rr.Image(decoded)) + batch_size = batch_size * num_images_per_prompt - init_latents = 0.18215 * init_latents + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + rr.log("encoded_input_image", rr.Tensor(init_latents, dim_names=["b", "c", "h", "w"])) + + decoded = self.vae.decode(init_latents / self.vae.config.scaling_factor, return_dict=False)[0] + decoded = self.image_processor.postprocess(decoded) + rr.log("decoded_init_latents", rr.Image(decoded[0])) + + init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size @@ -394,16 +548,16 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: - init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + init_latents = torch.cat([init_latents], dim=0) - # add noise to latents using the timesteps - noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) @@ -415,12 +569,14 @@ def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_gui if isinstance(image, PIL.Image.Image): image = [image] else: - image = [img for img in image] + image = list(image) if isinstance(image[0], PIL.Image.Image): width, height = image[0].size + elif isinstance(image[0], np.ndarray): + width, height = image[0].shape[:-1] else: - width, height = image[0].shape[-2:] + height, width = image[0].shape[-2:] if depth_map is None: pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values @@ -451,174 +607,313 @@ def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_gui # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if depth_map.shape[0] < batch_size: - depth_map = depth_map.repeat(batch_size, 1, 1, 1) + repeat_by = batch_size // depth_map.shape[0] + depth_map = depth_map.repeat(repeat_by, 1, 1, 1) depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map return depth_map + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() def __call__( self, - prompt: str | list[str], - image: torch.FloatTensor | PIL.Image.Image, - depth_map: torch.FloatTensor | None = None, + prompt: Union[str, list[str]] = None, + image: PipelineImageInput = None, + depth_map: Optional[torch.FloatTensor] = None, strength: float = 0.8, - num_inference_steps: int | None = 50, - guidance_scale: float | None = 7.5, - negative_prompt: str | list[str] | None = None, - num_images_per_prompt: int | None = 1, - eta: float | None = 0.0, - generator: torch.Generator | None = None, - output_type: str | None = "pil", + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, list[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Callable[[int, int, torch.FloatTensor], None] | None = None, - callback_steps: int | None = 1, + cross_attention_kwargs: Optional[dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, dict], None]] = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, ): r""" - Function invoked when calling the pipeline for generation. + The call function to the pipeline for generation. + Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. + ---- + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image` or tensor representing an image batch to be used as the starting point. Can accept image + latents as `image` only if `depth_map` is not `None`. + depth_map (`torch.FloatTensor`, *optional*): + Depth prediction to be used as additional conditioning for the image generation process. If not + defined, it automatically predicts the depth with `self.depth_estimator`. strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` - will be used as a starting point, adding more noise to it the larger the `strength`. The number of - denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will - be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter will be modulated by `strength`. + expense of slower inference. This parameter is modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + kwargs: + Only used for deprecated arguments. + + Examples: + -------- + ```py + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> from diffusers import StableDiffusionDepth2ImgPipeline + + >>> pipe = StableDiffusionDepth2ImgPipeline.from_pretrained( + … "stabilityai/stable-diffusion-2-depth", + … torch_dtype=torch.float16, + … ) + >>> pipe.to("cuda") + + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> init_image = Image.open(requests.get(url, stream=True).raw) + >>> prompt = "two tigers" + >>> prompt = "bad, deformed, ugly, bad anotomy" + >>> image = pipe(prompt=prompt, image=init_image, negative_prompt=prompt, strength=0.7).images[0] + ``` + Returns: + ------- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + rr.set_time_sequence("step", -1) + # 1. Check inputs - self.check_inputs(prompt, strength, callback_steps) + self.check_inputs( + prompt, + strength, + callback_steps, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + if image is None: + raise ValueError("`image` input cannot be undefined.") # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - text_embeddings = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare depth mask depth_mask = self.prepare_depth_map( image, depth_map, batch_size * num_images_per_prompt, - do_classifier_free_guidance, - text_embeddings.dtype, + self.do_classifier_free_guidance, + prompt_embeds.dtype, device, ) # 5. Preprocess image rr.log("image/original", rr.Image(image)) - image = preprocess(image) + image = self.image_processor.preprocess(image) rr.log("input_image/preprocessed", rr.Tensor(image)) - # 6. set timesteps + # 6. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 7. Prepare latent variables latents = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) rr.log("diffusion/latents", rr.Tensor(latents, dim_names=["b", "c", "h", "w"])) # 8. Prepare extra step kwargs. - # TODO(someone): Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): rr.set_time_sequence("step", i) rr.set_time_sequence("timestep", t) # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) rr.log("diffusion/latent_model_input", rr.Tensor(latent_model_input)) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + return_dict=False, + )[0] rr.log("diffusion/noise_pred", rr.Tensor(noise_pred, dim_names=["b", "c", "h", "w"])) # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] rr.log("diffusion/latents", rr.Tensor(latents, dim_names=["b", "c", "h", "w"])) # Decode the latents for visualization purposes - image = self.decode_latents(latents) - image = (image * 255).round().astype("uint8") - rr.log("image/diffused", rr.Image(image)) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + rr.log("image/diffused", rr.Image(image[0])) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + depth_mask = callback_outputs.pop("depth_mask", depth_mask) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) - # 10. Post-processing - image = self.decode_latents(latents) - image_8 = (image * 255).round().astype("uint8") - rr.log("image/diffused", rr.Image(image_8)) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) + self.maybe_free_model_hooks() if not return_dict: return (image,) diff --git a/examples/python/depth_guided_stable_diffusion/main.py b/examples/python/depth_guided_stable_diffusion/main.py index 370573ce2413..dafccd255def 100755 --- a/examples/python/depth_guided_stable_diffusion/main.py +++ b/examples/python/depth_guided_stable_diffusion/main.py @@ -134,8 +134,8 @@ def main() -> None: rrb.Vertical( rrb.TextLogView(name="Prompt", contents=["prompt/text", "prompt/text_negative"]), rrb.Tabs( - rrb.TensorView(name="Text embeddings", origin="prompt/text_embeddings"), - rrb.TensorView(name="Unconditional embeddings", origin="prompt/uncond_embeddings"), + rrb.TensorView(name="Prompt embeddings", origin="prompt/text_embeddings"), + rrb.TensorView(name="Negative prompt embeddings", origin="prompt/negative_text_embeddings"), ), rrb.BarChartView(name="Prompt ids", origin="prompt/text_input"), name="Prompt Inputs", diff --git a/examples/python/depth_guided_stable_diffusion/requirements.txt b/examples/python/depth_guided_stable_diffusion/requirements.txt index f48c42477775..141b6f665518 100644 --- a/examples/python/depth_guided_stable_diffusion/requirements.txt +++ b/examples/python/depth_guided_stable_diffusion/requirements.txt @@ -1,5 +1,5 @@ accelerate -diffusers>=0.12.1,<=0.21 +diffusers==0.27.2 ftfy numpy packaging