Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Jan 21, 2022
1 parent eb58820 commit 16ef524
Show file tree
Hide file tree
Showing 11 changed files with 1,391 additions and 18 deletions.
68 changes: 68 additions & 0 deletions nerf/network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from encoding import get_encoder


class NeRFNetwork(nn.Module):
def __init__(self,
encoding="hashgrid",
encoding_view="frequency",
num_layers=3,
skips=[],
hidden_dim=64,
clip_sdf=None,
):
super().__init__()

self.num_layers = num_layers
self.skips = skips
self.hidden_dim = hidden_dim
self.clip_sdf = clip_sdf

self.encoder, self.in_dim = get_encoder(encoding)

backbone = []

for l in range(num_layers):
if l == 0:
in_dim = self.in_dim
elif l in self.skips:
in_dim = self.hidden_dim + self.in_dim
else:
in_dim = self.hidden_dim

if l == num_layers - 1:
out_dim = 1
else:
out_dim = self.hidden_dim

backbone.append(nn.Linear(in_dim, out_dim, bias=False))

self.backbone = nn.ModuleList(backbone)


def forward(self, x):
# x: [B, 3]

#print('forward: x', x.shape, x.min().item(), x.max().item())

x = self.encoder(x)

#print('forward: enc(x)', x.shape, x.min().item(), x.max().item())

h = x
for l in range(self.num_layers):
if l in self.skips:
h = torch.cat([h, x], dim=-1)
h = self.backbone[l](h)
if l != self.num_layers - 1:
h = F.relu(h, inplace=True)

if self.clip_sdf is not None:
h = h.clamp(-self.clip_sdf, self.clip_sdf)

#print('forward: y', h.shape, h.min().item(), h.max().item())

return h
46 changes: 46 additions & 0 deletions nerf/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
import time
import glob
import numpy as np

import cv2
from PIL import Image

import torch
from torch.utils.data import DataLoader, Dataset

# NeRF dataset
import json


class NeRFDataset(Dataset):
def __init__(self, path):
super().__init__()

self.path = path

# load cameras
transform_path = os.path.join(self.path, 'transforms.json')
with open(transform_path, 'r') as f:
transform = json.load(f)

self.images = []
self.cameras = []
self.intrinsics = []





def __len__(self):
return len(self.images)

def __getitem__(self, index):



results = {

}

return results
Loading

0 comments on commit 16ef524

Please sign in to comment.