-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ff4a2fa
commit 38be49a
Showing
1 changed file
with
306 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,306 @@ | ||
|
||
# Model Architecture | ||
# Author: Landy Xu, created on Nov. 12, 2022 | ||
# Last modified by Simon on Nov. 13 | ||
# Version 2: add attention to shallow feature, change first conv to 1x1 kernal | ||
''' | ||
Change log: | ||
- Landy: create feature extractor and DILRAN | ||
- Simon: revise some writing style of module configs (e.g., replace = True), | ||
refine the FE module, add recon module | ||
- Simon: create full model pipeline | ||
- Simon: add leaky relu to recon module | ||
''' | ||
|
||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import numpy as np | ||
|
||
|
||
class ConvLeakyRelu2d(nn.Module): | ||
# convolution | ||
# leaky relu | ||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1, dilation=1, groups=1): | ||
super(ConvLeakyRelu2d, self).__init__() | ||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation, groups=groups) | ||
# self.bn = nn.BatchNorm2d(out_channels) | ||
def forward(self,x): | ||
# print(x.size()) | ||
return F.leaky_relu(self.conv(x), negative_slope=0.2) | ||
|
||
class Sobelxy(nn.Module): | ||
def __init__(self,channels, kernel_size=3, padding=1, stride=1, dilation=1, groups=1): | ||
super(Sobelxy, self).__init__() | ||
sobel_filter = np.array([[1, 0, -1], | ||
[2, 0, -2], | ||
[1, 0, -1]]) | ||
self.convx=nn.Conv2d(channels, channels, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation, groups=channels,bias=False) | ||
self.convx.weight.data.copy_(torch.from_numpy(sobel_filter)) | ||
self.convy=nn.Conv2d(channels, channels, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation, groups=channels,bias=False) | ||
self.convy.weight.data.copy_(torch.from_numpy(sobel_filter.T)) | ||
def forward(self, x): | ||
sobelx = self.convx(x) | ||
sobely = self.convy(x) | ||
x=torch.abs(sobelx) + torch.abs(sobely) | ||
return x | ||
|
||
class Conv1(nn.Module): | ||
def __init__(self, in_channels, out_channels, kernel_size=1, padding=0, stride=1, dilation=1, groups=1): | ||
super(Conv1, self).__init__() | ||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation, groups=groups) | ||
def forward(self,x): | ||
return self.conv(x) | ||
|
||
class DenseBlock(nn.Module): | ||
def __init__(self,channels): | ||
super(DenseBlock, self).__init__() | ||
# self.conv_in = nn.Conv2d(1, 64, kernel_size=1) | ||
self.conv1 = ConvLeakyRelu2d(channels, channels) | ||
self.conv2 = ConvLeakyRelu2d(2*channels, 2*channels) | ||
# self.conv3 = ConvLeakyRelu2d(3*channels, channels) | ||
def forward(self,x): | ||
# x = self.conv_in(x) | ||
x = torch.cat((x,self.conv1(x)),dim=1) | ||
x = self.conv2(x) | ||
# x = torch.cat((x, self.conv3(x)), dim=1) | ||
return x | ||
|
||
class Edge_Enhancer(nn.Module): | ||
def __init__(self,in_channels,out_channels): | ||
super(Edge_Enhancer, self).__init__() | ||
self.dense =DenseBlock(in_channels) | ||
self.convdown=Conv1(2*in_channels,out_channels) | ||
self.sobelconv=Sobelxy(in_channels) | ||
self.convup = Conv1(in_channels,out_channels) | ||
self.dropout = nn.Dropout(0.2) | ||
def forward(self,x): | ||
x1=self.dropout(self.dense(x)) | ||
x1=self.convdown(x1) | ||
x2=self.sobelconv(x) | ||
x2=self.convup(x2) | ||
return self.dropout(x1+x2) | ||
|
||
|
||
class DILRAN(nn.Module): | ||
def __init__(self): | ||
super(DILRAN, self).__init__() | ||
# TODO: confirm convolution | ||
self.conv = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) | ||
self.up = nn.Upsample(scale_factor=2, mode='nearest') | ||
self.down = nn.AvgPool2d(2, 2) | ||
self.lu = nn.ReLU(replace = True) | ||
|
||
def forward(self, x): | ||
prev = self.conv(x) + self.conv(self.conv(x)) + self.conv(self.conv(self.conv(x))) | ||
return torch.mul(self.lu(self.up(self.down(x))), prev) + x | ||
|
||
|
||
class FeatureExtractor(nn.Module): | ||
def __init__(self, level): | ||
super(FeatureExtractor, self).__init__() | ||
# TODO: confirm dilated convolution | ||
self.conv = nn.Conv2d(1, 64, (1, 1), (1, 1), (0, 0), dilation = 2) | ||
self.network = DILRAN() | ||
self.up = nn.Upsample(scale_factor=2, mode='nearest') | ||
self.down = nn.AvgPool2d(2, 2) | ||
self.lu = nn.ReLU(replace = True) | ||
|
||
def forward(self, x): | ||
n1 = self.network(self.conv(x[0])) | ||
n2 = self.network(self.conv(x[1])) | ||
n3 = self.network(self.conv(x[2])) | ||
return torch.cat((n1, n2, n3), 0) | ||
|
||
|
||
class DILRAN_V1(nn.Module): | ||
''' | ||
V1: concat the output of three (conv-d,DILRAN) paths channel wise and add the low level feature to the concat output | ||
temporary, will edit if necessary | ||
''' | ||
def __init__(self, cat_first = False, use_leaky = False): | ||
super(DILRAN_V1, self).__init__() | ||
# cat_first, whether to perform channel-wise concat before DILRAN | ||
# convolution in DILRAN, in channel is the channel from the previous block | ||
if not cat_first: | ||
self.conv_d = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding="same") | ||
self.bnorm = nn.BatchNorm2d(num_features=64) | ||
else: | ||
self.conv_d = nn.Conv2d(in_channels=64*3, out_channels=64*3, kernel_size=3, stride=1, padding="same") | ||
self.bnorm = nn.BatchNorm2d(num_features=64*3) | ||
|
||
if not use_leaky: | ||
self.relu = nn.ReLU() | ||
else: | ||
self.lrelu = nn.LeakyReLU(0.2, inplace=True) | ||
|
||
self.down = nn.AvgPool2d(2, 2) | ||
self.up = nn.Upsample(scale_factor=2, mode="nearest") | ||
|
||
|
||
def forward(self, x): | ||
# pooling -> upsample -> ReLU block | ||
pur_path = self.relu(self.up(self.down(x))) | ||
# 3*3, 5*5, 7*7 multiscale addition block | ||
conv_path = self.conv_d(x) + self.conv_d(self.conv_d(x)) + self.conv_d(self.conv_d(self.conv_d(x))) | ||
# attention | ||
attn = torch.mul(pur_path, conv_path) | ||
# residual + attention | ||
resid_x = x + attn | ||
return resid_x | ||
|
||
|
||
class FE_V1(nn.Module): | ||
''' | ||
feature extractor block (temporary, will edit if necessary) | ||
''' | ||
def __init__(self): | ||
super(FE_V1, self).__init__() | ||
|
||
# multiscale dilation conv2d | ||
self.convd1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, dilation=1, padding="same") | ||
self.convd2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, dilation=3, padding="same") | ||
self.convd3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, dilation=5, padding="same") | ||
|
||
self.reduce = nn.Conv2d(in_channels=64*3, out_channels=64, kernel_size=1, stride=1, padding="same") | ||
self.relu = nn.ReLU() | ||
|
||
self.bnorm1 = nn.BatchNorm2d(num_features=64) | ||
|
||
self.dilran = DILRAN_V1() | ||
|
||
|
||
def forward(self, x): | ||
|
||
# dilated convolution | ||
dilf1 = self.convd1(x) | ||
dilf2 = self.convd2(x) | ||
dilf3 = self.convd3(x) | ||
|
||
diltotal = torch.cat((dilf1, dilf2, dilf3), dim = 1) | ||
diltotal = self.reduce(diltotal) | ||
diltotal = self.bnorm1(diltotal) | ||
|
||
# single DILRAN | ||
out = self.dilran(diltotal) | ||
out = self.bnorm1(out) | ||
#out = self.relu(out) | ||
return out | ||
|
||
# DILRAN | ||
# dilran_o1 = self.dilran(dilf1) | ||
# # batchnorm | ||
# dilran_o1 = self.bnorm1(dilran_o1) | ||
# dilran_o2 = self.dilran(dilf2) | ||
# # batchnorm | ||
# dilran_o2 = self.bnorm1(dilran_o2) | ||
# dilran_o3 = self.dilran(dilf3) | ||
# # batchnorm | ||
# dilran_o3 = self.bnorm1(dilran_o3) | ||
# # element-wise addition | ||
# cat_o = dilran_o1 + dilran_o2 + dilran_o3 | ||
|
||
# return cat_o | ||
|
||
class MSFuNet(nn.Module): | ||
''' | ||
the whole network (from input image -> feature maps to be used in fusion strategy) | ||
temporary, will edit if necessary | ||
''' | ||
def __init__(self): | ||
super(MSFuNet, self).__init__() | ||
|
||
self.conv_id = nn.Sequential(nn.Conv2d(in_channels=64*3, out_channels=64, kernel_size=1, stride=1, padding="same")) | ||
#nn.BatchNorm2d(num_features = 64)) | ||
#nn.ReLU(inplace=True)) | ||
self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=1, stride=1, padding="same") | ||
|
||
self.conv2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding="same"), | ||
nn.BatchNorm2d(num_features=64), | ||
nn.ReLU(), | ||
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding="same"), | ||
nn.BatchNorm2d(num_features=64)) | ||
|
||
self.relu = nn.ReLU() | ||
self.down = nn.AvgPool2d(2, 2) | ||
self.bnorm = nn.BatchNorm2d(num_features=64) | ||
self.up = nn.Upsample(scale_factor=2, mode="nearest") | ||
|
||
self.fe = FE_V1() | ||
self.edge_enhance = Edge_Enhancer(64,64) | ||
|
||
def forward(self, x): | ||
# x: input image | ||
resid = self.conv1(x) | ||
temp0 = self.conv1(x) # shallow feature, 64 x (1x1) | ||
pur_orig = self.relu(self.up(self.down(x))) | ||
attn = torch.mul(pur_orig, temp0) | ||
x = x + attn | ||
# feature returned from feature extractor | ||
deep_fe = self.fe(x) | ||
pur_x = self.relu(self.up(self.down(x))) | ||
attn2 = torch.mul(pur_x, deep_fe) | ||
add = attn2 + x | ||
|
||
# addded for edge enhance | ||
edge_x = self.edge_enhance(resid) | ||
add = add + edge_x | ||
return add | ||
#x = x + cat_feature | ||
# short cut connection | ||
# expand_x = self.conv_id(x) | ||
# add = expand_x + cat_feature | ||
|
||
#add = self.conv2(add) | ||
# add = self.conv2(resid) # should get shape [b, 64, 256, 256] | ||
# return add | ||
|
||
|
||
class Recon(nn.Module): | ||
''' | ||
reconstruction module (temporary, will edit if necessary) | ||
''' | ||
def __init__(self): | ||
super(Recon, self).__init__() | ||
|
||
# version 1 | ||
# self.recon_conv = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding="same"), | ||
# nn.LeakyReLU(0.2, inplace=True), | ||
# nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding="same"), | ||
# nn.LeakyReLU(0.2, inplace=True), | ||
# nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding="same"), | ||
# nn.LeakyReLU(0.2, inplace=True), | ||
# nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, stride=1, padding="same"), | ||
# nn.LeakyReLU(0.2, inplace=True)) | ||
|
||
# version 2 | ||
self.recon_conv = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding="same"), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
#nn.ReLU(), | ||
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding="same"), | ||
#nn.ReLU(), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding="same")) | ||
#nn.ReLU()) | ||
def forward(self, x): | ||
x = self.recon_conv(x) | ||
return x # should get shape [b, 1, 256, 256] | ||
|
||
|
||
class fullModel(nn.Module): | ||
''' | ||
Feature extractor + reconstruction | ||
a full model pipeline | ||
''' | ||
def __init__(self): | ||
super(fullModel, self).__init__() | ||
|
||
self.fe = MSFuNet() | ||
self.recon = Recon() | ||
|
||
def forward(self, x): | ||
deep_fe = self.fe(x) | ||
recon_img = self.recon(deep_fe) | ||
return recon_img |