forked from ShineChen1024/MagicClothing
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gradio_ipadapter_faceid.py
101 lines (84 loc) · 5.44 KB
/
gradio_ipadapter_faceid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
from PIL import Image
import torch
from diffusers import UniPCMultistepScheduler, AutoencoderKL
from diffusers.pipelines import StableDiffusionPipeline
import gradio as gr
import argparse
import cv2
from pipelines.OmsDiffusionPipeline import OmsDiffusionPipeline
parser = argparse.ArgumentParser(description='oms diffusion')
parser.add_argument('--model_path', type=str, required=True)
parser.add_argument('--pipe_path', type=str, default="SG161222/Realistic_Vision_V4.0_noVAE")
parser.add_argument('--enable_cloth_guidance', type=str, default=True)
parser.add_argument('--faceid_version', type=str, default="FaceIDPlusV2", choices=['FaceID', 'FaceIDPlus', 'FaceIDPlusV2'])
args = parser.parse_args()
device = "cuda"
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dtype=torch.float16)
if args.enable_cloth_guidance:
pipe = OmsDiffusionPipeline.from_pretrained(args.pipe_path, vae=vae, torch_dtype=torch.float16)
else:
pipe = StableDiffusionPipeline.from_pretrained(args.pipe_path, vae=vae, torch_dtype=torch.float16)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
if args.faceid_version == "FaceID":
ip_lora = "./checkpoints/ipadapter_faceid/ip-adapter-faceid_sd15_lora.safetensors"
ip_ckpt = "./checkpoints/ipadapter_faceid/ip-adapter-faceid_sd15.bin"
pipe.load_lora_weights(ip_lora)
pipe.fuse_lora()
from garment_adapter.garment_ipadapter_faceid import IPAdapterFaceID
ip_model = IPAdapterFaceID(pipe, args.model_path, ip_ckpt, device, args.enable_cloth_guidance)
else:
if args.faceid_version == "FaceIDPlus":
ip_ckpt = "./checkpoints/ipadapter_faceid/ip-adapter-faceid-plus_sd15.bin"
ip_lora = "./checkpoints/ipadapter_faceid/ip-adapter-faceid-plus_sd15_lora.safetensors"
v2 = False
else:
ip_ckpt = "./checkpoints/ipadapter_faceid/ip-adapter-faceid-plusv2_sd15.bin"
ip_lora = "./checkpoints/ipadapter_faceid/ip-adapter-faceid-plusv2_sd15_lora.safetensors"
v2 = True
pipe.load_lora_weights(ip_lora)
pipe.fuse_lora()
image_encoder_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
from garment_adapter.garment_ipadapter_faceid import IPAdapterFaceIDPlus as IPAdapterFaceID
ip_model = IPAdapterFaceID(pipe, args.model_path, image_encoder_path, ip_ckpt, device, args.enable_cloth_guidance)
def process(cloth_image, face_img, cloth_mask_image, prompt, a_prompt, n_prompt, num_samples, width, height, sample_steps, scale, cloth_guidance_scale, seed):
if args.faceid_version == "FaceID":
result = ip_model.generate(cloth_image, face_img, cloth_mask_image, prompt, a_prompt, n_prompt, num_samples, seed, scale, cloth_guidance_scale, sample_steps, height, width)
else:
result = ip_model.generate(cloth_image, face_img, cloth_mask_image, prompt, a_prompt, n_prompt, num_samples, seed, scale, cloth_guidance_scale, sample_steps, height, width, shortcut=v2)
if result is None:
raise gr.Error("人脸检测异常,尝试其他肖像")
else:
images, cloth_mask_image = result
return images, cloth_mask_image
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("##You can enlarge image resolution to get better face, but the cloth maybe lose control, we will release high-resolution checkpoint soon##")
with gr.Row():
with gr.Column():
face_img = gr.Image(label="face Image", type="pil")
cloth_image = gr.Image(label="cloth Image", type="pil")
cloth_mask_image = gr.Image(label="cloth mask Image, if not support, will be produced by inner segment algorithm", type="pil")
prompt = gr.Textbox(label="Prompt", value='a photography')
run_button = gr.Button(value="Run")
with gr.Accordion("Advanced options", open=False):
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
height = gr.Slider(label="Height", minimum=256, maximum=1024, value=768, step=64)
width = gr.Slider(label="Width", minimum=192, maximum=768, value=576, step=64)
sample_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
if args.enable_cloth_guidance:
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10., value=3., step=0.1)
cloth_guidance_scale = gr.Slider(label="Cloth guidance Scale", minimum=1, maximum=10., value=3., step=0.1)
else:
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10., value=2.5, step=0.1)
cloth_guidance_scale = None
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=1234)
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, high quality')
n_prompt = gr.Textbox(label="Negative Prompt", value='bare, monochrome, lowres, bad anatomy, worst quality, low quality')
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery")
cloth_seg_image = gr.Image(label="cloth mask", type="pil", width=192, height=256)
ips = [cloth_image, face_img, cloth_mask_image, prompt, a_prompt, n_prompt, num_samples, width, height, sample_steps, guidance_scale, cloth_guidance_scale, seed]
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, cloth_seg_image])
block.launch(server_name="0.0.0.0", server_port=7860)