Skip to content

Commit

Permalink
Teed, Lineart Standard and Anyline added (huggingface#105)
Browse files Browse the repository at this point in the history
* teed, lineart_standard and anyline added

* removed test save
asomoza authored May 22, 2024
1 parent 6367d57 commit 907ae1b
Showing 9 changed files with 683 additions and 8 deletions.
17 changes: 14 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -2,9 +2,9 @@

This is a PyPi installable package of [lllyasviel's ControlNet Annotators](https://github.com/lllyasviel/ControlNet/tree/main/annotator)

The code is copy-pasted from the respective folders in https://github.com/lllyasviel/ControlNet/tree/main/annotator and connected to [the 🤗 Hub](https://huggingface.co/lllyasviel/Annotators).
The code is copy-pasted from the respective folders in <https://github.com/lllyasviel/ControlNet/tree/main/annotator> and connected to [the 🤗 Hub](https://huggingface.co/lllyasviel/Annotators).

All credit & copyright goes to https://github.com/lllyasviel .
All credit & copyright goes to <https://github.com/lllyasviel> .

## Install

@@ -13,17 +13,19 @@ pip install controlnet-aux==0.0.7
```

To support DWPose which is dependent on MMDetection, MMCV and MMPose

```
pip install -U openmim
mim install mmengine
mim install "mmcv>=2.0.1"
mim install "mmdet>=3.1.0"
mim install "mmpose>=1.1.0"
```
## Usage

## Usage

You can use the processor class, which can load each of the auxiliary models with the following code

```python
import requests
from PIL import Image
@@ -51,6 +53,7 @@ processed_image = processor(img, to_pil=True)
```

Each model can be loaded individually by importing and instantiating them as follows

```python
from PIL import Image
import requests
@@ -76,6 +79,10 @@ zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
sam = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
mobile_sam = SamDetector.from_pretrained("dhkim2810/MobileSAM", model_type="vit_t", filename="mobile_sam.pt")
leres = LeresDetector.from_pretrained("lllyasviel/Annotators")
teed = TEEDdetector.from_pretrained("fal-ai/teed", filename="5_model.pth")
anyline = AnylineDetector.from_pretrained(
"TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline"
)

# specify configs, ckpts and device, or it will be downloaded automatically and use cpu by default
# det_config: ./src/controlnet_aux/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py
@@ -90,6 +97,7 @@ dwpose = DWposeDetector(det_config=det_config, det_ckpt=det_ckpt, pose_config=po
canny = CannyDetector()
content = ContentShuffleDetector()
face_detector = MediapipeFaceDetector()
lineart_standard = LineartStandardDetector()


# process
@@ -104,11 +112,14 @@ processed_image_lineart_anime = lineart_anime(img)
processed_image_zoe = zoe(img)
processed_image_sam = sam(img)
processed_image_leres = leres(img)
processed_image_teed = teed(img, detect_resolution=1024)
processed_image_anyline = anyline(img, detect_resolution=1280)

processed_image_canny = canny(img)
processed_image_content = content(img)
processed_image_mediapipe_face = face_detector(img)
processed_image_dwpose = dwpose(img)
processed_image_lineart_standard = lineart_standard(img, detect_resolution=1024)
```

### Image resolution
12 changes: 7 additions & 5 deletions src/controlnet_aux/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
__version__ = "0.0.8"

from .anyline import AnylineDetector
from .canny import CannyDetector
from .dwpose import DWposeDetector
from .hed import HEDdetector
from .leres import LeresDetector
from .lineart import LineartDetector
from .lineart_anime import LineartAnimeDetector
from .lineart_standard import LineartStandardDetector
from .mediapipe_face import MediapipeFaceDetector
from .midas import MidasDetector
from .mlsd import MLSDdetector
from .normalbae import NormalBaeDetector
from .open_pose import OpenposeDetector
from .pidi import PidiNetDetector
from .zoe import ZoeDetector

from .canny import CannyDetector
from .mediapipe_face import MediapipeFaceDetector
from .segment_anything import SamDetector
from .shuffle import ContentShuffleDetector
from .dwpose import DWposeDetector
from .teed import TEEDdetector
from .zoe import ZoeDetector
118 changes: 118 additions & 0 deletions src/controlnet_aux/anyline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# code based in https://github.com/TheMistoAI/ComfyUI-Anyline/blob/main/anyline.py
import os

import cv2
import numpy as np
import torch
from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image
from skimage import morphology

from ..teed.ted import TED
from ..util import HWC3, resize_image, safe_step


class AnylineDetector:
def __init__(self, model):
self.model = model

@classmethod
def from_pretrained(cls, pretrained_model_or_path, filename=None, subfolder=None):
if os.path.isdir(pretrained_model_or_path):
model_path = os.path.join(pretrained_model_or_path, filename)
else:
model_path = hf_hub_download(
pretrained_model_or_path, filename, subfolder=subfolder
)

model = TED()
model.load_state_dict(torch.load(model_path, map_location="cpu"))

return cls(model)

def to(self, device):
self.model.to(device)
return self

def __call__(
self,
input_image,
detect_resolution=1280,
guassian_sigma=2.0,
intensity_threshold=3,
output_type="pil",
):
device = next(iter(self.model.parameters())).device

if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image, dtype=np.uint8)
output_type = output_type or "pil"
else:
output_type = output_type or "np"

original_height, original_width, _ = input_image.shape

input_image = HWC3(input_image)
input_image = resize_image(input_image, detect_resolution)

assert input_image.ndim == 3
height, width, _ = input_image.shape
with torch.no_grad():
image_teed = torch.from_numpy(input_image.copy()).float().to(device)
image_teed = rearrange(image_teed, "h w c -> 1 c h w")
edges = self.model(image_teed)
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
edges = [
cv2.resize(e, (width, height), interpolation=cv2.INTER_LINEAR)
for e in edges
]
edges = np.stack(edges, axis=2)
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
edge = safe_step(edge, 2)
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)

mteed_result = edge
mteed_result = HWC3(mteed_result)

x = input_image.astype(np.float32)
g = cv2.GaussianBlur(x, (0, 0), guassian_sigma)
intensity = np.min(g - x, axis=2).clip(0, 255)
intensity /= max(16, np.median(intensity[intensity > intensity_threshold]))
intensity *= 127
lineart_result = intensity.clip(0, 255).astype(np.uint8)

lineart_result = HWC3(lineart_result)

lineart_result = self.get_intensity_mask(
lineart_result, lower_bound=0, upper_bound=255
)

cleaned = morphology.remove_small_objects(
lineart_result.astype(bool), min_size=36, connectivity=1
)
lineart_result = lineart_result * cleaned
final_result = self.combine_layers(mteed_result, lineart_result)

final_result = cv2.resize(
final_result,
(original_width, original_height),
interpolation=cv2.INTER_LINEAR,
)

if output_type == "pil":
final_result = Image.fromarray(final_result)

return final_result

def get_intensity_mask(self, image_array, lower_bound, upper_bound):
mask = image_array[:, :, 0]
mask = np.where((mask >= lower_bound) & (mask <= upper_bound), mask, 0)
mask = np.expand_dims(mask, 2).repeat(3, axis=2)
return mask

def combine_layers(self, base_layer, top_layer):
mask = top_layer.astype(bool)
temp = 1 - (1 - top_layer) * (1 - base_layer)
result = base_layer * (~mask) + temp * mask
return result
47 changes: 47 additions & 0 deletions src/controlnet_aux/lineart_standard/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Code based based from the repository comfyui_controlnet_aux:
# https://github.com/Fannovel16/comfyui_controlnet_aux/blob/main/src/controlnet_aux/lineart_standard/__init__.py
import cv2
import numpy as np
from PIL import Image

from ..util import HWC3, resize_image


class LineartStandardDetector:
def __call__(
self,
input_image=None,
guassian_sigma=6.0,
intensity_threshold=8,
detect_resolution=512,
output_type="pil",
):
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image, dtype=np.uint8)
else:
output_type = output_type or "np"

original_height, original_width, _ = input_image.shape

input_image = HWC3(input_image)
input_image = resize_image(input_image, detect_resolution)

x = input_image.astype(np.float32)
g = cv2.GaussianBlur(x, (0, 0), guassian_sigma)
intensity = np.min(g - x, axis=2).clip(0, 255)
intensity /= max(16, np.median(intensity[intensity > intensity_threshold]))
intensity *= 127
detected_map = intensity.clip(0, 255).astype(np.uint8)

detected_map = HWC3(detected_map)

detected_map = cv2.resize(
detected_map,
(original_width, original_height),
interpolation=cv2.INTER_CUBIC,
)

if output_type == "pil":
detected_map = Image.fromarray(detected_map)

return detected_map
19 changes: 19 additions & 0 deletions src/controlnet_aux/teed/Fsmish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
Script based on:
Wang, Xueliang, Honge Ren, and Achuan Wang.
"Smish: A Novel Activation Function for Deep Learning Methods.
" Electronics 11.4 (2022): 540.
"""

# import pytorch
import torch


@torch.jit.script
def smish(input):
"""
Applies the mish function element-wise:
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(sigmoid(x))))
See additional documentation for mish class.
"""
return input * torch.tanh(torch.log(1 + torch.sigmoid(input)))
21 changes: 21 additions & 0 deletions src/controlnet_aux/teed/LICENSE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2022 Xavier Soria Poma

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
41 changes: 41 additions & 0 deletions src/controlnet_aux/teed/Xsmish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Script based on:
Wang, Xueliang, Honge Ren, and Achuan Wang.
"Smish: A Novel Activation Function for Deep Learning Methods.
" Electronics 11.4 (2022): 540.
smish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + sigmoid(x)))
"""

# import pytorch
# import activation functions
from torch import nn

from .Fsmish import smish


class Smish(nn.Module):
"""
Applies the mish function element-wise:
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
Shape:
- Input: (N, *) where * means, any number of additional
dimensions
- Output: (N, *), same shape as the input
Examples:
>>> m = Mish()
>>> input = torch.randn(2)
>>> output = m(input)
Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
"""

def __init__(self):
"""
Init method.
"""
super().__init__()

def forward(self, input):
"""
Forward pass of the function.
"""
return smish(input)
84 changes: 84 additions & 0 deletions src/controlnet_aux/teed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os

import cv2
import numpy as np
import torch
from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image

from ..util import HWC3, resize_image, safe_step
from .ted import TED


class TEEDdetector:
def __init__(self, model):
self.model = model

@classmethod
def from_pretrained(cls, pretrained_model_or_path, filename=None, subfolder=None):
if os.path.isdir(pretrained_model_or_path):
model_path = os.path.join(pretrained_model_or_path, filename)
else:
model_path = hf_hub_download(
pretrained_model_or_path, filename, subfolder=subfolder
)

model = TED()
model.load_state_dict(torch.load(model_path, map_location="cpu"))

return cls(model)

def to(self, device):
self.model.to(device)
return self

def __call__(
self,
input_image,
detect_resolution=512,
safe_steps=2,
output_type="pil",
):
device = next(iter(self.model.parameters())).device
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image, dtype=np.uint8)
output_type = output_type or "pil"
else:
output_type = output_type or "np"

original_height, original_width, _ = input_image.shape

input_image = HWC3(input_image)
input_image = resize_image(input_image, detect_resolution)

assert input_image.ndim == 3
height, width, _ = input_image.shape
with torch.no_grad():
image_teed = torch.from_numpy(input_image.copy()).float().to(device)
image_teed = rearrange(image_teed, "h w c -> 1 c h w")
edges = self.model(image_teed)
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
edges = [
cv2.resize(e, (width, height), interpolation=cv2.INTER_LINEAR)
for e in edges
]
edges = np.stack(edges, axis=2)
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
if safe_steps != 0:
edge = safe_step(edge, safe_steps)
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)

detected_map = edge
detected_map = HWC3(detected_map)

detected_map = cv2.resize(
detected_map,
(original_width, original_height),
interpolation=cv2.INTER_LINEAR,
)

if output_type == "pil":
detected_map = Image.fromarray(detected_map)

return detected_map
332 changes: 332 additions & 0 deletions src/controlnet_aux/teed/ted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
# Original from: https://github.com/xavysp/TEED
# TEED: is a Tiny but Efficient Edge Detection, it comes from the LDC-B3
# with a Slightly modification
# LDC parameters:
# 155665
# TED > 58K

import torch
import torch.nn as nn
import torch.nn.functional as F

from .Fsmish import smish as Fsmish
from .Xsmish import Smish


def weight_init(m):
if isinstance(m, (nn.Conv2d,)):
torch.nn.init.xavier_normal_(m.weight, gain=1.0)

if m.bias is not None:
torch.nn.init.zeros_(m.bias)

# for fusion layer
if isinstance(m, (nn.ConvTranspose2d,)):
torch.nn.init.xavier_normal_(m.weight, gain=1.0)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)


class CoFusion(nn.Module):
# from LDC

def __init__(self, in_ch, out_ch):
super(CoFusion, self).__init__()
self.conv1 = nn.Conv2d(
in_ch, 32, kernel_size=3, stride=1, padding=1
) # before 64
self.conv3 = nn.Conv2d(
32, out_ch, kernel_size=3, stride=1, padding=1
) # before 64 instead of 32
self.relu = nn.ReLU()
self.norm_layer1 = nn.GroupNorm(4, 32) # before 64

def forward(self, x):
# fusecat = torch.cat(x, dim=1)
attn = self.relu(self.norm_layer1(self.conv1(x)))
attn = F.softmax(self.conv3(attn), dim=1)
return ((x * attn).sum(1)).unsqueeze(1)


class CoFusion2(nn.Module):
# TEDv14-3
def __init__(self, in_ch, out_ch):
super(CoFusion2, self).__init__()
self.conv1 = nn.Conv2d(
in_ch, 32, kernel_size=3, stride=1, padding=1
) # before 64
# self.conv2 = nn.Conv2d(32, 32, kernel_size=3,
# stride=1, padding=1)# before 64
self.conv3 = nn.Conv2d(
32, out_ch, kernel_size=3, stride=1, padding=1
) # before 64 instead of 32
self.smish = Smish() # nn.ReLU(inplace=True)

def forward(self, x):
# fusecat = torch.cat(x, dim=1)
attn = self.conv1(self.smish(x))
attn = self.conv3(self.smish(attn)) # before , )dim=1)

# return ((fusecat * attn).sum(1)).unsqueeze(1)
return ((x * attn).sum(1)).unsqueeze(1)


class DoubleFusion(nn.Module):
# TED fusion before the final edge map prediction
def __init__(self, in_ch, out_ch):
super(DoubleFusion, self).__init__()
self.DWconv1 = nn.Conv2d(
in_ch, in_ch * 8, kernel_size=3, stride=1, padding=1, groups=in_ch
) # before 64
self.PSconv1 = nn.PixelShuffle(1)

self.DWconv2 = nn.Conv2d(
24, 24 * 1, kernel_size=3, stride=1, padding=1, groups=24
) # before 64 instead of 32

self.AF = Smish() # XAF() #nn.Tanh()# XAF() # # Smish()#

def forward(self, x):
# fusecat = torch.cat(x, dim=1)
attn = self.PSconv1(
self.DWconv1(self.AF(x))
) # #TEED best res TEDv14 [8, 32, 352, 352]

attn2 = self.PSconv1(
self.DWconv2(self.AF(attn))
) # #TEED best res TEDv14[8, 3, 352, 352]

return Fsmish(((attn2 + attn).sum(1)).unsqueeze(1)) # TED best res


class _DenseLayer(nn.Sequential):
def __init__(self, input_features, out_features):
super(_DenseLayer, self).__init__()

(
self.add_module(
"conv1",
nn.Conv2d(
input_features,
out_features,
kernel_size=3,
stride=1,
padding=2,
bias=True,
),
),
)
(self.add_module("smish1", Smish()),)
self.add_module(
"conv2",
nn.Conv2d(out_features, out_features, kernel_size=3, stride=1, bias=True),
)

def forward(self, x):
x1, x2 = x

new_features = super(_DenseLayer, self).forward(Fsmish(x1)) # F.relu()

return 0.5 * (new_features + x2), x2


class _DenseBlock(nn.Sequential):
def __init__(self, num_layers, input_features, out_features):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(input_features, out_features)
self.add_module("denselayer%d" % (i + 1), layer)
input_features = out_features


class UpConvBlock(nn.Module):
def __init__(self, in_features, up_scale):
super(UpConvBlock, self).__init__()
self.up_factor = 2
self.constant_features = 16

layers = self.make_deconv_layers(in_features, up_scale)
assert layers is not None, layers
self.features = nn.Sequential(*layers)

def make_deconv_layers(self, in_features, up_scale):
layers = []
all_pads = [0, 0, 1, 3, 7]
for i in range(up_scale):
kernel_size = 2**up_scale
pad = all_pads[up_scale] # kernel_size-1
out_features = self.compute_out_features(i, up_scale)
layers.append(nn.Conv2d(in_features, out_features, 1))
layers.append(Smish())
layers.append(
nn.ConvTranspose2d(
out_features, out_features, kernel_size, stride=2, padding=pad
)
)
in_features = out_features
return layers

def compute_out_features(self, idx, up_scale):
return 1 if idx == up_scale - 1 else self.constant_features

def forward(self, x):
return self.features(x)


class SingleConvBlock(nn.Module):
def __init__(self, in_features, out_features, stride, use_ac=False):
super(SingleConvBlock, self).__init__()
# self.use_bn = use_bs
self.use_ac = use_ac
self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride, bias=True)
if self.use_ac:
self.smish = Smish()

def forward(self, x):
x = self.conv(x)
if self.use_ac:
return self.smish(x)
else:
return x


class DoubleConvBlock(nn.Module):
def __init__(
self, in_features, mid_features, out_features=None, stride=1, use_act=True
):
super(DoubleConvBlock, self).__init__()

self.use_act = use_act
if out_features is None:
out_features = mid_features
self.conv1 = nn.Conv2d(in_features, mid_features, 3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1)
self.smish = Smish() # nn.ReLU(inplace=True)

def forward(self, x):
x = self.conv1(x)
x = self.smish(x)
x = self.conv2(x)
if self.use_act:
x = self.smish(x)
return x


class TED(nn.Module):
"""Definition of Tiny and Efficient Edge Detector
model
"""

def __init__(self):
super(TED, self).__init__()
self.block_1 = DoubleConvBlock(
3,
16,
16,
stride=2,
)
self.block_2 = DoubleConvBlock(16, 32, use_act=False)
self.dblock_3 = _DenseBlock(1, 32, 48) # [32,48,100,100] before (2, 32, 64)

self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

# skip1 connection, see fig. 2
self.side_1 = SingleConvBlock(16, 32, 2)

# skip2 connection, see fig. 2
self.pre_dense_3 = SingleConvBlock(32, 48, 1) # before (32, 64, 1)

# USNet
self.up_block_1 = UpConvBlock(16, 1)
self.up_block_2 = UpConvBlock(32, 1)
self.up_block_3 = UpConvBlock(48, 2) # (32, 64, 1)

self.block_cat = DoubleFusion(3, 3) # TEED: DoubleFusion

self.apply(weight_init)

def slice(self, tensor, slice_shape):
t_shape = tensor.shape
img_h, img_w = slice_shape
if img_w != t_shape[-1] or img_h != t_shape[2]:
new_tensor = F.interpolate(
tensor, size=(img_h, img_w), mode="bicubic", align_corners=False
)

else:
new_tensor = tensor
# tensor[..., :height, :width]
return new_tensor

def resize_input(self, tensor):
t_shape = tensor.shape
if t_shape[2] % 8 != 0 or t_shape[3] % 8 != 0:
img_w = ((t_shape[3] // 8) + 1) * 8
img_h = ((t_shape[2] // 8) + 1) * 8
new_tensor = F.interpolate(
tensor, size=(img_h, img_w), mode="bicubic", align_corners=False
)
else:
new_tensor = tensor
return new_tensor

def crop_bdcn(data1, h, w, crop_h, crop_w):
# Based on BDCN Implementation @ https://github.com/pkuCactus/BDCN
_, _, h1, w1 = data1.size()
assert h <= h1 and w <= w1
data = data1[:, :, crop_h : crop_h + h, crop_w : crop_w + w]
return data

def forward(self, x, single_test=False):
assert x.ndim == 4, x.shape
# supose the image size is 352x352

# Block 1
block_1 = self.block_1(x) # [8,16,176,176]
block_1_side = self.side_1(block_1) # 16 [8,32,88,88]

# Block 2
block_2 = self.block_2(block_1) # 32 # [8,32,176,176]
block_2_down = self.maxpool(block_2) # [8,32,88,88]
block_2_add = block_2_down + block_1_side # [8,32,88,88]

# Block 3
block_3_pre_dense = self.pre_dense_3(
block_2_down
) # [8,64,88,88] block 3 L connection
block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense]) # [8,64,88,88]

# upsampling blocks
out_1 = self.up_block_1(block_1)
out_2 = self.up_block_2(block_2)
out_3 = self.up_block_3(block_3)

results = [out_1, out_2, out_3]

# concatenate multiscale outputs
block_cat = torch.cat(results, dim=1) # Bx6xHxW
block_cat = self.block_cat(block_cat) # Bx1xHxW DoubleFusion

results.append(block_cat)
return results


if __name__ == "__main__":
batch_size = 8
img_height = 352
img_width = 352

# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
input = torch.rand(batch_size, 3, img_height, img_width).to(device)
# target = torch.rand(batch_size, 1, img_height, img_width).to(device)
print(f"input shape: {input.shape}")
model = TED().to(device)
output = model(input)
print(f"output shapes: {[t.shape for t in output]}")

# for i in range(20000):
# print(i)
# output = model(input)
# loss = nn.MSELoss()(output[-1], target)
# loss.backward()

0 comments on commit 907ae1b

Please sign in to comment.