Skip to content

Commit

Permalink
update demo transformer_2d & simplified_facebook_dit
Browse files Browse the repository at this point in the history
  • Loading branch information
chang-wenbin committed Aug 7, 2024
1 parent 400ab19 commit bfe8c41
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from paddlenlp.trainer import set_seed
from ppdiffusers import DDIMScheduler, DiTPipeline

os.environ["Inference_Optimize"] = "False"
os.environ["INFOPTIMIZE"] = "False"

dtype = paddle.float16
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
Expand All @@ -27,25 +27,6 @@
words = ["golden retriever"] # class_ids [207]
class_ids = pipe.get_label_ids(words)

# warmup
for i in range(5):
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]


import datetime
import time
repeat_times = 10
paddle.device.synchronize()
starttime = datetime.datetime.now()

for i in range(repeat_times):
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]

paddle.device.synchronize()
endtime = datetime.datetime.now()
duringtime = endtime - starttime
time_ms = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0

print("The ave end to end time : ", time_ms / repeat_times, "ms")
image.save("class_conditional_image_generation-dit-result.png")

image = pipe(class_labels=class_ids, num_inference_steps=25).images[0]
image.save("class_conditional_image_generation-dit-result.png")
57 changes: 31 additions & 26 deletions ppdiffusers/ppdiffusers/models/simplified_facebook_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import paddle.nn.functional as F
import math


class SimplifiedFacebookDIT(nn.Layer):
def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int):
super().__init__()
Expand All @@ -15,29 +16,28 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio
self.timestep_embedder_time_embed_dim_out = self.timestep_embedder_time_embed_dim
self.LabelEmbedding_num_classes = 1001
self.LabelEmbedding_num_hidden_size = 1152
self.fcs0 = nn.LayerList([nn.Linear(self.timestep_embedder_in_channels,

self.fcs0 = nn.LayerList([nn.Linear(self.timestep_embedder_in_channels,
self.timestep_embedder_time_embed_dim) for i in range(num_layers)])

self.fcs1 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim,
self.timestep_embedder_time_embed_dim_out) for i in range(num_layers)])

self.fcs2 = nn.LayerList([nn.Linear(self.timestep_embedder_time_embed_dim,
6 * self.timestep_embedder_time_embed_dim) for i in range(num_layers)])
self.embs = nn.LayerList([nn.Embedding(self.LabelEmbedding_num_classes,

self.embs = nn.LayerList([nn.Embedding(self.LabelEmbedding_num_classes,
self.LabelEmbedding_num_hidden_size) for i in range(num_layers)])


self.q = nn.LayerList([nn.Linear(dim, dim ) for i in range(num_layers)])
self.k = nn.LayerList([nn.Linear(dim, dim ) for i in range(num_layers)])
self.v = nn.LayerList([nn.Linear(dim, dim ) for i in range(num_layers)])
self.q = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)])
self.k = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)])
self.v = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)])
self.out_proj = nn.LayerList([nn.Linear(dim, dim) for i in range(num_layers)])
self.ffn1 = nn.LayerList([nn.Linear(dim, dim*4) for i in range(num_layers)])
self.ffn2 = nn.LayerList([nn.Linear(dim*4, dim) for i in range(num_layers)])
self.ffn1 = nn.LayerList([nn.Linear(dim, dim * 4) for i in range(num_layers)])
self.ffn2 = nn.LayerList([nn.Linear(dim * 4, dim) for i in range(num_layers)])

def forward(self, hidden_states, timesteps, class_labels):

# below code are copied from PaddleMIX/ppdiffusers/ppdiffusers/models/embeddings.py
num_channels = 256
max_period = 10000
Expand All @@ -49,8 +49,8 @@ def forward(self, hidden_states, timesteps, class_labels):
emb = timesteps[:, None].cast("float32") * emb[None, :]
emb = paddle.concat([paddle.cos(emb), paddle.sin(emb)], axis=-1)
common_emb = emb.cast(hidden_states.dtype)
for i in range(self.num_layers):

for i in range(self.num_layers):
emb = self.fcs0[i](common_emb)
emb = F.silu(emb)
emb = self.fcs1[i](emb)
Expand All @@ -59,24 +59,29 @@ def forward(self, hidden_states, timesteps, class_labels):
emb = self.fcs2[i](emb)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1)
import paddlemix
norm_hidden_states =paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_msa, shift_msa)
q = self.q[i](norm_hidden_states).reshape([0,0,self.heads_num,self.head_dim])
k = self.k[i](norm_hidden_states).reshape([0,0,self.heads_num,self.head_dim])
v = self.v[i](norm_hidden_states).reshape([0,0,self.heads_num,self.head_dim])
norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_msa, shift_msa)
q = self.q[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim])
k = self.k[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim])
v = self.v[i](norm_hidden_states).reshape([0, 0, self.heads_num, self.head_dim])

norm_hidden_states = F.scaled_dot_product_attention_(q, k, v, scale=self.head_dim**-0.5)
norm_hidden_states = norm_hidden_states.reshape([norm_hidden_states.shape[0],norm_hidden_states.shape[1],self.dim])
norm_hidden_states = norm_hidden_states.reshape([norm_hidden_states.shape[0], norm_hidden_states.shape[1], self.dim])
norm_hidden_states = self.out_proj[i](norm_hidden_states)
# hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([b,1,self.dim])

# hidden_states = hidden_states + norm_hidden_states * gate_msa.reshape([b,1,self.dim])
# norm_hidden_states =paddlemix.triton_ops.adaptive_layer_norm(hidden_states, scale_mlp, shift_mlp)
hidden_states,norm_hidden_states =paddlemix.triton_ops.fused_adaLN_scale_residual(hidden_states, norm_hidden_states, gate_msa, scale_mlp, shift_mlp)
hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual(
hidden_states,
norm_hidden_states,
gate_msa,
scale_mlp,
shift_mlp
)

norm_hidden_states = self.ffn1[i](norm_hidden_states)
norm_hidden_states = F.gelu(norm_hidden_states, approximate=True)
norm_hidden_states = self.ffn2[i](norm_hidden_states)

hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([norm_hidden_states.shape[0],1,self.dim])
hidden_states = hidden_states + norm_hidden_states * gate_mlp.reshape([norm_hidden_states.shape[0], 1, self.dim])

return hidden_states

36 changes: 20 additions & 16 deletions ppdiffusers/ppdiffusers/models/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import os



@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
Expand Down Expand Up @@ -118,8 +117,8 @@ def __init__(
self.inner_dim = inner_dim = num_attention_heads * attention_head_dim
self.data_format = data_format

self.Inference_Optimize = os.getenv('Inference_Optimize') == "True"
self.inference_optimize = os.getenv('INFOPTIMIZE') == "True"

conv_cls = nn.Conv2D if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear

Expand Down Expand Up @@ -219,14 +218,20 @@ def __init__(
for d in range(num_layers)
]
)
if self.Inference_Optimize:
self.simplified_facebookDIT = SimplifiedFacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim)
self.simplified_facebookDIT = paddle.incubate.jit.inference(self.simplified_facebookDIT,
enable_new_ir=True,
cache_static_model=False,
exp_enable_use_cutlass=True,
delete_pass_lists=["add_norm_fuse_pass"],
)
if self.inference_optimize:
self.simplified_facebookDIT = SimplifiedFacebookDIT(
num_layers,
inner_dim,
num_attention_heads,
attention_head_dim
)
self.simplified_facebookDIT = paddle.incubate.jit.inference(
self.simplified_facebookDIT,
enable_new_ir=True,
cache_static_model=False,
exp_enable_use_cutlass=True,
delete_pass_lists=["add_norm_fuse_pass"],
)

# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
Expand Down Expand Up @@ -264,7 +269,6 @@ def __init__(
self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)

self.gradient_checkpointing = False


def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
Expand Down Expand Up @@ -399,9 +403,9 @@ def forward(
batch_size = hidden_states.shape[0]
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]])
if self.Inference_Optimize:
hidden_states =self.simplified_facebookDIT(hidden_states, timestep, class_labels)

if self.inference_optimize:
hidden_states = self.simplified_facebookDIT(hidden_states, timestep, class_labels)
else:
for block in self.transformer_blocks:
if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute():
Expand Down Expand Up @@ -503,7 +507,7 @@ def custom_forward(*inputs):

@classmethod
def custom_modify_weight(cls, state_dict):
if os.getenv('Inference_Optimize') != "True":
if os.getenv('INFOPTIMIZE') != "True":
return
map_from_my_dit = {}
for i in range(28):
Expand Down

0 comments on commit bfe8c41

Please sign in to comment.