Skip to content

Commit 9d20ed3

Browse files
Perturbed-Attention Guidance (huggingface#7512)
* pag_initial * pag_docs * edit_docs * custom * typo * delete_docs * whitespace * make style --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent bda1d4f commit 9d20ed3

File tree

2 files changed

+1554
-0
lines changed

2 files changed

+1554
-0
lines changed

examples/community/README.md

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3743,3 +3743,80 @@ onestep_image = pipe(prompt, num_inference_steps=1).images[0]
37433743
# Multistep sampling
37443744
multistep_image = pipe(prompt, num_inference_steps=4).images[0]
37453745
```
3746+
3747+
# Perturbed-Attention Guidance
3748+
3749+
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
3750+
3751+
This implementation is based on [Diffusers](https://huggingface.co/docs/diffusers/index). StableDiffusionPAGPipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG).
3752+
3753+
## Example Usage
3754+
3755+
```
3756+
import os
3757+
import torch
3758+
3759+
from accelerate.utils import set_seed
3760+
3761+
from diffusers import StableDiffusionPipeline
3762+
from diffusers.utils import load_image, make_image_grid
3763+
from diffusers.utils.torch_utils import randn_tensor
3764+
3765+
pipe = StableDiffusionPipeline.from_pretrained(
3766+
"runwayml/stable-diffusion-v1-5",
3767+
custom_pipeline="hyoungwoncho/sd_perturbed_attention_guidance",
3768+
torch_dtype=torch.float16
3769+
)
3770+
3771+
device="cuda"
3772+
pipe = pipe.to(device)
3773+
3774+
pag_scale = 5.0
3775+
pag_applied_layers_index = ['m0']
3776+
3777+
batch_size = 4
3778+
seed=10
3779+
3780+
base_dir = "./results/"
3781+
grid_dir = base_dir + "/pag" + str(pag_scale) + "/"
3782+
3783+
if not os.path.exists(grid_dir):
3784+
os.makedirs(grid_dir)
3785+
3786+
set_seed(seed)
3787+
3788+
latent_input = randn_tensor(shape=(batch_size,4,64,64),generator=None, device=device, dtype=torch.float16)
3789+
3790+
output_baseline = pipe(
3791+
"",
3792+
width=512,
3793+
height=512,
3794+
num_inference_steps=50,
3795+
guidance_scale=0.0,
3796+
pag_scale=0.0,
3797+
pag_applied_layers_index=pag_applied_layers_index,
3798+
num_images_per_prompt=batch_size,
3799+
latents=latent_input
3800+
).images
3801+
3802+
output_pag = pipe(
3803+
"",
3804+
width=512,
3805+
height=512,
3806+
num_inference_steps=50,
3807+
guidance_scale=0.0,
3808+
pag_scale=5.0,
3809+
pag_applied_layers_index=pag_applied_layers_index,
3810+
num_images_per_prompt=batch_size,
3811+
latents=latent_input
3812+
).images
3813+
3814+
grid_image = make_image_grid(output_baseline + output_pag, rows=2, cols=batch_size)
3815+
grid_image.save(grid_dir + "sample.png")
3816+
```
3817+
3818+
## PAG Parameters
3819+
3820+
pag_scale : gudiance scale of PAG (ex: 5.0)
3821+
3822+
pag_applied_layers_index : index of the layer to apply perturbation (ex: ['m0'])

0 commit comments

Comments
 (0)