Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
陈安沛 committed Mar 17, 2022
1 parent f791678 commit 1e0c7af
Show file tree
Hide file tree
Showing 16 changed files with 3,027 additions and 1 deletion.
62 changes: 61 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,61 @@
# TensoRF
# TensoRF
## [Project page](https://apchenstu.github.io/TensoRF/) | [Paper](https://arxiv.org/abs/2103.15595)
This repository contains a pytorch implementation for the paper: [TensoRF: Tensorial Radiance Fields](https://arxiv.org/abs/2103.15595). Our work present a novel approach to model and reconstruct radiance fields, which achieves super
**fast** training process, **compact** memory footprint and **state-of-the-art** rendering quality.<br><br>

xxx.mp4

## Installation

#### Tested on Ubuntu 20.04 + Pytorch 1.10.1

Install environment:
```
conda create -n TensoRF python=3.8
conda activate TensoRF
pip install torch torchvision
pip install tqdm scikit-image opencv-python configargparse lpips imageio-ffmpeg
```


## Dataset
* [Synthetic-NeRF](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
* [Synthetic-NSVF](https://dl.fbaipublicfiles.com/nsvf/dataset/Synthetic_NSVF.zip)
* [Tanks&Temples](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip)
* [Forward-facing](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)



## Training
The training script is in `train.py`, we have provided command list in `run_batch.py` to reproduce our results, please note:

`dataset_name`, choices = ['blender', 'llff', 'nsvf', 'dtu','tankstemple'];

`shadingMode`, choices = ['MLP_PE', 'SH'];

`n_lamb_sigma` and `n_lamb_sh` are string type refer to the basis number of density and appearance along XYZ
dimension;

`N_voxel_init` and `N_voxel_final` control the resolution of matrix and vector;

`N_vis` and `vis_every` control the visualization during training;


You need to set `--render_test 1`/`--render_path 1` if you want to render testing views or path after training.

More options refer to the `opt.py`.

### For pretrained checkpoints and results please see:
[https://1drv.ms/u/s!Ard0t_p4QWIMgQ2qSEAs7MUk8hVw?e=dc6hBm](https://1drv.ms/u/s!Ard0t_p4QWIMgQ2qSEAs7MUk8hVw?e=dc6hBm),



## Rendering
You can just simply pass `--render_only 1` and `--ckpt path/to/your/checkpoint` to render images from a pre-trained
checkpoint.

## Citation
If you find our code or paper helps, please consider citing:
```
xxx
```
11 changes: 11 additions & 0 deletions dataLoader/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .llff import LLFFDataset
from .blender import BlenderDataset
from .nsvf import NSVF
from .tankstemple import TanksTempleDataset



dataset_dict = {'blender': BlenderDataset,
'llff':LLFFDataset,
'tankstemple':TanksTempleDataset,
'nsvf':NSVF}
141 changes: 141 additions & 0 deletions dataLoader/blender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import torch,cv2
from torch.utils.data import Dataset
import json
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T


from .ray_utils import *


class BlenderDataset(Dataset):
def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1):

self.N_vis = N_vis
self.root_dir = datadir
self.split = split
self.is_stack = is_stack
self.img_wh = (int(800/downsample),int(800/downsample))
self.define_transforms()

self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])
self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
self.read_meta()
self.define_proj_mat()

self.white_bg = True
self.near_far = [2.0,6.0]

self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
self.downsample=downsample
# if device is not None:
# self.all_rgbs = self.all_rgbs.to(device)
# self.all_rays = self.all_rays.to(device)

def read_depth(self, filename):
depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
return depth

def read_meta(self):

with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
self.meta = json.load(f)
# with open(os.path.join(self.root_dir, f"transforms_train.json"), 'r') as f:
# self.meta = json.load(f)

w, h = self.img_wh
self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh


# ray directions for all pixels, same for all images (same H, W, focal)
self.directions = get_ray_directions(h, w, [self.focal,self.focal]) # (h, w, 3)
self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
self.intrinsics = torch.tensor([[self.focal,0,w/2],[0,self.focal,h/2],[0,0,1]]).float()

self.image_paths = []
self.poses = []
self.all_rays = []
self.all_rgbs = []
self.all_masks = []
self.all_depth = []
self.downsample=1.0

img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#

frame = self.meta['frames'][i]
pose = np.array(frame['transform_matrix']) @ self.blender2opencv
c2w = torch.FloatTensor(pose)
self.poses += [c2w]

image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
self.image_paths += [image_path]
img = Image.open(image_path)

if self.downsample!=1.0:
img = img.resize(self.img_wh, Image.LANCZOS)
img = self.transform(img) # (4, h, w)
img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA
img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
self.all_rgbs += [img]

# if self.split=='train':
# depth = self.read_depth(os.path.join(self.root_dir, f"{frame['file_path']}_depth.pfm"))
# depth = torch.from_numpy(cv2.resize(depth, self.img_wh))
# self.all_depth += [depth]

rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
# vdirs = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
# t_minmax = aabb(rays_o, vdirs, self.scene_bbox)
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8)
# self.near * torch.ones_like(rays_o[:, :1]),
# self.far * torch.ones_like(rays_o[:, :1])
# self.all_masks += []


self.poses = torch.stack(self.poses)
if not self.is_stack:
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)

# self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3)
else:
self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
# self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3)


def define_transforms(self):
self.transform = T.ToTensor()

def define_proj_mat(self):
self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]

def world2ndc(self,points,lindisp=None):
device = points.device
return (points - self.center.to(device)) / self.radius.to(device)

def __len__(self):
return len(self.all_rgbs)

def __getitem__(self, idx):

if self.split == 'train': # use data in the buffers
sample = {'rays': self.all_rays[idx],
'rgbs': self.all_rgbs[idx]}

else: # create data for each image separately

img = self.all_rgbs[idx]
rays = self.all_rays[idx]
mask = self.all_masks[idx] # for quantity evaluation

sample = {'rays': rays,
'rgbs': img,
'mask': mask}
return sample
Loading

0 comments on commit 1e0c7af

Please sign in to comment.