Skip to content

Commit

Permalink
updated
Browse files Browse the repository at this point in the history
  • Loading branch information
ihaeyong committed Jul 18, 2024
1 parent 513f660 commit 0d765bc
Show file tree
Hide file tree
Showing 31 changed files with 5,376 additions and 1,011 deletions.
257 changes: 250 additions & 7 deletions model_nerv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@

from model.subnet import SubnetConv2d, SubnetLinear, SubnetConvTranspose2d

from neuralop.spectral_convolution import FactorizedSpectralConv as SpectralConv
#from neuralop.spectral_convolution import FactorizedSpectralConvV1 as SpectralConv
from neuralop.spectral_convolution import FactorizedSpectralConv2d as SpectralConv2D
from neuralop.spectral_linear import FactorizedSpectralLinear as SpectralLinear
from neuralop.fno_block import FNOBlocks

from neuralop.resample import resample
from neuralop.skip_connections import skip_connection

class CustomDataSet(Dataset):
def __init__(self, main_dir, transform, vid_list=[None], frame_gap=1, visualize=False):
Expand Down Expand Up @@ -46,6 +54,79 @@ def __getitem__(self, idx):

return tensor_image, frame_idx

class CustomDataSetMTL1(Dataset):
def __init__(self, main_dir_list, transform, vid_list=[None], frame_gap=1):

self.frame_path = []
self.task_id = []
self.frame_idx = []
for id, main_dir in enumerate(main_dir_list):
# import ipdb; ipdb.set_trace()
print('loading...'+ main_dir)
#self.main_dir = main_dir
self.transform = transform
frame_idx = []
accum_img_num = []
all_imgs = os.listdir(main_dir)
all_imgs.sort()

num_frame = 0
for img_id in all_imgs:
img_id = os.path.join(main_dir, img_id)
self.frame_path.append(img_id)
frame_idx.append(num_frame)
self.task_id.append(id)
num_frame += 1

accum_img_num.append(num_frame)

frame_idx = [float(x) / len(frame_idx) for x in frame_idx]
self.accum_img_num = np.asfarray(accum_img_num)
if None not in vid_list:
frame_idx = [frame_idx[i] for i in vid_list]

# import ipdb; ipdb.set_trace()
self.frame_idx.extend(frame_idx)

self.frame_gap = frame_gap

def __len__(self):
return len(self.frame_idx) // self.frame_gap

def __getitem__(self, idx):
valid_idx = idx * self.frame_gap
img_name = self.frame_path[valid_idx]
#img_name = os.path.join(self.main_dir, img_id)
image = Image.open(img_name).convert("RGB")
tensor_image = self.transform(image)
if tensor_image.size(1) > tensor_image.size(2):
tensor_image = tensor_image.permute(0,2,1)

tensor_image = F.adaptive_avg_pool2d(tensor_image, (720, 1280))

frame_idx = torch.tensor(self.frame_idx[valid_idx])
task_id = torch.tensor(self.task_id[valid_idx])

return tensor_image, frame_idx, task_id

class CustomDataSetMTL(Dataset):
def __init__(self, data, norm_idx, task_id, max_num_tasks):
self.data = data
self.norm_idx = norm_idx
self.task_id = task_id
self.max_num_tasks = max_num_tasks

def __len__(self):
return self.max_num_tasks

def __getitem__(self, idx):

data = self.data[idx]
norm_idx = self.norm_idx[idx]
task_id = self.task_id[idx]

return data, norm_idx, task_id

class Sin(nn.Module):
def __init__(self, inplace: bool = False):
super(Sin, self).__init__()
Expand Down Expand Up @@ -102,6 +183,8 @@ def __init__(self, **kargs):
self.subnet= kargs['subnet']
self.sparsity = kargs['sparsity']
self.conv_type = kargs['conv_type']
self.name = kargs['name']

if self.conv_type == 'conv':
if not self.subnet:
self.conv = nn.Conv2d(ngf, new_ngf * stride * stride, 3, 1, 1, bias=kargs['bias'])
Expand All @@ -123,9 +206,85 @@ def __init__(self, **kargs):
self.conv = None
self.up_scale = None


elif self.conv_type == 'convfreq_sum' :

self.device = kargs['device']
self.scale = 0.1
self.var = 0.5
self.idx_th = 0

if self.name == 'layers.0.conv':
n_modes = (9,16)
elif self.name == 'layers.1.conv':
n_modes = (45,80)
elif self.name == 'layers.2.conv':
n_modes = (90,160)
elif self.name == 'layers.3.conv':
n_modes = (180,320)

self.conv = nn.Conv2d(ngf,
new_ngf * stride * stride, 3, 1, 1,
bias=kargs['bias'])

self.up_scale = nn.PixelShuffle(stride)

factorization='subnet'
rank=1.0
self.n_layers = 1
output_scaling_factor = (stride, stride)
incremental_n_modes = None
fft_norm = 'forward'
fixed_rank_modes=False
implementation = 'factorized'
separable=False
decomposition_kwargs = dict()
joint_factorization = False

# freq-sparsity
self.fft_scale = 1.0
# fft_sparsity = 1 - (1-self.sparsity) * self.fft_scale
fft_sparsity = 0.5 # dense

self.conv_freq = SpectralConv(
ngf, new_ngf // self.n_layers,
n_modes,
output_scaling_factor=output_scaling_factor,
incremental_n_modes=incremental_n_modes,
rank=rank,
fft_norm=fft_norm,
fixed_rank_modes=fixed_rank_modes,
implementation=implementation,
separable=separable,
factorization=factorization,
decomposition_kwargs=decomposition_kwargs,
joint_factorization=joint_factorization,
n_layers=self.n_layers, bias=False,
sparsity=fft_sparsity,
device=self.device, scale=self.scale,
idx_th=self.idx_th, var=self.var)

def forward(self, x):
out = self.conv(x)
return self.up_scale(out)

if self.conv_type == 'convfreq_sum':

out = self.conv(x)
out = self.up_scale(out)

# mask is none since it is a dense layer
weight_mask_real = None
weight_mask_imag = None

out += self.conv_freq(x=x,indices=0, output_shape=None,
weight_mask_real=weight_mask_real,
weight_mask_imag=weight_mask_imag,
mode='train')

return out
else:

out = self.conv(x)
return self.up_scale(out)


# Multiple Input Sequential
Expand Down Expand Up @@ -174,7 +333,9 @@ def __init__(self, **kargs):
bias=kargs['bias'],
conv_type=kargs['conv_type'],
subnet=kargs['subnet'],
sparsity=kargs['sparsity'])
sparsity=kargs['sparsity'],
device=kargs['device'],
name=kargs['name'] + '.conv')

self.norm = NormLayer(kargs['norm'], kargs['new_ngf'])
self.act = ActivationLayer(kargs['act'])
Expand All @@ -199,6 +360,8 @@ def __init__(self, **kargs):
else:
self.stem = SubnetMLP(dim_list=mlp_dim_list, act=kargs['act'], sparsity=self.sparsity)

self.freq = True if kargs['freq'] >= 0 else False

# BUILD CONV LAYERS
self.layers, self.head_layers = [nn.ModuleList() for _ in range(2)]
ngf = self.fc_dim
Expand All @@ -211,8 +374,25 @@ def __init__(self, **kargs):
new_ngf = max(ngf // (1 if stride == 1 else kargs['reduction']), kargs['lower_width'])

for j in range(kargs['num_blocks']):
self.layers.append(NeRVBlock(ngf=ngf, new_ngf=new_ngf, stride=1 if j else stride,
bias=kargs['bias'], norm=kargs['norm'], act=kargs['act'], conv_type=kargs['conv_type'], subnet=self.subnet, sparsity=self.sparsity))
name = 'layers.{}'.format(i)

if i == kargs['freq'] and self.freq:
conv_type = 'convfreq_sum' if kargs['cat_size'] < 0 else 'convfreq_cat'
self.layers.append(NeRVBlock(ngf=ngf, new_ngf=new_ngf, stride=1 if j else stride,
bias=kargs['bias'], norm=kargs['norm'], act=kargs['act'],
conv_type=conv_type, subnet=self.subnet,
sparsity=self.sparsity, device=kargs['device'],
name=name))

else:
self.layers.append(NeRVBlock(ngf=ngf, new_ngf=new_ngf, stride=1 if j else stride,
bias=kargs['bias'], norm=kargs['norm'], act=kargs['act'],
conv_type=kargs['conv_type'], subnet=self.subnet,
sparsity=self.sparsity, device=kargs['device'],
name=name))



ngf = new_ngf

# build head classifier, upscale feature layer, upscale img layer
Expand All @@ -239,8 +419,6 @@ def forward(self, input):
output = output.view(output.size(0), self.fc_dim, self.fc_h, self.fc_w)

out_list = []

import ipdb; ipdb.set_trace()
for layer, head_layer in zip(self.layers, self.head_layers):
output = layer(output)
if head_layer is not None:
Expand All @@ -251,4 +429,69 @@ def forward(self, input):
out_list.append(img_out)

return out_list





class GeneratorMTL(nn.Module):
def __init__(self, **kargs):
super().__init__()

self.name = 'generator'
stem_dim, stem_num = [int(x) for x in kargs['stem_dim_num'].split('_')]
self.fc_h, self.fc_w, self.fc_dim = [int(x) for x in kargs['fc_hw_dim'].split('_')]
mlp_dim_list = [kargs['embed_length'] * 2] + [stem_dim] * stem_num + [self.fc_h *self.fc_w *self.fc_dim]

self.subnet = kargs['subnet']
self.sparsity = kargs['sparsity']
if not self.subnet:
self.stem = MLP(dim_list=mlp_dim_list, act=kargs['act'])
else:
self.stem = SubnetMLP(dim_list=mlp_dim_list, act=kargs['act'], sparsity=self.sparsity)

# BUILD CONV LAYERS
self.layers, self.head_layers = [nn.ModuleList() for _ in range(2)]
ngf = self.fc_dim
for i, stride in enumerate(kargs['stride_list']):
if i == 0:
# expand channel width at first stage
new_ngf = int(ngf * kargs['expansion'])
else:
# change the channel width for each stage
new_ngf = max(ngf // (1 if stride == 1 else kargs['reduction']), kargs['lower_width'])

for j in range(kargs['num_blocks']):
self.layers.append(NeRVBlock(ngf=ngf, new_ngf=new_ngf, stride=1 if j else stride,
bias=kargs['bias'], norm=kargs['norm'], act=kargs['act'], conv_type=kargs['conv_type'], subnet=self.subnet, sparsity=self.sparsity))
ngf = new_ngf


# kargs['sin_res']:
for t in range(kargs['n_tasks']):
head_layer = nn.Conv2d(ngf, 3, 1, 1, bias=kargs['bias'])
self.head_layers.append(head_layer)

self.sigmoid =kargs['sigmoid']

def forward(self, input, task_id=None):
output = self.stem(input)
output = output.view(output.size(0), self.fc_dim, self.fc_h, self.fc_w)

out_list = []
for layer in self.layers:
output = layer(output)

img_out = []
for i, id in enumerate(task_id):
out = self.head_layers[task_id[i]](output[i])
img_out.append(out[None])

img_out = torch.cat(img_out)

# normalize the final output iwth sigmoid or tanh function
img_out = torch.sigmoid(img_out) if self.sigmoid else (torch.tanh(img_out) + 1) * 0.5
out_list.append(img_out)

return out_list

Loading

0 comments on commit 0d765bc

Please sign in to comment.