-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathedit_image.py
153 lines (113 loc) · 5.64 KB
/
edit_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import time
from utils.debug_utils import enable_deterministic
enable_deterministic()
import torch
from pathlib import Path
import cv2
import argparse
from modules import load_diffusion_model, load_inverter, load_editor
from modules.inversion.diffusion_inversion import DiffusionInversion
from modules import StablePreprocess, StablePostProc
from diffusers import StableDiffusionPipeline
from typing import List, Tuple
from utils.utils import add_argparse_arg
def split_to_words(prompt: str) -> List[str]:
"""Split prompt to words
Args:
prompt (str): Prompt to split
Returns:
List[str]: Words
"""
# remove trailing dot
if prompt[-1] == ".":
prompt = prompt[:-1]
return prompt.split(" ")
def get_edit_word(source_prompt: str, target_prompt: str) -> Tuple[str, str]:
"""Get word which differs in source and target prompt
Args:
source_prompt (str): Source prompt
target_prompt (str): Target prompt
Returns:
Tuple[str, str]: Different word
"""
source_prompt = split_to_words(source_prompt)
target_prompt = split_to_words(target_prompt)
if len(source_prompt) != len(target_prompt):
return None
diffs = [(s, t) for s, t in zip(source_prompt, target_prompt) if s != t]
if len(diffs) != 1:
return None
return diffs[0]
@torch.no_grad()
def main(input: str, model: str, source_prompt: str, target_prompt: str, output: str, inv_method: str, edit_method: str,
scheduler: str, steps: int, guidance_scale_bwd: float, guidance_scale_fwd: float, edit_cfg: str, prec: str) -> None:
enable_deterministic()
input = Path(input)
if output is None:
# default output path
output = str(input.parent / (input.name + "_inv" + input.suffix))
device = "cuda"
# load models
ldm_stable, (preproc, postproc) = load_diffusion_model(model, device, variant=prec)
if edit_cfg is None:
# Using a default config for prompt-to-prompt if no edit_cfg yaml is specified
if edit_method in ("ptp", "etaedit"):
# Get blend word
blended_word = get_edit_word(source_prompt, target_prompt)
if blended_word is None:
print("Provide a edit_cfg for prompt-to-prompt if source and target prompt differ in more than one word.")
return
edit_cfg = dict(
is_replace_controller=False,
prompts = [source_prompt, target_prompt],
cross_replace_steps={'default_': .4,},
self_replace_steps=0.6,
blend_words=(((blended_word[0], ),
(blended_word[1], ))) if len(blended_word) else None,
equilizer_params={
"words": (blended_word[1], ),
"values": (2, )
} if len(blended_word) else None,
)
print(f"Using default ptp config:\n{edit_cfg}")
else:
edit_cfg = None
# load inverter and editor module
inverter = load_inverter(model=ldm_stable, type=inv_method, scheduler=scheduler, num_inference_steps=steps, guidance_scale_bwd=guidance_scale_bwd, guidance_scale_fwd=guidance_scale_fwd)
editor = load_editor(inverter=inverter, type=edit_method)
image = preproc(input) # load and preprocess image
edit_word_idx_src = next((i for i, (s, t) in enumerate(zip(source_prompt.split(" "), target_prompt.split(" "))) if s != t), None)
inv_cfg = dict(edit_word_idx=(edit_word_idx_src, edit_word_idx_src))
t1 = time.time()
edit_res = editor.edit(image, source_prompt, target_prompt, cfg=edit_cfg, inv_cfg=inv_cfg) # edit image
t2 = time.time()
img_edit = postproc(edit_res["image"]) # postprocess output
# save result
cv2.imwrite(output, cv2.cvtColor(img_edit, cv2.COLOR_RGB2BGR))
if "image_inv" in edit_res:
img_inv = postproc(edit_res["image_inv"]) # postprocess output
output_inv = Path(output)
output_inv = output_inv.parent / (output_inv.stem + "_inv" + output_inv.suffix)
# save result
cv2.imwrite(str(output_inv), cv2.cvtColor(img_inv, cv2.COLOR_RGB2BGR))
print(f"Saved result to {output}")
print(f"Took {t2 - t1}s")
def parse_args():
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, description="Edits a single image.")
parser.add_argument("--input", required=True, help="Path to image to invert.")
parser.add_argument("--model", default="CompVis/stable-diffusion-v1-4", help="Diffusion Model.")
parser.add_argument("--source_prompt", required=True, help="Prompt to use for inversion.")
parser.add_argument("--target_prompt", required=True, help="Prompt to use for inversion.")
parser.add_argument("--output", help="Path for output image.")
add_argparse_arg(parser, "--inv_method")
add_argparse_arg(parser, "--edit_method")
parser.add_argument("--edit_cfg", help="Path to yaml file for editor configuration. Often needed for prompt-to-prompt.")
parser.add_argument("--scheduler", help="Which scheduler to use.", choices=DiffusionInversion.get_available_schedulers())
parser.add_argument("--steps", type=int, help="How many diffusion steps to use.")
parser.add_argument("--guidance_scale_bwd", type=int, help="Classifier free guidance scale to use for backward diffusion (denoising).")
parser.add_argument("--guidance_scale_fwd", type=int, help="Classifier free guidance scale to use for forward diffusion (inversion).")
parser.add_argument("--prec", choices=["fp16", "fp32"], help="Precision for diffusion.")
args = parser.parse_args()
return vars(args)
if __name__ == "__main__":
main(**parse_args())