Skip to content

Commit

Permalink
fix hashgrid grad
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Feb 28, 2022
1 parent d36b739 commit 4cb588e
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 1 deletion.
2 changes: 1 addition & 1 deletion hashencoder/hashgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, b
for i in range(num_levels):
resolution = int(np.ceil(base_resolution * per_level_scale ** i))
params_in_level = min(self.max_params, (resolution + 1) ** input_dim) # limit max number
params_in_level = int(params_in_level / 8) * 8 # make divisible
#params_in_level = np.ceil(params_in_level / 8) * 8 # make divisible
offsets.append(offset)
offset += params_in_level
offsets.append(offset)
Expand Down
8 changes: 8 additions & 0 deletions hashencoder/src/hashencoder.cu
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,14 @@ __global__ void kernel_grid_backward(
const float scale = exp2f(level * S) * H - 1.0f;
const uint32_t resolution = (uint32_t)ceil(scale) + 1;

// check input range (should be in [0, 1])
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if (inputs[d] < 0 || inputs[d] > 1) {
return; // grad is init as 0, so we simply return.
}
}

// calculate coordinate
float pos[D];
uint32_t pos_grid[D];
Expand Down
62 changes: 62 additions & 0 deletions testing/test_hashgrid_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# we need check the grad_hash_grid;
import torch
import torch.nn.functional as F
from torch.autograd import gradcheck
import numpy as np
from hashencoder.hashgrid import _hash_encode
import random
import os
# import torch.random as random
device=torch.device(0)
input_dim=3 # 2
num_levels=4 # 1
level_dim=2 # 1
per_level_scale=2
base_resolution=4 # 2
log2_hashmap_size=8 # 4
# inputs , embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False

output_dim = num_levels * level_dim

if level_dim % 2 != 0:
print('[WARN] detected HashGrid level_dim % 2 != 0, which will cause very slow backward is also enabled fp16! (maybe fix later)')

# allocate parameters
offsets = []
offset = 0
max_params = 2 ** log2_hashmap_size
for i in range(num_levels):
resolution = int(np.ceil(base_resolution * per_level_scale ** i))
params_in_level = min(max_params, (resolution + 1) ** input_dim) # limit max number
#params_in_level = np.ceil(params_in_level / 8) * 8 # make divisible
offsets.append(offset)
offset += params_in_level
offsets.append(offset)

print(offsets)

def seed_torch(seed=42):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

#seed_torch()

# parameters
inputs = torch.rand(1, input_dim, dtype= torch.float64, requires_grad=False).to(device)

offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)).to(device)
embeddings = torch.randn(offset, level_dim, dtype=torch.float64, requires_grad=True).to(device) * 0.1

print(inputs)
print(embeddings)


Inputs = (inputs, embeddings, offsets, per_level_scale, base_resolution, inputs.requires_grad)
check_results1 = torch.autograd.gradcheck(_hash_encode.apply, Inputs, eps=1e-2, atol=1e-3, rtol=0.01, fast_mode=False)
print("check_results1", check_results1)

0 comments on commit 4cb588e

Please sign in to comment.