forked from apchenstu/TensoRF
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
陈安沛
committed
Mar 17, 2022
1 parent
f791678
commit 1e0c7af
Showing
16 changed files
with
3,027 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,61 @@ | ||
# TensoRF | ||
# TensoRF | ||
## [Project page](https://apchenstu.github.io/TensoRF/) | [Paper](https://arxiv.org/abs/2103.15595) | ||
This repository contains a pytorch implementation for the paper: [TensoRF: Tensorial Radiance Fields](https://arxiv.org/abs/2103.15595). Our work present a novel approach to model and reconstruct radiance fields, which achieves super | ||
**fast** training process, **compact** memory footprint and **state-of-the-art** rendering quality.<br><br> | ||
|
||
xxx.mp4 | ||
|
||
## Installation | ||
|
||
#### Tested on Ubuntu 20.04 + Pytorch 1.10.1 | ||
|
||
Install environment: | ||
``` | ||
conda create -n TensoRF python=3.8 | ||
conda activate TensoRF | ||
pip install torch torchvision | ||
pip install tqdm scikit-image opencv-python configargparse lpips imageio-ffmpeg | ||
``` | ||
|
||
|
||
## Dataset | ||
* [Synthetic-NeRF](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) | ||
* [Synthetic-NSVF](https://dl.fbaipublicfiles.com/nsvf/dataset/Synthetic_NSVF.zip) | ||
* [Tanks&Temples](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip) | ||
* [Forward-facing](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) | ||
|
||
|
||
|
||
## Training | ||
The training script is in `train.py`, we have provided command list in `run_batch.py` to reproduce our results, please note: | ||
|
||
`dataset_name`, choices = ['blender', 'llff', 'nsvf', 'dtu','tankstemple']; | ||
|
||
`shadingMode`, choices = ['MLP_PE', 'SH']; | ||
|
||
`n_lamb_sigma` and `n_lamb_sh` are string type refer to the basis number of density and appearance along XYZ | ||
dimension; | ||
|
||
`N_voxel_init` and `N_voxel_final` control the resolution of matrix and vector; | ||
|
||
`N_vis` and `vis_every` control the visualization during training; | ||
|
||
|
||
You need to set `--render_test 1`/`--render_path 1` if you want to render testing views or path after training. | ||
|
||
More options refer to the `opt.py`. | ||
|
||
### For pretrained checkpoints and results please see: | ||
[https://1drv.ms/u/s!Ard0t_p4QWIMgQ2qSEAs7MUk8hVw?e=dc6hBm](https://1drv.ms/u/s!Ard0t_p4QWIMgQ2qSEAs7MUk8hVw?e=dc6hBm), | ||
|
||
|
||
|
||
## Rendering | ||
You can just simply pass `--render_only 1` and `--ckpt path/to/your/checkpoint` to render images from a pre-trained | ||
checkpoint. | ||
|
||
## Citation | ||
If you find our code or paper helps, please consider citing: | ||
``` | ||
xxx | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from .llff import LLFFDataset | ||
from .blender import BlenderDataset | ||
from .nsvf import NSVF | ||
from .tankstemple import TanksTempleDataset | ||
|
||
|
||
|
||
dataset_dict = {'blender': BlenderDataset, | ||
'llff':LLFFDataset, | ||
'tankstemple':TanksTempleDataset, | ||
'nsvf':NSVF} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import torch,cv2 | ||
from torch.utils.data import Dataset | ||
import json | ||
from tqdm import tqdm | ||
import os | ||
from PIL import Image | ||
from torchvision import transforms as T | ||
|
||
|
||
from .ray_utils import * | ||
|
||
|
||
class BlenderDataset(Dataset): | ||
def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1): | ||
|
||
self.N_vis = N_vis | ||
self.root_dir = datadir | ||
self.split = split | ||
self.is_stack = is_stack | ||
self.img_wh = (int(800/downsample),int(800/downsample)) | ||
self.define_transforms() | ||
|
||
self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]]) | ||
self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) | ||
self.read_meta() | ||
self.define_proj_mat() | ||
|
||
self.white_bg = True | ||
self.near_far = [2.0,6.0] | ||
|
||
self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) | ||
self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) | ||
self.downsample=downsample | ||
# if device is not None: | ||
# self.all_rgbs = self.all_rgbs.to(device) | ||
# self.all_rays = self.all_rays.to(device) | ||
|
||
def read_depth(self, filename): | ||
depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800) | ||
return depth | ||
|
||
def read_meta(self): | ||
|
||
with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f: | ||
self.meta = json.load(f) | ||
# with open(os.path.join(self.root_dir, f"transforms_train.json"), 'r') as f: | ||
# self.meta = json.load(f) | ||
|
||
w, h = self.img_wh | ||
self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length | ||
self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh | ||
|
||
|
||
# ray directions for all pixels, same for all images (same H, W, focal) | ||
self.directions = get_ray_directions(h, w, [self.focal,self.focal]) # (h, w, 3) | ||
self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) | ||
self.intrinsics = torch.tensor([[self.focal,0,w/2],[0,self.focal,h/2],[0,0,1]]).float() | ||
|
||
self.image_paths = [] | ||
self.poses = [] | ||
self.all_rays = [] | ||
self.all_rgbs = [] | ||
self.all_masks = [] | ||
self.all_depth = [] | ||
self.downsample=1.0 | ||
|
||
img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis | ||
idxs = list(range(0, len(self.meta['frames']), img_eval_interval)) | ||
for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:# | ||
|
||
frame = self.meta['frames'][i] | ||
pose = np.array(frame['transform_matrix']) @ self.blender2opencv | ||
c2w = torch.FloatTensor(pose) | ||
self.poses += [c2w] | ||
|
||
image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") | ||
self.image_paths += [image_path] | ||
img = Image.open(image_path) | ||
|
||
if self.downsample!=1.0: | ||
img = img.resize(self.img_wh, Image.LANCZOS) | ||
img = self.transform(img) # (4, h, w) | ||
img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA | ||
img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB | ||
self.all_rgbs += [img] | ||
|
||
# if self.split=='train': | ||
# depth = self.read_depth(os.path.join(self.root_dir, f"{frame['file_path']}_depth.pfm")) | ||
# depth = torch.from_numpy(cv2.resize(depth, self.img_wh)) | ||
# self.all_depth += [depth] | ||
|
||
rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) | ||
# vdirs = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) | ||
# t_minmax = aabb(rays_o, vdirs, self.scene_bbox) | ||
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8) | ||
# self.near * torch.ones_like(rays_o[:, :1]), | ||
# self.far * torch.ones_like(rays_o[:, :1]) | ||
# self.all_masks += [] | ||
|
||
|
||
self.poses = torch.stack(self.poses) | ||
if not self.is_stack: | ||
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) | ||
self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) | ||
|
||
# self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3) | ||
else: | ||
self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) | ||
self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) | ||
# self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3) | ||
|
||
|
||
def define_transforms(self): | ||
self.transform = T.ToTensor() | ||
|
||
def define_proj_mat(self): | ||
self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3] | ||
|
||
def world2ndc(self,points,lindisp=None): | ||
device = points.device | ||
return (points - self.center.to(device)) / self.radius.to(device) | ||
|
||
def __len__(self): | ||
return len(self.all_rgbs) | ||
|
||
def __getitem__(self, idx): | ||
|
||
if self.split == 'train': # use data in the buffers | ||
sample = {'rays': self.all_rays[idx], | ||
'rgbs': self.all_rgbs[idx]} | ||
|
||
else: # create data for each image separately | ||
|
||
img = self.all_rgbs[idx] | ||
rays = self.all_rays[idx] | ||
mask = self.all_masks[idx] # for quantity evaluation | ||
|
||
sample = {'rays': rays, | ||
'rgbs': img, | ||
'mask': mask} | ||
return sample |
Oops, something went wrong.