Skip to content

Commit

Permalink
fix atomicAdd bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Jan 29, 2022
1 parent 74dfd12 commit bd92b9c
Show file tree
Hide file tree
Showing 14 changed files with 58 additions and 34 deletions.
5 changes: 4 additions & 1 deletion hashencoder/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

_backend = load(name='_hash_encoder',
extra_cflags=['-O3'], # '-std=c++17'
extra_cuda_cflags=['-O3'], # '-arch=sm_70'
extra_cuda_cflags=[
'-O3',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', # undefine flags, necessary!
], # '-arch=sm_70'
sources=[os.path.join(_src_path, 'src', f) for f in [
'hashencoder.cu',
'bindings.cpp',
Expand Down
3 changes: 3 additions & 0 deletions hashencoder/hashgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def __init__(self, input_dim=3, num_levels=16, level_dim=2, base_resolution=16,
self.base_resolution = base_resolution
self.output_dim = num_levels * level_dim

if level_dim % 2 != 0:
print('[WARN] detected HashGrid level_dim % 2 != 0, which will cause very slow backward is also enabled fp16! (maybe fix later)')

# allocate parameters
self.offsets = []
offset = 0
Expand Down
18 changes: 15 additions & 3 deletions hashencoder/src/hashencoder.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@


// requires CUDA >= 10 and ARCH >= 70
// this is so damnly slow...
static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
return atomicAdd(reinterpret_cast<__half*>(address), val);
}
Expand Down Expand Up @@ -233,9 +234,20 @@ __global__ void kernel_grid_backward(

uint32_t index = get_grid_index<D, C>(ch, hashmap_size, resolution, pos_grid_local);

#pragma unroll
for (uint32_t c = 0; c < N_C; c++) {
atomicAdd(&grad_grid[index + c], w * grad[c]);

// atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
// TODO: use float which is better than __half, if N_C % 2 != 0
if (N_C % 2 == 0) {
#pragma unroll
for (uint32_t c = 0; c < N_C; c += 2) {
__half2 v = {(__half)((float)grad[c] * w), (__half)((float)grad[c + 1] * w)};
atomicAdd((__half2*)&grad_grid[index + c], v);
}
} else {
#pragma unroll
for (uint32_t c = 0; c < N_C; c++) {
atomicAdd(&grad_grid[index + c], w * grad[c]);
}
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions nerf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ class NeRFNetwork(nn.Module):
def __init__(self,
encoding="hashgrid",
encoding_dir="sphere_harmonics",
num_layers=3,
num_layers=2,
hidden_dim=64,
geo_feat_dim=15,
num_layers_color=4,
num_layers_color=3,
hidden_dim_color=64,
density_grid_size=-1, # density grid size
):
Expand Down
6 changes: 3 additions & 3 deletions nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def __init__(self,
device=None, # device to use, usually setting to None is OK. (auto choose device)
mute=False, # whether to mute all print
fp16=False, # amp optimize level
use_grad_scaler=False, # use amp grad scaler
eval_interval=1, # eval once every $ epoch
max_keep_ckpt=2, # max num of saved ckpts in disk
workspace='workspace', # workspace to save logs & ckpts
Expand All @@ -163,7 +162,6 @@ def __init__(self,
self.workspace = workspace
self.ema_decay = ema_decay
self.fp16 = fp16
self.use_grad_scaler = use_grad_scaler and fp16
self.best_mode = best_mode
self.use_loss_as_metric = use_loss_as_metric
self.max_keep_ckpt = max_keep_ckpt
Expand Down Expand Up @@ -200,7 +198,7 @@ def __init__(self,
else:
self.ema = None

self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_grad_scaler)
self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)

# variable init
self.epoch = 1
Expand Down Expand Up @@ -466,9 +464,11 @@ def train_one_epoch(self, loader):
#print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))

#with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p:

self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()

#print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))

if self.ema is not None:
Expand Down
5 changes: 3 additions & 2 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@ SDF | NeRF
* HashGrid Encoder
- [x] basic pytorch CUDA extension
- [x] fp16 support
- [ ] improve performance (currently the slowest part in nerf inference)
- [x] improve performance
* Experiments
- SDF
- [x] baseline
- [ ] better SDF calculation (especially for non-watertight meshes)
- NeRF
- [x] baseline (although much slower)
- [ ] fp16 with GradScaler enabled leads to slower backward in training...
- [ ] ray marching in CUDA.

# News
* 1.30:
* fixed atomicAdd() to use __half2 in HashGrid Encoder's backward, now the training speed with fp16 is as expected!
* 1.29:
* finished an experimental binding of fully-fused MLP.
* replace SHEncoder with a CUDA implementation.
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_nerf.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python train_nerf.py data/fox --workspace trial_nerf_ff --fp16 --ff #--cuda_raymarching
OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python train_nerf.py data/fox --workspace trial_nerf --fp16 --ff #--cuda_raymarching
2 changes: 1 addition & 1 deletion scripts/run_sdf.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python train_sdf.py data/armadillo.obj --workspace trial_sdf --fp16
OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python train_sdf.py data/armadillo.obj --workspace trial_sdf --fp16 --ff
4 changes: 1 addition & 3 deletions sdf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(self,
device=None, # device to use, usually setting to None is OK. (auto choose device)
mute=False, # whether to mute all print
fp16=False, # amp optimize level
use_grad_scaler=True, # use amp grad scaler
eval_interval=1, # eval once every $ epoch
max_keep_ckpt=2, # max num of saved ckpts in disk
workspace='workspace', # workspace to save logs & ckpts
Expand All @@ -104,7 +103,6 @@ def __init__(self,
self.workspace = workspace
self.ema_decay = ema_decay
self.fp16 = fp16
self.use_grad_scaler = use_grad_scaler and fp16
self.best_mode = best_mode
self.use_loss_as_metric = use_loss_as_metric
self.max_keep_ckpt = max_keep_ckpt
Expand Down Expand Up @@ -141,7 +139,7 @@ def __init__(self,
else:
self.ema = None

self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_grad_scaler)
self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)

# variable init
self.epoch = 1
Expand Down
13 changes: 10 additions & 3 deletions test_nerf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

from nerf.network import NeRFNetwork
from nerf.network_ff import NeRFNetwork as NeRFNetwork_FF
from nerf.provider import NeRFDataset
from nerf.utils import *

Expand All @@ -16,20 +17,26 @@
parser.add_argument('--num_steps', type=int, default=256)
parser.add_argument('--upsample_steps', type=int, default=256)
parser.add_argument('--max_ray_batch', type=int, default=4096) # lower if OOM
parser.add_argument('--fp16', action='store_true')
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--ff', action='store_true', help="use fully-fused MLP")

parser.add_argument('--radius', type=float, default=2, help="assume the camera is located on sphere(0, radius))")
parser.add_argument('--bound', type=float, default=2, help="assume the scene is bounded in sphere(0, size)")
parser.add_argument('--bound', type=float, default=2, help="assume the scene is bounded in box(-size, size)")

parser.add_argument('--cuda_raymarching', action='store_true', help="use CUDA raymarching instead of pytorch (unstable now)")

opt = parser.parse_args()

print(opt)

if opt.ff:
assert opt.fp16, "fully-fused mode must be used with fp16 mode"
Network = NeRFNetwork_FF
else:
Network = NeRFNetwork
seed_everything(opt.seed)

model = NeRFNetwork(
model = Network(
encoding="hashgrid", encoding_dir="sphere_harmonics",
num_layers=2, hidden_dim=64, geo_feat_dim=15, num_layers_color=3, hidden_dim_color=64,
density_grid_size=128 if opt.cuda_raymarching else -1,
Expand Down
12 changes: 9 additions & 3 deletions test_sdf.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
from sdf.netowrk import SDFNetwork
from sdf.netowrk_ff import SDFNetwork as SDFNetwork_FF
from sdf.utils import *

import argparse

CLIP_SDF = None

if __name__ == '__main__':

parser = argparse.ArgumentParser()
parser.add_argument('path', type=str)
parser.add_argument('--workspace', type=str, default='workspace')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--fp16', action='store_true')
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--ff', action='store_true', help="use fully-fused MLP")

opt = parser.parse_args()

if opt.ff:
assert opt.fp16, "fully-fused mode must be used with fp16 mode"
Network = SDFNetwork_FF
else:
Network = SDFNetwork
seed_everything(opt.seed)

model = SDFNetwork(encoding="hashgrid", clip_sdf=CLIP_SDF)
model = Network(encoding="hashgrid")
#model = SDFNetwork(encoding="frequency", num_layers=8, skips=[4], hidden_dim=256, clip_sdf=CLIP_SDF)

print(model)
Expand Down
5 changes: 1 addition & 4 deletions testing/test_ffmlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,7 @@ def forward(self, x):

starter.record()
y2 = net1(x2)
ender.record()
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender)
print(f'time1 (fp32 train) = {curr_time}')
ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'time1 (fp32 train) = {curr_time}')

starter.record()
y2.sum().backward()
Expand Down
9 changes: 3 additions & 6 deletions train_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
parser.add_argument('--num_steps', type=int, default=256)
parser.add_argument('--upsample_steps', type=int, default=256)
parser.add_argument('--max_ray_batch', type=int, default=4096)
parser.add_argument('--fp16', action='store_true')
parser.add_argument('--ff', action='store_true')
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--ff', action='store_true', help="use fully-fused MLP")

parser.add_argument('--radius', type=float, default=2, help="assume the camera is located on sphere(0, radius))")
parser.add_argument('--bound', type=float, default=2, help="assume the scene is bounded in box(-size, size)")
Expand All @@ -33,10 +33,8 @@

if opt.ff:
assert opt.fp16, "fully-fused mode must be used with fp16 mode"
use_grad_scaler = opt.fp16 # pytorch amp gradscaler lead to 10x slower backward for fp16, why?
Network = NeRFNetwork_FF
else:
use_grad_scaler = opt.fp16 # pytorch amp gradscaler lead to 10x slower backward for fp16, why?
Network = NeRFNetwork

seed_everything(opt.seed)
Expand Down Expand Up @@ -66,12 +64,11 @@

scheduler = lambda optimizer: optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 100, 150], gamma=0.33)

trainer = Trainer('ngp', vars(opt), model, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95, fp16=opt.fp16, use_grad_scaler=use_grad_scaler, lr_scheduler=scheduler, use_checkpoint='latest', eval_interval=1)
trainer = Trainer('ngp', vars(opt), model, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint='latest', eval_interval=1)

trainer.train(train_loader, valid_loader, 200)

# test dataset
test_dataset = NeRFDataset(opt.path, 'test', radius=opt.radius, n_test=10)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1)

trainer.test(test_loader)
4 changes: 2 additions & 2 deletions train_sdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
parser.add_argument('path', type=str)
parser.add_argument('--workspace', type=str, default='workspace')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--fp16', action='store_true')
parser.add_argument('--ff', action='store_true')
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--ff', action='store_true', help="use fully-fused MLP")

opt = parser.parse_args()

Expand Down

0 comments on commit bd92b9c

Please sign in to comment.