Skip to content

Commit

Permalink
Merge pull request ddPn08#81 from Isotr0py/hires_fix
Browse files Browse the repository at this point in the history
Add Feature: Hires.fix
  • Loading branch information
ddPn08 authored May 14, 2023
2 parents 51eb7b5 + a232bc6 commit 0eb692e
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 7 deletions.
3 changes: 3 additions & 0 deletions api/models/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class ImageGenerationOptions:
strength: Optional[float] = 1.0

image: PIL.Image.Image = field(default_factory=PIL.Image.Image)
hiresfix: bool = False
hiresfix_mode: str = "bilinear"
hiresfix_scale: float = 1.5

def dict(self):
return asdict(self)
Expand Down
28 changes: 24 additions & 4 deletions modules/components/image_generation_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,35 @@ def common_options_ui():
)


def hires_options_ui():
with gr.Row():
with gr.Accordion("Hires.fix", open=False):
enable_upscale = gr.Checkbox(label="Hires.fix")
with gr.Row():
upscaler_mode = gr.Dropdown(
choices=[
"bilinear",
"bilinear-antialiased",
"bicubic",
"bicubic-antialiased",
"nearest",
"nearest-exact",
],
value="bilinear",
label="Latent upscaler mode",
)
scale_slider = gr.Slider(
value=1.5, minimum=1, maximum=4, step=0.05, label="Upscale by"
)
return enable_upscale, upscaler_mode, scale_slider


def img2img_options_ui():
with gr.Column():
with gr.Accordion("Img2Img", open=False):
init_image = gr.Image(label="Init Image", type="pil")
strength_slider = gr.Slider(
value=0.5,
minimum=0,
maximum=1,
step=0.01,
value=0.5, minimum=0, maximum=1, step=0.01, label="Strength"
)
return init_image, strength_slider

Expand Down
33 changes: 33 additions & 0 deletions modules/diffusion/pipelines/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(

self.plugin_data = None
self.opts = None
self.stage_1st = None

def to(self, device: torch.device = None, dtype: torch.dtype = None):
if device is None:
Expand Down Expand Up @@ -384,6 +385,34 @@ def __call__(
self.plugin_data = plugin_data
self.opts = opts

# Hires.fix
if opts.hiresfix:
opts.hiresfix, self.stage_1st = False, True
opts.image = self.__call__(
opts,
generator,
eta,
latents,
prompt_embeds,
negative_prompt_embeds,
"latent",
return_dict,
callback,
callback_steps,
cross_attention_kwargs,
plugin_data,
).images
opts.height = int(opts.height * opts.hiresfix_scale)
opts.width = int(opts.width * opts.hiresfix_scale)

opts.image = torch.nn.functional.interpolate(
opts.image,
(opts.height // 8, opts.width // 8),
mode=opts.hiresfix_mode.split("-")[0],
antialias=True if "antialiased" in opts.hiresfix_mode else False,
)
opts.image = self.create_output(opts.image, "pil", True).images[0]

# 1. Define call parameters
num_images_per_prompt = 1
prompt = [opts.prompt] * opts.batch_size
Expand Down Expand Up @@ -465,6 +494,10 @@ def __call__(
for enterer in enterers:
enterer.__exit__(None, None, None)

if self.stage_1st:
self.stage_1st = None
return outputs

self.plugin_data = None
self.opts = None

Expand Down
21 changes: 18 additions & 3 deletions modules/tabs/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ def wrapper(self, data):
seed,
width,
height,
hiresfix,
hiresfix_mode,
hiresfix_scale,
init_image,
strength,
) = as_list[0:12]
) = as_list[0:15]

plugin_values = dict(list(data.items())[12:])
plugin_values = dict(list(data.items())[15:])

opts = ImageGenerationOptions(
prompt=prompt,
Expand All @@ -41,6 +44,9 @@ def wrapper(self, data):
strength=strength,
seed=seed,
image=init_image,
hiresfix=hiresfix,
hiresfix_mode=hiresfix_mode,
hiresfix_scale=hiresfix_scale,
)
yield from fn(self, opts, plugin_values)

Expand Down Expand Up @@ -77,10 +83,18 @@ def generate_image(self, opts, plugin_values):

count = 0

# pre-calculate inference steps
if opts.hiresfix:
inference_steps = opts.num_inference_steps + int(
opts.num_inference_steps * opts.strength
)
else:
inference_steps = opts.num_inference_steps

for data in model_manager.sd_model(opts, plugin_data):
if type(data) == tuple:
step, preview = data
progress = step / (opts.batch_count * opts.num_inference_steps)
progress = step / (opts.batch_count * inference_steps)
previews = []
for images, opts in preview:
previews.extend(images)
Expand Down Expand Up @@ -115,6 +129,7 @@ def ui(self, outlet):
with gr.Column(scale=1.25):
options = image_generation_options.common_options_ui()

options += image_generation_options.hires_options_ui()
options += image_generation_options.img2img_options_ui()

plugin_values = image_generation_options.plugin_options_ui()
Expand Down

0 comments on commit 0eb692e

Please sign in to comment.