Skip to content

Commit

Permalink
change for subset all
Browse files Browse the repository at this point in the history
  • Loading branch information
Gen Matono committed May 3, 2023
1 parent 621ee56 commit 8c34755
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 125 deletions.
6 changes: 3 additions & 3 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import numpy as np
import pandas as pd
import os
from options import *
from model import *
from data import *
from utils.options import *
from models.model import *
from utils.data import *

# ----------------------------------------------------------------------------------------
# make function
Expand Down
2 changes: 1 addition & 1 deletion decoder.py → models/decoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
from module import *
from models.module import *

# class for morphing based decoder
class OneMorphingDecoder(nn.Module):
Expand Down
30 changes: 16 additions & 14 deletions encoder.py → models/encoder.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import torch
import torch.nn as nn
from module import *
from stn import *
from models.module import *
# from stn import *

class Encoder(nn.Module):
def __init__(self, emb_dim):
super(Encoder, self).__init__()
self.emb_dim = emb_dim

self.stn3d = STNkd(3)
# self.stn3d = STNkd(3)
self.MLP1 = nn.Sequential(
SharedMLP(3, 64),
SharedMLP(64, 64)
)
self.stn64d = STNkd(64)
# self.stn64d = STNkd(64)
self.MLP2 = nn.Sequential(
SharedMLP(64, 64),
SharedMLP(64, 128),
Expand All @@ -27,22 +27,24 @@ def __init__(self, emb_dim):

def forward(self, x):
# stn for input
trans_3d = self.stn3d(x)
x = x.permute(0, 2, 1)
trans_x = torch.bmm(x, trans_3d)
trans_x = trans_x.permute(0, 2, 1)
# trans_3d = self.stn3d(x)
# x = x.permute(0, 2, 1)
# trans_x = torch.bmm(x, trans_3d)
# trans_x = trans_x.permute(0, 2, 1)

# MLP1
x = self.MLP1(trans_x)
# x = self.MLP1(trans_x)
x = self.MLP1(x)

# stn for second t-net
trans_64d = self.stn64d(x)
x = x.permute(0, 2, 1)
trans_x = torch.bmm(x, trans_64d)
trans_x = trans_x.permute(0, 2, 1)
# trans_64d = self.stn64d(x)
# x = x.permute(0, 2, 1)
# trans_x = torch.bmm(x, trans_64d)
# trans_x = trans_x.permute(0, 2, 1)

# MLP2
x = self.MLP2(trans_x)
# x = self.MLP2(trans_x)
x = self.MLP2(x)

# Max Pooling
x, _ = torch.max(x, dim=2)
Expand Down
8 changes: 4 additions & 4 deletions model.py → models/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import torch
import torch.nn as nn
from encoder import *
from decoder import *
from models.encoder import *
from models.decoder import *
import sys
sys.path.append("./expansion_penalty")
sys.path.append("./MDS")
import expansion_penalty_module as expansion
from module import farthest_point_sampling, index2point_converter
import expansion_penalty.expansion_penalty_module as expansion
from models.module import farthest_point_sampling, index2point_converter

class MSN(nn.Module):
def __init__(self, emb_dim, num_output_points, num_surfaces, sampling_method):
Expand Down
File renamed without changes.
File renamed without changes.
33 changes: 15 additions & 18 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import os
import sys
import datetime
from data import *
from options import make_parser
from model import MSN
from utils.data import *
from utils.options import make_parser
from models.model import MSN
sys.path.append("./emd")
import emd_module as emd

Expand Down Expand Up @@ -87,7 +87,7 @@ def val_one_epoch(model, dataloader):

# ----------------------------------------------------------------------------------------
if __name__ == "__main__":
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# get options
parser = make_parser()
args = parser.parse_args()
Expand All @@ -104,30 +104,28 @@ def val_one_epoch(model, dataloader):
f.write('')

writter = SummaryWriter()
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# make dataloader
# data_dir = os.path.join(args.dataset_dir)
train_dataset = MakeDataset(dataset_path=args.dataset_dir, subset=args.subset,
eval="train", num_partial_pattern=4, device=args.device)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size,
shuffle=True, drop_last=True,
collate_fn=OriginalCollate(args.num_partial, args.num_comp, args.device)) # DataLoader is iterable object.
collate_fn=OriginalCollate(args.device)) # DataLoader is iterable object.

# validation data
val_dataset = MakeDataset(dataset_path=args.dataset_dir, subset=args.subset,
eval="val", num_partial_pattern=4,device=args.device)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=2,
shuffle=True, drop_last=True,
collate_fn=OriginalCollate(args.num_partial, args.num_comp, args.device))
collate_fn=OriginalCollate(args.device))

# check of data in dataloader
# for i, points in enumerate(tqdm(train_dataloader)):
# print(f"complete points:{points[0].shape}, partial points:{points[1].shape}")
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# prepare model and optimaizer
model = MSN(args.emb_dim, args.num_output_points, args.num_surfaces, args.sampling_method).to(args.device)
if args.optimizer == "Adam":
Expand All @@ -136,19 +134,18 @@ def val_one_epoch(model, dataloader):
optim = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.6)

# lr_schdual = torch.optim.lr_scheduler.StepLR(optim, step_size=int(args.epochs/4), gamma=0.7)
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# main loop
best_loss = np.inf
for epoch in tqdm(range(1, args.epochs+1), desc="main loop"):

# determin the ration of loss
if epoch < 50:
if epoch < 40:
alpha = 0.01
elif epoch < 100:
elif epoch < 800:
alpha = 0.1
elif epoch < 200:
elif epoch < 120:
alpha = 0.5
else:
alpha = 1.0
Expand Down
Loading

0 comments on commit 8c34755

Please sign in to comment.