Skip to content

Commit

Permalink
adjusting random sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
GenMNL committed Oct 16, 2022
1 parent fdab2e7 commit 76d14ff
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 22 deletions.
23 changes: 18 additions & 5 deletions hoge.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
import torch
import numpy as np

tensor1 = torch.randn(1, 3)
print(tensor1)
tens = torch.randn((2, 3, 5))
idx = np.zeros((2, 5))
for b in range(2):
id = np.arange(5, dtype=int)
id = np.random.permutation(id)
idx[b, :] = id
idx = torch.tensor(idx, dtype=int)

tensor2 = torch.randn(1, 3)
print(tensor2)
batch_indices = torch.arange(2, dtype=torch.long)
batch_indices = batch_indices.view(2, 1)
batch_indices = batch_indices.repeat(1, 5)

print(tensor1 + tensor2)
print(idx)
print(tens)
print(tens[batch_indices, :, idx].permute(0, 2, 1))
# print(tens[idx])

# idx = np.repeat(idx, )
# print(idx)
20 changes: 14 additions & 6 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def forward(self, x):
tensor: (B, 3, N)
"""
device = x.device
batchsize, _, _ = x.shape
B , _, _ = x.shape
partial = x
features = self.encoder(x)

coarse_output = []
for k in range(0, self.num_surfaces):
rand_grid = torch.rand((batchsize, 2, self.num_output_points//self.num_surfaces),
rand_grid = torch.rand((B, 2, self.num_output_points//self.num_surfaces),
dtype=torch.float32,
device=device)
x = features.unsqueeze(dim=2).repeat(1, 1, rand_grid.shape[2])
Expand All @@ -58,9 +58,9 @@ def forward(self, x):
coarse_output = coarse_output.transpose(1, 2).contiguous() # [B, C, N]

# get id of input partial points and coarse output poitns
id_partial = torch.zeros(batchsize, 1, partial.shape[2], device=device)
id_partial = torch.zeros(B, 1, partial.shape[2], device=device)
x_partial = torch.cat([partial, id_partial], dim=1)
id_coarse = torch.ones(batchsize, 1, coarse_output.shape[2], device=device)
id_coarse = torch.ones(B, 1, coarse_output.shape[2], device=device)
x_coarse = torch.cat([coarse_output, id_coarse], dim=1)
# concatnate partial input points and coarse output points which have identifier index
x = torch.cat([x_partial, x_coarse], dim=2) # [B, 4(xyz+identifier), N(partial+coarse)]
Expand All @@ -74,6 +74,14 @@ def forward(self, x):
elif self.sampling_method == "FPS":
FPS_indices = farthest_point_sampling(x[:, 0:3, :], x_coarse.shape[2])
x = index2point_converter(x, FPS_indices)
elif self.sampling_method == "random":
random_indices = np.zeros((B, x_coarse.shape[2]))
for b in range(B):
id = np.arange(x.shape[2])
id = np.random.permutation(id)
random_indices[b, :] = id[0:x_coarse.shape[2]]
random_indices = torch.tensor(random_indices, dtype=int, device=device)
x = index2point_converter(x, random_indices)

# This is decoder to get fine output
# the num points of fine and coarse is same, but the accuracy of fine is more than coarse
Expand All @@ -84,9 +92,9 @@ def forward(self, x):
return coarse_output, fine_output, loss_mst

if __name__ == "__main__":
input = torch.randn(10, 3, 1024, device="cuda")
input = torch.randn(10, 3, 4000, device="cuda")

model = MSN(1024, 1024, 16, "cuda").to("cuda")
model = MSN(1024, 16384, 32, "random").to("cuda")

coarse_output, fine_output, loss= model(input)
print(coarse_output.shape)
Expand Down
18 changes: 14 additions & 4 deletions module.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def farthest_point_sampling(xyz, num_sumpling):

for i in range(num_sumpling):
centroids[:, i] = farthest # updating list for centroids
centroid = xyz[batch_indicies, :, farthest] # centriud has points cordinate of farthest
centroid = xyz[batch_indicies, :, farthest] # centroid has points cordinate of farthest
centroid = centroid.view(B, C, 1) # reshape for compute distance between centroid and points in xyz
dist = torch.sum((centroid - xyz)**2, dim=1) # computing distance
mask = dist < distance # make boolean list
Expand Down Expand Up @@ -75,8 +75,18 @@ def index2point_converter(xyz, indices):
return new_xyz.permute(0, 2, 1)
# --------------------------------------------------------------------------------------

# --------------------------------------------------------------------------------------
# modules for random point sampling

if __name__ == "__main__":
x = torch.randn(10, 1026, 100)
MLP = SharedMLP(1026, 1026)
out = MLP(x)
# x = torch.randn(10, 1026, 100)
# MLP = SharedMLP(1026, 1026)
# out = MLP(x)
# print(out.shape)


x = torch.randn(10, 4, 20384, device="cuda")
idx = farthest_point_sampling(x[:, 0:3, :], 16384)
out = index2point_converter(x, idx)
print(idx.shape)
print(out.shape)
7 changes: 3 additions & 4 deletions options.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import torch

# ----------------------------------------------------------------------------------------
def make_parser():
Expand All @@ -10,10 +9,10 @@ def make_parser():
parser.add_argument("--num_comp", default=16384, type=int)
parser.add_argument("--num_output_points", default=16384, type=int)
parser.add_argument("--emb_dim", default=1024, type=int)
parser.add_argument("--num_surfaces", default=1024, type=int)
parser.add_argument("--num_surfaces", default=32, type=int)
parser.add_argument("--batch_size", default=6, type=int)
parser.add_argument("--epochs", default=500, type=int)
parser.add_argument("-sm", "--sampling_method", default="FPS", help="You can use MDS if you use pytorch1.2.0")
parser.add_argument("--epochs", default=1000, type=int)
parser.add_argument("-sm", "--sampling_method", default="random", help="You can use MDS if you use pytorch1.2.0 or FPS")
parser.add_argument("--optimizer", default="Adam", help="if you want to choose other optimization, you must change the code.")
parser.add_argument("--lr", default=1e-4, help="learning rate", type=float)
parser.add_argument("--dataset_dir", default="../PCN/data/BridgeCompletion")
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
7 changes: 4 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,12 @@ def val_one_epoch(model, dataloader):
with torch.no_grad():
for i, points in enumerate(tqdm(dataloader, desc="validation")):
comp = points[0]
comp = comp.permute(0, 2, 1)
partial = points[1]
partial = partial.permute(0, 2, 1)

# prediction
coarse, fine, loss_expantion= model(partial)
partial = partial.permute(0, 2, 1)
_, fine, _ = model(partial)
fine = fine.permute(0, 2, 1) # [B, N, 3]
# get chamfer distance loss
emd_fine, _ = emd_loss(fine, comp, eps, iters)
emd_fine = torch.sqrt(emd_fine).mean(dim=1)
Expand Down

0 comments on commit 76d14ff

Please sign in to comment.