Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
shine committed Apr 23, 2024
1 parent 6a6c2da commit de5a6ff
Show file tree
Hide file tree
Showing 20 changed files with 899 additions and 60 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ Magic Clothing is a branch version of [OOTDiffusion](https://github.com/levihsu/

## News
🔥 [2024/4/23] In response to the enthusiasm for virtual try-on, we release ***gradio_virtual_tryon***, and see our [guidance]!

🔥 [2024/4/19] An 1024 version trained on both VTON-HD and DressCode for early access branch is avaliable now!

🔥 [2024/4/19] We support AnimateDiff now for generating GIF!
Expand All @@ -33,6 +35,19 @@ Have fun with ***gradio_ipadapter_faceid.py***
![demo](images/demo.png) 
![workflow](images/workflow.png) 

***Virtual Try-On Demo***
<div align="left">
<img src="virtual_tryon_img/a1.jpg" alt="图片1" width="10%">
<img src="virtual_tryon_img/a2.png" alt="图片2" width="10%">
<img src="virtual_tryon_img/a3.png" alt="图片3" width="10%">
<img src="virtual_tryon_img/b1.jpg" alt="图片4" width="10%">
<img src="virtual_tryon_img/b2.png" alt="图片5" width="10%">
<img src="virtual_tryon_img/b3.png" alt="图片6" width="10%">
<img src="virtual_tryon_img/c1.jpg" alt="图片7" width="10%">
<img src="virtual_tryon_img/c2.png" alt="图片8" width="10%">
<img src="virtual_tryon_img/c3.png" alt="图片9" width="10%">
</div>

***1024 version for upper-body lower-body and full-body clothes Demo***
<div align="left">
<img src="images/a0.jpg" alt="图片1" width="15%">
Expand Down
189 changes: 164 additions & 25 deletions garment_adapter/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def __call__(
scale: float = 1.0,
attn_store=None,
do_classifier_free_guidance=None,
enable_cloth_guidance=None
enable_cloth_guidance=None,
use_independent_condition=None
) -> torch.Tensor:
residual = hidden_states

Expand Down Expand Up @@ -99,7 +100,8 @@ def __call__(
scale: float = 1.0,
attn_store=None,
do_classifier_free_guidance=None,
enable_cloth_guidance=None
enable_cloth_guidance=None,
use_independent_condition=None
) -> torch.Tensor:
if self.type == "read":
attn_store[self.name] = hidden_states
Expand All @@ -108,7 +110,10 @@ def __call__(
if do_classifier_free_guidance:
empty_copy = torch.zeros_like(ref_hidden_states)
if enable_cloth_guidance:
ref_hidden_states = torch.cat([empty_copy, ref_hidden_states, ref_hidden_states])
if use_independent_condition:
ref_hidden_states = torch.cat([empty_copy, empty_copy, ref_hidden_states])
else:
ref_hidden_states = torch.cat([empty_copy, ref_hidden_states, ref_hidden_states])
else:
ref_hidden_states = torch.cat([empty_copy, ref_hidden_states])
hidden_states = torch.cat([hidden_states, ref_hidden_states], dim=1)
Expand Down Expand Up @@ -192,7 +197,8 @@ def __call__(
scale: float = 1.0,
attn_store=None,
do_classifier_free_guidance=None,
enable_cloth_guidance=None
enable_cloth_guidance=None,
use_independent_condition=None,
) -> torch.FloatTensor:
residual = hidden_states
if attn.spatial_norm is not None:
Expand Down Expand Up @@ -279,7 +285,8 @@ def __call__(
scale: float = 1.0,
attn_store=None,
do_classifier_free_guidance=False,
enable_cloth_guidance=True
enable_cloth_guidance=True,
use_independent_condition=None
) -> torch.FloatTensor:
if self.type == "read":
attn_store[self.name] = hidden_states
Expand All @@ -288,7 +295,10 @@ def __call__(
if do_classifier_free_guidance:
empty_copy = torch.zeros_like(ref_hidden_states)
if enable_cloth_guidance:
ref_hidden_states = torch.cat([empty_copy, ref_hidden_states, ref_hidden_states])
if use_independent_condition:
ref_hidden_states = torch.cat([empty_copy, empty_copy, ref_hidden_states])
else:
ref_hidden_states = torch.cat([empty_copy, ref_hidden_states, ref_hidden_states])
else:
ref_hidden_states = torch.cat([empty_copy, ref_hidden_states])
hidden_states = torch.cat([hidden_states, ref_hidden_states], dim=1)
Expand Down Expand Up @@ -363,12 +373,14 @@ def __call__(


class REFAnimateDiffAttnProcessor2_0(nn.Module):
def __init__(self, name, type="read"):
def __init__(self, cross_attention_dim, hidden_size, name):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.name = name
self.type = type
self.scale = 1.0
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

def __call__(
self,
Expand All @@ -381,19 +393,17 @@ def __call__(
attn_store=None,
do_classifier_free_guidance=False,
) -> torch.FloatTensor:
if self.type == "read":
attn_store[self.name] = hidden_states
elif self.type == "write":
ref_hidden_states = attn_store[self.name]
if do_classifier_free_guidance:
empty_copy = torch.zeros_like(ref_hidden_states)
ref_hidden_states = torch.cat([empty_copy, ref_hidden_states, ref_hidden_states])
if hidden_states.shape[0] % ref_hidden_states.shape[0] != 0:
raise ValueError("not evenly divisible")
# ref_hidden_states = ref_hidden_states*1.05
hidden_states = torch.cat([hidden_states, ref_hidden_states.repeat(hidden_states.shape[0] // ref_hidden_states.shape[0], 1, 1)], dim=1)
else:
raise ValueError("unsupport type")
ref_hidden_states = attn_store[self.name]
if do_classifier_free_guidance:
empty_copy = torch.zeros_like(ref_hidden_states)
repeat_num = hidden_states.shape[0] // 3
ref_hidden_states = torch.cat(
[empty_copy.repeat(repeat_num, 1, 1), ref_hidden_states.repeat(repeat_num, 1, 1),
ref_hidden_states.repeat(repeat_num, 1, 1)])

if hidden_states.shape[0] % ref_hidden_states.shape[0] != 0:
raise ValueError("not evenly divisible")

residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
Expand Down Expand Up @@ -445,8 +455,16 @@ def __call__(
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

if self.type == "write":
hidden_states, _ = torch.chunk(hidden_states, 2, dim=1)
ref_key = self.to_k_ip(ref_hidden_states.float()).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ref_value = self.to_v_ip(ref_hidden_states.float()).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ref_hidden_states = F.scaled_dot_product_attention(
query.float(), ref_key, ref_value, attn_mask=None, dropout_p=0.0, is_causal=False
)

ref_hidden_states = ref_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ref_hidden_states = ref_hidden_states.to(query.dtype)

hidden_states = hidden_states + self.scale * ref_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
Expand Down Expand Up @@ -484,7 +502,8 @@ def __call__(
temb=None,
attn_store=None,
do_classifier_free_guidance=None,
enable_cloth_guidance=None
enable_cloth_guidance=None,
use_independent_condition=None
):
residual = hidden_states

Expand Down Expand Up @@ -585,7 +604,8 @@ def __call__(
temb=None,
attn_store=None,
do_classifier_free_guidance=None,
enable_cloth_guidance=None
enable_cloth_guidance=None,
use_independent_condition=None
):
residual = hidden_states

Expand Down Expand Up @@ -680,3 +700,122 @@ def __call__(
hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states

class StableREFAttnProcessor2_0(nn.Module):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""

def __init__(self, cross_attention_dim, hidden_size, name):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.scale = 1.
self.name = name

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
attn_store=None,
do_classifier_free_guidance=False,
enable_cloth_guidance=True
) -> torch.FloatTensor:

residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
ip_hidden_states = attn_store[self.name]
if do_classifier_free_guidance:
empty_copy = torch.zeros_like(ip_hidden_states)
if enable_cloth_guidance:
ip_hidden_states = torch.cat([empty_copy, ip_hidden_states, ip_hidden_states])
else:
ip_hidden_states = torch.cat([empty_copy, ip_hidden_states])
encoder_hidden_states = None
input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states.float())
ip_value = self.to_v_ip(ip_hidden_states.float())

ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query.float(), ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)

ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)

hidden_states = hidden_states + self.scale * ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states
Loading

0 comments on commit de5a6ff

Please sign in to comment.