Skip to content

Commit

Permalink
update opt and readme
Browse files Browse the repository at this point in the history
  • Loading branch information
陈安沛 committed Mar 18, 2022
1 parent 1e0c7af commit b191a56
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 133 deletions.
34 changes: 28 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# TensoRF
## [Project page](https://apchenstu.github.io/TensoRF/) | [Paper](https://arxiv.org/abs/2103.15595)
## [Project page](https://apchenstu.github.io/TensoRF/) | [Paper](https://arxiv.org/abs/2203.09517)
This repository contains a pytorch implementation for the paper: [TensoRF: Tensorial Radiance Fields](https://arxiv.org/abs/2103.15595). Our work present a novel approach to model and reconstruct radiance fields, which achieves super
**fast** training process, **compact** memory footprint and **state-of-the-art** rendering quality.<br><br>

xxx.mp4
https://github.com/apchenstu/apchenstu.github.io/blob/master/TensoRF/video/train_process.mp4

## Installation

Expand All @@ -26,13 +26,22 @@ pip install tqdm scikit-image opencv-python configargparse lpips imageio-ffmpeg



## Training
The training script is in `train.py`, we have provided command list in `run_batch.py` to reproduce our results, please note:
## Quick Start
The training script is in `train.py`, to train a TensoRF:

```
python train.py --config configs/lego.txt
```


we provide a few examples in the configuration folder, please note:

`dataset_name`, choices = ['blender', 'llff', 'nsvf', 'dtu','tankstemple'];

`shadingMode`, choices = ['MLP_PE', 'SH'];

`model_name`, choices = ['TensorVMSplit', 'TensorCP'], corresponding to the VM and CP decomposition;

`n_lamb_sigma` and `n_lamb_sh` are string type refer to the basis number of density and appearance along XYZ
dimension;

Expand All @@ -51,11 +60,24 @@ More options refer to the `opt.py`.


## Rendering

```
python train.py --config configs/lego.txt --ckpt path/to/your/checkpoint --render_only 1 --render_test 1
```

You can just simply pass `--render_only 1` and `--ckpt path/to/your/checkpoint` to render images from a pre-trained
checkpoint.
checkpoint. You may also need to specify what you want to render, like `--render_test 1`, `--render_train 1` or `--render_path 1`.
The rendering results are located in your checkpoint folder.

## Citation
If you find our code or paper helps, please consider citing:
```
xxx
@misc{TensoRF,
title={TensoRF: Tensorial Radiance Fields},
author={Anpei Chen and Zexiang Xu and Andreas Geiger and and Jingyi Yu and Hao Su},
year={2022},
eprint={2203.09517},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
32 changes: 32 additions & 0 deletions configs/flower.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

dataset_name = llff
datadir = ./data/nerf_llff_data/flower
expname = tensorf_flower
basedir = ./log

n_iters = 25000
batch_size = 4096

N_voxel_init = 2097156 # 128**3
N_voxel_final = 262144000 # 640**3
upsamp_list = [2000,3000,4000,5500]
update_AlphaMask_list = [2500]

N_vis = 5
vis_every = 10000

render_test = 1
render_path = 1

n_lamb_sigma = [16,4,4]
n_lamb_sh = [48,12,12]

shadingMode = MLP_Fea
fea2denseAct = relu

view_pe = 0
fea_pe = 0

TV_weight_density = 1.0
TV_weight_app = 1.0

31 changes: 31 additions & 0 deletions configs/lego.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

dataset_name = blender
datadir = ./data/nerf_synthetic/lego
expname = tensorf_lego
basedir = ./log

n_iters = 30000
batch_size = 4096

N_voxel_init = 2097156 # 128**3
N_voxel_final = 27000000 # 300**3
upsamp_list = [2000,3000,4000,5500,7000]
update_AlphaMask_list = [2000,4000]

N_vis = 5
vis_every = 10000

render_test = 1

n_lamb_sigma = [16,16,16]
n_lamb_sh = [48,48,48]

shadingMode = MLP_Fea
fea2denseAct = softplus

view_pe = 2
fea_pe = 2

L1_weight_inital = 8e-5
L1_weight_rest = 4e-5
rm_weight_mask_thre = 1e-4
32 changes: 32 additions & 0 deletions configs/truck.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@


dataset_name = tankstemple
datadir = ./data/TanksAndTemple/Truck
expname = tensorf_truck
basedir = ./log

n_iters = 30000
batch_size = 4096

N_voxel_init = 2097156 # 128**3
N_voxel_final = 27000000 # 300**3
upsamp_list = [2000,3000,4000,5500,7000]
update_AlphaMask_list = [2000,4000]

N_vis = 5
vis_every = 10000

render_test = 1

n_lamb_sigma = [16,16,16]
n_lamb_sh = [48,48,48]

shadingMode = MLP_Fea
fea2denseAct = softplus

view_pe = 2
fea_pe = 2

TV_weight_density = 0.1
TV_weight_app = 0.01

31 changes: 31 additions & 0 deletions configs/wineholder.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

dataset_name = nsvf
datadir = ./data/Synthetic_NSVF/Wineholder
expname = tensorf_Wineholder
basedir = ./log

n_iters = 30000
batch_size = 4096

N_voxel_init = 2097156 # 128**3
N_voxel_final = 27000000 # 300**3
upsamp_list = [2000,3000,4000,5500,7000]
update_AlphaMask_list = [2000,4000]

N_vis = 5
vis_every = 10000

render_test = 1

n_lamb_sigma = [16,16,16]
n_lamb_sh = [48,48,48]

shadingMode = MLP_Fea
fea2denseAct = softplus

view_pe = 2
fea_pe = 2

L1_weight_inital = 8e-5
L1_weight_rest = 4e-5
rm_weight_mask_thre = 1e-4
18 changes: 2 additions & 16 deletions dataLoader/blender.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis
self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
self.downsample=downsample
# if device is not None:
# self.all_rgbs = self.all_rgbs.to(device)
# self.all_rays = self.all_rays.to(device)

def read_depth(self, filename):
depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
Expand All @@ -43,8 +40,6 @@ def read_meta(self):

with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
self.meta = json.load(f)
# with open(os.path.join(self.root_dir, f"transforms_train.json"), 'r') as f:
# self.meta = json.load(f)

w, h = self.img_wh
self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
Expand Down Expand Up @@ -83,19 +78,10 @@ def read_meta(self):
img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA
img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
self.all_rgbs += [img]

# if self.split=='train':
# depth = self.read_depth(os.path.join(self.root_dir, f"{frame['file_path']}_depth.pfm"))
# depth = torch.from_numpy(cv2.resize(depth, self.img_wh))
# self.all_depth += [depth]


rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
# vdirs = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
# t_minmax = aabb(rays_o, vdirs, self.scene_bbox)
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8)
# self.near * torch.ones_like(rays_o[:, :1]),
# self.far * torch.ones_like(rays_o[:, :1])
# self.all_masks += []
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)


self.poses = torch.stack(self.poses)
Expand Down
23 changes: 1 addition & 22 deletions models/tensoRF.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,7 @@ def vectorDiffs(self, vector_comps):
n_comp, n_size = vector_comps[idx].shape[1:-1]

dotp = torch.matmul(vector_comps[idx].view(n_comp,n_size), vector_comps[idx].view(n_comp,n_size).transpose(-1,-2))
# print(vector_comps[idx].shape, vector_comps[idx].view(n_comp,n_size).transpose(-1,-2).shape, dotp.shape)
non_diagonal = dotp.view(-1)[1:].view(n_comp-1, n_comp+1)[...,:-1]
# print(vector_comps[idx].shape, vector_comps[idx].view(n_comp,n_size).transpose(-1,-2).shape, dotp.shape,non_diagonal.shape)
total = total + torch.mean(torch.abs(non_diagonal))
return total

Expand Down Expand Up @@ -247,15 +245,6 @@ def compute_appfeature(self, xyz_sampled):

return self.basis_mat((plane_coef_point * line_coef_point).T)

# idx = [130, 8, 129, 113, 72, 26, 114, 137, 66, 101, 134, 46, 102, 127,
# 117, 20, 4, 118, 43, 103, 40, 3, 122, 10, 121, 86, 33, 38,
# 58, 123, 138, 11, 12, 35, 48, 18, 142, 110, 9, 25, 24, 56,
# 39, 61, 22, 59, 141, 115, 105, 45, 13, 7, 140, 64, 32, 47,
# 37, 21, 116, 67, 139, 128, 51, 49, 2, 1, 107, 57, 69, 71,
# 68, 89, 5, 95, 27, 44, 28, 74, 136, 84, 16, 63, 132, 6,
# 36, 85, 143, 55, 106, 17, 19, 97, 93, 96, 34, 29, 52, 30,
# 42, 88, 76, 112, 87, 109, 80, 31, 54, 91]
# return torch.mm((plane_coef_point * line_coef_point)[idx].T, self.basis_mat.weight[:,idx].T)


@torch.no_grad()
Expand All @@ -278,13 +267,6 @@ def upsample_volume_grid(self, res_target):
self.app_plane, self.app_line = self.up_sampling_VM(self.app_plane, self.app_line, res_target)
self.density_plane, self.density_line = self.up_sampling_VM(self.density_plane, self.density_line, res_target)

# for plane in self.density_plane:
# plane.sub(0.1)

# scale = res_target[0]/self.line_coef.shape[2] #assuming xyz have the same scale
# plane_coef = F.interpolate(self.plane_coef.detach().data, scale_factor=scale, mode='bilinear',align_corners=True)
# line_coef = F.interpolate(self.line_coef.detach().data, size=(res_target[0],1), mode='bilinear',align_corners=True)
# self.plane_coef, self.line_coef = torch.nn.Parameter(plane_coef), torch.nn.Parameter(line_coef)
self.update_stepSize(res_target)
print(f'upsamping to {res_target}')

Expand All @@ -297,7 +279,6 @@ def shrink(self, new_aabb):
# print(t_l, b_r,self.alphaMask.alpha_volume.shape)
t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1
b_r = torch.stack([b_r, self.gridSize]).amin(0)
print('================>',t_l, b_r)

for i in range(len(self.vecMode)):
mode0 = self.vecMode[i]
Expand All @@ -314,9 +295,7 @@ def shrink(self, new_aabb):
self.app_plane[i] = torch.nn.Parameter(
self.app_plane[i].data[...,t_l[mode1]:b_r[mode1],t_l[mode0]:b_r[mode0]]
)
# if self.alphaMask is not None:
# alpha_volume = self.alphaMask.alpha_volume[:, :, t_l[2]:b_r[2], t_l[1]:b_r[1],t_l[0]:b_r[0]]
# self.alphaMask = AlphaGridMask(self.device, new_aabb, alpha_volume)


if not torch.all(self.alphaMask.gridSize == self.gridSize):
t_l_r, b_r_r = t_l / (self.gridSize-1), (b_r-1) / (self.gridSize-1)
Expand Down
2 changes: 1 addition & 1 deletion models/tensorBase.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn
import torch.nn.functional as F
from sh import eval_sh_bases
from .sh import eval_sh_bases
import numpy as np
import time

Expand Down
12 changes: 7 additions & 5 deletions opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

def config_parser(cmd=None):
parser = configargparse.ArgumentParser()
parser.add_argument('--config', is_config_file=True,
help='config file path')
parser.add_argument("--expname", type=str,
help='experiment name')
parser.add_argument("--basedir", type=str, default='./log',
help='where to store ckpts and logs')
parser.add_argument("--add_timestamp", type=int, default=1,
parser.add_argument("--add_timestamp", type=int, default=0,
help='add timestamp to dir')
parser.add_argument("--datadir", type=str, default='./data/llff/fern',
help='input data directory')
Expand Down Expand Up @@ -55,8 +57,8 @@ def config_parser(cmd=None):

# model
# volume options
parser.add_argument("--n_lamb_sigma", type=str, default='[16,4,4]')
parser.add_argument("--n_lamb_sh", type=str, default='[48,12,12]')
parser.add_argument("--n_lamb_sigma", type=int, action="append")
parser.add_argument("--n_lamb_sh", type=int, action="append")
parser.add_argument("--data_dim_color", type=int, default=27)

parser.add_argument("--rm_weight_mask_thre", type=float, default=0.0001,
Expand Down Expand Up @@ -114,8 +116,8 @@ def config_parser(cmd=None):
parser.add_argument('--N_voxel_final',
type=int,
default=300**3)
parser.add_argument("--upsamp_list", type=str, default="[2000, 3000, 4000, 5500]")
parser.add_argument("--update_AlphaMask_list", type=str, default="[2000, 3000, 4000, 5500,7000]")
parser.add_argument("--upsamp_list", type=int, action="append")
parser.add_argument("--update_AlphaMask_list", type=int, action="append")

parser.add_argument('--idx_view',
type=int,
Expand Down
Loading

0 comments on commit b191a56

Please sign in to comment.