Skip to content

Commit

Permalink
add tiled gridencoder
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Apr 13, 2022
1 parent fd5dd43 commit 213b0b7
Show file tree
Hide file tree
Showing 15 changed files with 85 additions and 67 deletions.
8 changes: 6 additions & 2 deletions encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,12 @@ def get_encoder(encoding, input_dim=3,
encoder = SHEncoder(input_dim=input_dim, degree=degree)

elif encoding == 'hashgrid':
from hashencoder import HashEncoder
encoder = HashEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution)
from gridencoder import GridEncoder
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash')

elif encoding == 'tiledgrid':
from gridencoder import GridEncoder
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled')

elif encoding == 'ash':
from ashencoder import AshEncoder
Expand Down
1 change: 1 addition & 0 deletions gridencoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .grid import GridEncoder
4 changes: 2 additions & 2 deletions hashencoder/backend.py → gridencoder/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def find_cl_path():
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path

_backend = load(name='_hash_encoder',
_backend = load(name='_grid_encoder',
extra_cflags=c_flags,
extra_cuda_cflags=nvcc_flags,
sources=[os.path.join(_src_path, 'src', f) for f in [
'hashencoder.cu',
'gridencoder.cu',
'bindings.cpp',
]],
)
Expand Down
33 changes: 20 additions & 13 deletions hashencoder/hashgrid.py → gridencoder/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@

from .backend import _backend

class _hash_encode(Function):
_gridtype_to_id = {
'hash': 0,
'tiled': 1,
}

class _grid_encode(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.half)
def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False):
def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0):
# inputs: [B, D], float in [0, 1]
# embeddings: [sO, C], float
# offsets: [L + 1], int
Expand All @@ -35,13 +40,13 @@ def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution,
else:
dy_dx = torch.empty(1, device=inputs.device, dtype=inputs.dtype)

_backend.hash_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, calc_grad_inputs, dy_dx)
_backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, calc_grad_inputs, dy_dx, gridtype)

# permute back to [B, L * C]
outputs = outputs.permute(1, 0, 2).reshape(B, L * C)

ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
ctx.dims = [B, D, C, L, S, H]
ctx.dims = [B, D, C, L, S, H, gridtype]
ctx.calc_grad_inputs = calc_grad_inputs

return outputs
Expand All @@ -52,7 +57,7 @@ def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution,
def backward(ctx, grad):

inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
B, D, C, L, S, H = ctx.dims
B, D, C, L, S, H, gridtype = ctx.dims
calc_grad_inputs = ctx.calc_grad_inputs

# grad: [B, L * C] --> [L, B, C]
Expand All @@ -65,19 +70,19 @@ def backward(ctx, grad):
else:
grad_inputs = torch.zeros(1, device=inputs.device, dtype=inputs.dtype)

_backend.hash_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs)
_backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype)

if calc_grad_inputs:
return grad_inputs, grad_embeddings, None, None, None, None
return grad_inputs, grad_embeddings, None, None, None, None, None
else:
return None, grad_embeddings, None, None, None, None
return None, grad_embeddings, None, None, None, None, None


hash_encode = _hash_encode.apply
grid_encode = _grid_encode.apply


class HashEncoder(nn.Module):
def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None):
class GridEncoder(nn.Module):
def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash'):
super().__init__()

# the finest resolution desired at the last level, if provided, overridee per_level_scale
Expand All @@ -91,6 +96,8 @@ def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, b
self.log2_hashmap_size = log2_hashmap_size
self.base_resolution = base_resolution
self.output_dim = num_levels * level_dim
self.gridtype = gridtype
self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"

if level_dim % 2 != 0:
print('[WARN] detected HashGrid level_dim % 2 != 0, which will cause very slow backward is also enabled fp16! (maybe fix later)')
Expand Down Expand Up @@ -121,7 +128,7 @@ def reset_parameters(self):
self.embeddings.data.uniform_(-std, std)

def __repr__(self):
return f"HashEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} base_resolution={self.base_resolution} per_level_scale={self.per_level_scale} params={tuple(self.embeddings.shape)}"
return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} base_resolution={self.base_resolution} per_level_scale={self.per_level_scale} params={tuple(self.embeddings.shape)} gridtype={self.gridtype}"

def forward(self, inputs, bound=1):
# inputs: [..., input_dim], normalized real world positions in [-bound, bound]
Expand All @@ -134,7 +141,7 @@ def forward(self, inputs, bound=1):
prefix_shape = list(inputs.shape[:-1])
inputs = inputs.view(-1, self.input_dim)

outputs = hash_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad)
outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id)
outputs = outputs.view(prefix_shape + [self.output_dim])

#print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
Expand Down
8 changes: 8 additions & 0 deletions gridencoder/src/bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include <torch/extension.h>

#include "gridencoder.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
}
Loading

0 comments on commit 213b0b7

Please sign in to comment.