Skip to content

Commit

Permalink
update sd3 batch parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
chang-wenbin committed Oct 10, 2024
1 parent 62155dd commit 77f0726
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
5 changes: 3 additions & 2 deletions ppdiffusers/deploy/sd3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ python text_to_image_generation-stable_diffusion_3.py --dtype float16 --height

## Paddle Stable Diffusion 3 模型多卡推理:
### batch parallel 实现原理
- 由于SD3 MMDiT部分的输入batch为2,所以我们考虑在多卡并行的方案中,将batch为2的输入拆分到两张卡上进行计算,这样单卡的计算量就减少为原来的一半,降低了单卡所承载的浮点计算量。
计算完成后,我们再把两张卡的计算结果 聚合在一起,结果与单卡计算完全一致。
- 在SD3中,对于输入是一个prompt时,使用CFG需要同时进行unconditional guide和text guide的生成,此时 MM-DiT-blocks 的输入batch_size=2;
所以我们考虑在多卡并行的方案中,将batch为2的输入拆分到两张卡上进行计算,这样单卡的计算量就减少为原来的一半,降低了单卡所承载的浮点计算量。
计算完成后,我们再把两张卡的计算结果 聚合在一起,结果与单卡计算完全一致。
### 开启多卡推理方法
- Paddle Inference 提供了SD3模型的多卡推理功能,用户可以通过设置 `--inference_optimize_bp 1` 来开启这一功能,
使用 `python -m paddle.distributed.launch --gpus 0,1` 指定使用哪些卡进行推理。
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def encode_prompt(

prompt_embeds = paddle.concat([clip_prompt_embeds, t5_prompt_embed], axis=-2)
pooled_prompt_embeds = paddle.concat([pooled_prompt_embed, pooled_prompt_2_embed], axis=-1)

breakpoint()
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
Expand Down Expand Up @@ -801,8 +801,7 @@ def __call__(
latent_model_input = paddle.concat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])

if self.inference_optimize_bp:
if self.inference_optimize_bp and self.do_classifier_free_guidance:
latent_input ,latent_model_input_ = paddle.split(latent_model_input,2,axis=0)
timestep_input ,timestep_ = paddle.split(timestep,2,axis=0)
prompt_embeds_input ,prompt_embeds_ = paddle.split(prompt_embeds,2,axis=0)
Expand Down

0 comments on commit 77f0726

Please sign in to comment.