forked from huggingface/controlnet_aux
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Teed, Lineart Standard and Anyline added (huggingface#105)
* teed, lineart_standard and anyline added * removed test save
Showing
9 changed files
with
683 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,20 @@ | ||
__version__ = "0.0.8" | ||
|
||
from .anyline import AnylineDetector | ||
from .canny import CannyDetector | ||
from .dwpose import DWposeDetector | ||
from .hed import HEDdetector | ||
from .leres import LeresDetector | ||
from .lineart import LineartDetector | ||
from .lineart_anime import LineartAnimeDetector | ||
from .lineart_standard import LineartStandardDetector | ||
from .mediapipe_face import MediapipeFaceDetector | ||
from .midas import MidasDetector | ||
from .mlsd import MLSDdetector | ||
from .normalbae import NormalBaeDetector | ||
from .open_pose import OpenposeDetector | ||
from .pidi import PidiNetDetector | ||
from .zoe import ZoeDetector | ||
|
||
from .canny import CannyDetector | ||
from .mediapipe_face import MediapipeFaceDetector | ||
from .segment_anything import SamDetector | ||
from .shuffle import ContentShuffleDetector | ||
from .dwpose import DWposeDetector | ||
from .teed import TEEDdetector | ||
from .zoe import ZoeDetector |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# code based in https://github.com/TheMistoAI/ComfyUI-Anyline/blob/main/anyline.py | ||
import os | ||
|
||
import cv2 | ||
import numpy as np | ||
import torch | ||
from einops import rearrange | ||
from huggingface_hub import hf_hub_download | ||
from PIL import Image | ||
from skimage import morphology | ||
|
||
from ..teed.ted import TED | ||
from ..util import HWC3, resize_image, safe_step | ||
|
||
|
||
class AnylineDetector: | ||
def __init__(self, model): | ||
self.model = model | ||
|
||
@classmethod | ||
def from_pretrained(cls, pretrained_model_or_path, filename=None, subfolder=None): | ||
if os.path.isdir(pretrained_model_or_path): | ||
model_path = os.path.join(pretrained_model_or_path, filename) | ||
else: | ||
model_path = hf_hub_download( | ||
pretrained_model_or_path, filename, subfolder=subfolder | ||
) | ||
|
||
model = TED() | ||
model.load_state_dict(torch.load(model_path, map_location="cpu")) | ||
|
||
return cls(model) | ||
|
||
def to(self, device): | ||
self.model.to(device) | ||
return self | ||
|
||
def __call__( | ||
self, | ||
input_image, | ||
detect_resolution=1280, | ||
guassian_sigma=2.0, | ||
intensity_threshold=3, | ||
output_type="pil", | ||
): | ||
device = next(iter(self.model.parameters())).device | ||
|
||
if not isinstance(input_image, np.ndarray): | ||
input_image = np.array(input_image, dtype=np.uint8) | ||
output_type = output_type or "pil" | ||
else: | ||
output_type = output_type or "np" | ||
|
||
original_height, original_width, _ = input_image.shape | ||
|
||
input_image = HWC3(input_image) | ||
input_image = resize_image(input_image, detect_resolution) | ||
|
||
assert input_image.ndim == 3 | ||
height, width, _ = input_image.shape | ||
with torch.no_grad(): | ||
image_teed = torch.from_numpy(input_image.copy()).float().to(device) | ||
image_teed = rearrange(image_teed, "h w c -> 1 c h w") | ||
edges = self.model(image_teed) | ||
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] | ||
edges = [ | ||
cv2.resize(e, (width, height), interpolation=cv2.INTER_LINEAR) | ||
for e in edges | ||
] | ||
edges = np.stack(edges, axis=2) | ||
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) | ||
edge = safe_step(edge, 2) | ||
edge = (edge * 255.0).clip(0, 255).astype(np.uint8) | ||
|
||
mteed_result = edge | ||
mteed_result = HWC3(mteed_result) | ||
|
||
x = input_image.astype(np.float32) | ||
g = cv2.GaussianBlur(x, (0, 0), guassian_sigma) | ||
intensity = np.min(g - x, axis=2).clip(0, 255) | ||
intensity /= max(16, np.median(intensity[intensity > intensity_threshold])) | ||
intensity *= 127 | ||
lineart_result = intensity.clip(0, 255).astype(np.uint8) | ||
|
||
lineart_result = HWC3(lineart_result) | ||
|
||
lineart_result = self.get_intensity_mask( | ||
lineart_result, lower_bound=0, upper_bound=255 | ||
) | ||
|
||
cleaned = morphology.remove_small_objects( | ||
lineart_result.astype(bool), min_size=36, connectivity=1 | ||
) | ||
lineart_result = lineart_result * cleaned | ||
final_result = self.combine_layers(mteed_result, lineart_result) | ||
|
||
final_result = cv2.resize( | ||
final_result, | ||
(original_width, original_height), | ||
interpolation=cv2.INTER_LINEAR, | ||
) | ||
|
||
if output_type == "pil": | ||
final_result = Image.fromarray(final_result) | ||
|
||
return final_result | ||
|
||
def get_intensity_mask(self, image_array, lower_bound, upper_bound): | ||
mask = image_array[:, :, 0] | ||
mask = np.where((mask >= lower_bound) & (mask <= upper_bound), mask, 0) | ||
mask = np.expand_dims(mask, 2).repeat(3, axis=2) | ||
return mask | ||
|
||
def combine_layers(self, base_layer, top_layer): | ||
mask = top_layer.astype(bool) | ||
temp = 1 - (1 - top_layer) * (1 - base_layer) | ||
result = base_layer * (~mask) + temp * mask | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Code based based from the repository comfyui_controlnet_aux: | ||
# https://github.com/Fannovel16/comfyui_controlnet_aux/blob/main/src/controlnet_aux/lineart_standard/__init__.py | ||
import cv2 | ||
import numpy as np | ||
from PIL import Image | ||
|
||
from ..util import HWC3, resize_image | ||
|
||
|
||
class LineartStandardDetector: | ||
def __call__( | ||
self, | ||
input_image=None, | ||
guassian_sigma=6.0, | ||
intensity_threshold=8, | ||
detect_resolution=512, | ||
output_type="pil", | ||
): | ||
if not isinstance(input_image, np.ndarray): | ||
input_image = np.array(input_image, dtype=np.uint8) | ||
else: | ||
output_type = output_type or "np" | ||
|
||
original_height, original_width, _ = input_image.shape | ||
|
||
input_image = HWC3(input_image) | ||
input_image = resize_image(input_image, detect_resolution) | ||
|
||
x = input_image.astype(np.float32) | ||
g = cv2.GaussianBlur(x, (0, 0), guassian_sigma) | ||
intensity = np.min(g - x, axis=2).clip(0, 255) | ||
intensity /= max(16, np.median(intensity[intensity > intensity_threshold])) | ||
intensity *= 127 | ||
detected_map = intensity.clip(0, 255).astype(np.uint8) | ||
|
||
detected_map = HWC3(detected_map) | ||
|
||
detected_map = cv2.resize( | ||
detected_map, | ||
(original_width, original_height), | ||
interpolation=cv2.INTER_CUBIC, | ||
) | ||
|
||
if output_type == "pil": | ||
detected_map = Image.fromarray(detected_map) | ||
|
||
return detected_map |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
""" | ||
Script based on: | ||
Wang, Xueliang, Honge Ren, and Achuan Wang. | ||
"Smish: A Novel Activation Function for Deep Learning Methods. | ||
" Electronics 11.4 (2022): 540. | ||
""" | ||
|
||
# import pytorch | ||
import torch | ||
|
||
|
||
@torch.jit.script | ||
def smish(input): | ||
""" | ||
Applies the mish function element-wise: | ||
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(sigmoid(x)))) | ||
See additional documentation for mish class. | ||
""" | ||
return input * torch.tanh(torch.log(1 + torch.sigmoid(input))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2022 Xavier Soria Poma | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
""" | ||
Script based on: | ||
Wang, Xueliang, Honge Ren, and Achuan Wang. | ||
"Smish: A Novel Activation Function for Deep Learning Methods. | ||
" Electronics 11.4 (2022): 540. | ||
smish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + sigmoid(x))) | ||
""" | ||
|
||
# import pytorch | ||
# import activation functions | ||
from torch import nn | ||
|
||
from .Fsmish import smish | ||
|
||
|
||
class Smish(nn.Module): | ||
""" | ||
Applies the mish function element-wise: | ||
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) | ||
Shape: | ||
- Input: (N, *) where * means, any number of additional | ||
dimensions | ||
- Output: (N, *), same shape as the input | ||
Examples: | ||
>>> m = Mish() | ||
>>> input = torch.randn(2) | ||
>>> output = m(input) | ||
Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
Init method. | ||
""" | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
""" | ||
Forward pass of the function. | ||
""" | ||
return smish(input) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import os | ||
|
||
import cv2 | ||
import numpy as np | ||
import torch | ||
from einops import rearrange | ||
from huggingface_hub import hf_hub_download | ||
from PIL import Image | ||
|
||
from ..util import HWC3, resize_image, safe_step | ||
from .ted import TED | ||
|
||
|
||
class TEEDdetector: | ||
def __init__(self, model): | ||
self.model = model | ||
|
||
@classmethod | ||
def from_pretrained(cls, pretrained_model_or_path, filename=None, subfolder=None): | ||
if os.path.isdir(pretrained_model_or_path): | ||
model_path = os.path.join(pretrained_model_or_path, filename) | ||
else: | ||
model_path = hf_hub_download( | ||
pretrained_model_or_path, filename, subfolder=subfolder | ||
) | ||
|
||
model = TED() | ||
model.load_state_dict(torch.load(model_path, map_location="cpu")) | ||
|
||
return cls(model) | ||
|
||
def to(self, device): | ||
self.model.to(device) | ||
return self | ||
|
||
def __call__( | ||
self, | ||
input_image, | ||
detect_resolution=512, | ||
safe_steps=2, | ||
output_type="pil", | ||
): | ||
device = next(iter(self.model.parameters())).device | ||
if not isinstance(input_image, np.ndarray): | ||
input_image = np.array(input_image, dtype=np.uint8) | ||
output_type = output_type or "pil" | ||
else: | ||
output_type = output_type or "np" | ||
|
||
original_height, original_width, _ = input_image.shape | ||
|
||
input_image = HWC3(input_image) | ||
input_image = resize_image(input_image, detect_resolution) | ||
|
||
assert input_image.ndim == 3 | ||
height, width, _ = input_image.shape | ||
with torch.no_grad(): | ||
image_teed = torch.from_numpy(input_image.copy()).float().to(device) | ||
image_teed = rearrange(image_teed, "h w c -> 1 c h w") | ||
edges = self.model(image_teed) | ||
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] | ||
edges = [ | ||
cv2.resize(e, (width, height), interpolation=cv2.INTER_LINEAR) | ||
for e in edges | ||
] | ||
edges = np.stack(edges, axis=2) | ||
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) | ||
if safe_steps != 0: | ||
edge = safe_step(edge, safe_steps) | ||
edge = (edge * 255.0).clip(0, 255).astype(np.uint8) | ||
|
||
detected_map = edge | ||
detected_map = HWC3(detected_map) | ||
|
||
detected_map = cv2.resize( | ||
detected_map, | ||
(original_width, original_height), | ||
interpolation=cv2.INTER_LINEAR, | ||
) | ||
|
||
if output_type == "pil": | ||
detected_map = Image.fromarray(detected_map) | ||
|
||
return detected_map |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,332 @@ | ||
# Original from: https://github.com/xavysp/TEED | ||
# TEED: is a Tiny but Efficient Edge Detection, it comes from the LDC-B3 | ||
# with a Slightly modification | ||
# LDC parameters: | ||
# 155665 | ||
# TED > 58K | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from .Fsmish import smish as Fsmish | ||
from .Xsmish import Smish | ||
|
||
|
||
def weight_init(m): | ||
if isinstance(m, (nn.Conv2d,)): | ||
torch.nn.init.xavier_normal_(m.weight, gain=1.0) | ||
|
||
if m.bias is not None: | ||
torch.nn.init.zeros_(m.bias) | ||
|
||
# for fusion layer | ||
if isinstance(m, (nn.ConvTranspose2d,)): | ||
torch.nn.init.xavier_normal_(m.weight, gain=1.0) | ||
if m.bias is not None: | ||
torch.nn.init.zeros_(m.bias) | ||
|
||
|
||
class CoFusion(nn.Module): | ||
# from LDC | ||
|
||
def __init__(self, in_ch, out_ch): | ||
super(CoFusion, self).__init__() | ||
self.conv1 = nn.Conv2d( | ||
in_ch, 32, kernel_size=3, stride=1, padding=1 | ||
) # before 64 | ||
self.conv3 = nn.Conv2d( | ||
32, out_ch, kernel_size=3, stride=1, padding=1 | ||
) # before 64 instead of 32 | ||
self.relu = nn.ReLU() | ||
self.norm_layer1 = nn.GroupNorm(4, 32) # before 64 | ||
|
||
def forward(self, x): | ||
# fusecat = torch.cat(x, dim=1) | ||
attn = self.relu(self.norm_layer1(self.conv1(x))) | ||
attn = F.softmax(self.conv3(attn), dim=1) | ||
return ((x * attn).sum(1)).unsqueeze(1) | ||
|
||
|
||
class CoFusion2(nn.Module): | ||
# TEDv14-3 | ||
def __init__(self, in_ch, out_ch): | ||
super(CoFusion2, self).__init__() | ||
self.conv1 = nn.Conv2d( | ||
in_ch, 32, kernel_size=3, stride=1, padding=1 | ||
) # before 64 | ||
# self.conv2 = nn.Conv2d(32, 32, kernel_size=3, | ||
# stride=1, padding=1)# before 64 | ||
self.conv3 = nn.Conv2d( | ||
32, out_ch, kernel_size=3, stride=1, padding=1 | ||
) # before 64 instead of 32 | ||
self.smish = Smish() # nn.ReLU(inplace=True) | ||
|
||
def forward(self, x): | ||
# fusecat = torch.cat(x, dim=1) | ||
attn = self.conv1(self.smish(x)) | ||
attn = self.conv3(self.smish(attn)) # before , )dim=1) | ||
|
||
# return ((fusecat * attn).sum(1)).unsqueeze(1) | ||
return ((x * attn).sum(1)).unsqueeze(1) | ||
|
||
|
||
class DoubleFusion(nn.Module): | ||
# TED fusion before the final edge map prediction | ||
def __init__(self, in_ch, out_ch): | ||
super(DoubleFusion, self).__init__() | ||
self.DWconv1 = nn.Conv2d( | ||
in_ch, in_ch * 8, kernel_size=3, stride=1, padding=1, groups=in_ch | ||
) # before 64 | ||
self.PSconv1 = nn.PixelShuffle(1) | ||
|
||
self.DWconv2 = nn.Conv2d( | ||
24, 24 * 1, kernel_size=3, stride=1, padding=1, groups=24 | ||
) # before 64 instead of 32 | ||
|
||
self.AF = Smish() # XAF() #nn.Tanh()# XAF() # # Smish()# | ||
|
||
def forward(self, x): | ||
# fusecat = torch.cat(x, dim=1) | ||
attn = self.PSconv1( | ||
self.DWconv1(self.AF(x)) | ||
) # #TEED best res TEDv14 [8, 32, 352, 352] | ||
|
||
attn2 = self.PSconv1( | ||
self.DWconv2(self.AF(attn)) | ||
) # #TEED best res TEDv14[8, 3, 352, 352] | ||
|
||
return Fsmish(((attn2 + attn).sum(1)).unsqueeze(1)) # TED best res | ||
|
||
|
||
class _DenseLayer(nn.Sequential): | ||
def __init__(self, input_features, out_features): | ||
super(_DenseLayer, self).__init__() | ||
|
||
( | ||
self.add_module( | ||
"conv1", | ||
nn.Conv2d( | ||
input_features, | ||
out_features, | ||
kernel_size=3, | ||
stride=1, | ||
padding=2, | ||
bias=True, | ||
), | ||
), | ||
) | ||
(self.add_module("smish1", Smish()),) | ||
self.add_module( | ||
"conv2", | ||
nn.Conv2d(out_features, out_features, kernel_size=3, stride=1, bias=True), | ||
) | ||
|
||
def forward(self, x): | ||
x1, x2 = x | ||
|
||
new_features = super(_DenseLayer, self).forward(Fsmish(x1)) # F.relu() | ||
|
||
return 0.5 * (new_features + x2), x2 | ||
|
||
|
||
class _DenseBlock(nn.Sequential): | ||
def __init__(self, num_layers, input_features, out_features): | ||
super(_DenseBlock, self).__init__() | ||
for i in range(num_layers): | ||
layer = _DenseLayer(input_features, out_features) | ||
self.add_module("denselayer%d" % (i + 1), layer) | ||
input_features = out_features | ||
|
||
|
||
class UpConvBlock(nn.Module): | ||
def __init__(self, in_features, up_scale): | ||
super(UpConvBlock, self).__init__() | ||
self.up_factor = 2 | ||
self.constant_features = 16 | ||
|
||
layers = self.make_deconv_layers(in_features, up_scale) | ||
assert layers is not None, layers | ||
self.features = nn.Sequential(*layers) | ||
|
||
def make_deconv_layers(self, in_features, up_scale): | ||
layers = [] | ||
all_pads = [0, 0, 1, 3, 7] | ||
for i in range(up_scale): | ||
kernel_size = 2**up_scale | ||
pad = all_pads[up_scale] # kernel_size-1 | ||
out_features = self.compute_out_features(i, up_scale) | ||
layers.append(nn.Conv2d(in_features, out_features, 1)) | ||
layers.append(Smish()) | ||
layers.append( | ||
nn.ConvTranspose2d( | ||
out_features, out_features, kernel_size, stride=2, padding=pad | ||
) | ||
) | ||
in_features = out_features | ||
return layers | ||
|
||
def compute_out_features(self, idx, up_scale): | ||
return 1 if idx == up_scale - 1 else self.constant_features | ||
|
||
def forward(self, x): | ||
return self.features(x) | ||
|
||
|
||
class SingleConvBlock(nn.Module): | ||
def __init__(self, in_features, out_features, stride, use_ac=False): | ||
super(SingleConvBlock, self).__init__() | ||
# self.use_bn = use_bs | ||
self.use_ac = use_ac | ||
self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride, bias=True) | ||
if self.use_ac: | ||
self.smish = Smish() | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
if self.use_ac: | ||
return self.smish(x) | ||
else: | ||
return x | ||
|
||
|
||
class DoubleConvBlock(nn.Module): | ||
def __init__( | ||
self, in_features, mid_features, out_features=None, stride=1, use_act=True | ||
): | ||
super(DoubleConvBlock, self).__init__() | ||
|
||
self.use_act = use_act | ||
if out_features is None: | ||
out_features = mid_features | ||
self.conv1 = nn.Conv2d(in_features, mid_features, 3, padding=1, stride=stride) | ||
self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1) | ||
self.smish = Smish() # nn.ReLU(inplace=True) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = self.smish(x) | ||
x = self.conv2(x) | ||
if self.use_act: | ||
x = self.smish(x) | ||
return x | ||
|
||
|
||
class TED(nn.Module): | ||
"""Definition of Tiny and Efficient Edge Detector | ||
model | ||
""" | ||
|
||
def __init__(self): | ||
super(TED, self).__init__() | ||
self.block_1 = DoubleConvBlock( | ||
3, | ||
16, | ||
16, | ||
stride=2, | ||
) | ||
self.block_2 = DoubleConvBlock(16, 32, use_act=False) | ||
self.dblock_3 = _DenseBlock(1, 32, 48) # [32,48,100,100] before (2, 32, 64) | ||
|
||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
|
||
# skip1 connection, see fig. 2 | ||
self.side_1 = SingleConvBlock(16, 32, 2) | ||
|
||
# skip2 connection, see fig. 2 | ||
self.pre_dense_3 = SingleConvBlock(32, 48, 1) # before (32, 64, 1) | ||
|
||
# USNet | ||
self.up_block_1 = UpConvBlock(16, 1) | ||
self.up_block_2 = UpConvBlock(32, 1) | ||
self.up_block_3 = UpConvBlock(48, 2) # (32, 64, 1) | ||
|
||
self.block_cat = DoubleFusion(3, 3) # TEED: DoubleFusion | ||
|
||
self.apply(weight_init) | ||
|
||
def slice(self, tensor, slice_shape): | ||
t_shape = tensor.shape | ||
img_h, img_w = slice_shape | ||
if img_w != t_shape[-1] or img_h != t_shape[2]: | ||
new_tensor = F.interpolate( | ||
tensor, size=(img_h, img_w), mode="bicubic", align_corners=False | ||
) | ||
|
||
else: | ||
new_tensor = tensor | ||
# tensor[..., :height, :width] | ||
return new_tensor | ||
|
||
def resize_input(self, tensor): | ||
t_shape = tensor.shape | ||
if t_shape[2] % 8 != 0 or t_shape[3] % 8 != 0: | ||
img_w = ((t_shape[3] // 8) + 1) * 8 | ||
img_h = ((t_shape[2] // 8) + 1) * 8 | ||
new_tensor = F.interpolate( | ||
tensor, size=(img_h, img_w), mode="bicubic", align_corners=False | ||
) | ||
else: | ||
new_tensor = tensor | ||
return new_tensor | ||
|
||
def crop_bdcn(data1, h, w, crop_h, crop_w): | ||
# Based on BDCN Implementation @ https://github.com/pkuCactus/BDCN | ||
_, _, h1, w1 = data1.size() | ||
assert h <= h1 and w <= w1 | ||
data = data1[:, :, crop_h : crop_h + h, crop_w : crop_w + w] | ||
return data | ||
|
||
def forward(self, x, single_test=False): | ||
assert x.ndim == 4, x.shape | ||
# supose the image size is 352x352 | ||
|
||
# Block 1 | ||
block_1 = self.block_1(x) # [8,16,176,176] | ||
block_1_side = self.side_1(block_1) # 16 [8,32,88,88] | ||
|
||
# Block 2 | ||
block_2 = self.block_2(block_1) # 32 # [8,32,176,176] | ||
block_2_down = self.maxpool(block_2) # [8,32,88,88] | ||
block_2_add = block_2_down + block_1_side # [8,32,88,88] | ||
|
||
# Block 3 | ||
block_3_pre_dense = self.pre_dense_3( | ||
block_2_down | ||
) # [8,64,88,88] block 3 L connection | ||
block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense]) # [8,64,88,88] | ||
|
||
# upsampling blocks | ||
out_1 = self.up_block_1(block_1) | ||
out_2 = self.up_block_2(block_2) | ||
out_3 = self.up_block_3(block_3) | ||
|
||
results = [out_1, out_2, out_3] | ||
|
||
# concatenate multiscale outputs | ||
block_cat = torch.cat(results, dim=1) # Bx6xHxW | ||
block_cat = self.block_cat(block_cat) # Bx1xHxW DoubleFusion | ||
|
||
results.append(block_cat) | ||
return results | ||
|
||
|
||
if __name__ == "__main__": | ||
batch_size = 8 | ||
img_height = 352 | ||
img_width = 352 | ||
|
||
# device = "cuda" if torch.cuda.is_available() else "cpu" | ||
device = "cpu" | ||
input = torch.rand(batch_size, 3, img_height, img_width).to(device) | ||
# target = torch.rand(batch_size, 1, img_height, img_width).to(device) | ||
print(f"input shape: {input.shape}") | ||
model = TED().to(device) | ||
output = model(input) | ||
print(f"output shapes: {[t.shape for t in output]}") | ||
|
||
# for i in range(20000): | ||
# print(i) | ||
# output = model(input) | ||
# loss = nn.MSELoss()(output[-1], target) | ||
# loss.backward() |