Skip to content

Commit

Permalink
fix mask
Browse files Browse the repository at this point in the history
  • Loading branch information
cene555 committed Nov 20, 2022
1 parent c536e0f commit eadda52
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
5 changes: 3 additions & 2 deletions natalle/natalle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from copy import deepcopy
import torch.nn.functional as F
import numpy as np
from .utils import prepare_image, q_sample, process_images
from .utils import prepare_image, q_sample, process_images, prepare_mask


class Natalle:
Expand Down Expand Up @@ -243,7 +243,8 @@ def generate_inpainting(self, prompt, pil_img, img_mask,
img_mask = torch.from_numpy(img_mask).unsqueeze(0).unsqueeze(0)
img_mask = F.interpolate(
img_mask, image_shape, mode="nearest",
).to(self.device)
)
img_mask = prepare_mask(img_mask).to(self.device)
if self.use_fp16:
img_mask = img_mask.half()
image = image.repeat(2, 1, 1, 1)
Expand Down
21 changes: 21 additions & 0 deletions natalle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,28 @@
import torch.nn as nn
import importlib
from .model.utils import get_named_beta_schedule, _extract_into_tensor
from copy import deepcopy

def prepare_mask(mask):
mask = mask.float()[0]
old_mask = deepcopy(mask)
for i in range(mask.shape[1]):
for j in range(mask.shape[2]):
if old_mask[0][i][j] == 1:
continue
if i != 0:
mask[:, i - 1:, j] = 0
if j != 0:
mask[:, i:, j - 1] = 0
if i != 0 and j != 0:
mask[:, i - 1:, j - 1] = 0
if i != mask.shape[1] - 1:
mask[:, i + 1:, j] = 0
if j != mask.shape[2] - 1:
mask[:, i:, j + 1] = 0
if i != mask.shape[1] - 1 and j != mask.shape[2] - 1:
mask[:, i + 1:, j + 1] = 0
return mask.unsqueeze(0)

def prepare_image(pil_image):
w, h = pil_image.size
Expand Down

0 comments on commit eadda52

Please sign in to comment.