Skip to content

Commit

Permalink
fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Jan 26, 2022
1 parent 5c0c61e commit 2ee63d3
Show file tree
Hide file tree
Showing 26 changed files with 1,287 additions and 224 deletions.
3 changes: 1 addition & 2 deletions hashencoder/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

_backend = load(name='_hash_encoder',
extra_cflags=['-O3'], # '-std=c++17'
extra_cuda_cflags=['-O3'], # '-arch=sm_70'
extra_cuda_cflags=['-O3', '-arch=sm_70'], # '-arch=sm_70'
sources=[os.path.join(_src_path, 'src', f) for f in [
'hashencoder.cpp',
'hashencoder.cu',
'bindings.cpp',
]],
Expand Down
23 changes: 17 additions & 6 deletions hashencoder/hashgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd

from .backend import _backend

class _hash_encode(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.half)
#@custom_fwd
def forward(ctx, inputs, embeddings, offsets, base_resolution, calc_grad_inputs=False):
# inputs: [B, D], float in [0, 1]
# embeddings: [sO, C], float
Expand All @@ -23,12 +26,12 @@ def forward(ctx, inputs, embeddings, offsets, base_resolution, calc_grad_inputs=
C = embeddings.shape[1] # embedding dim for each level
H = base_resolution # base resolution

outputs = torch.zeros(B, L * C, device=inputs.device)
outputs = torch.zeros(B, L * C, device=inputs.device, dtype=inputs.dtype)

if calc_grad_inputs:
dy_dx = torch.zeros(B, L * D * C).to(inputs.device)
dy_dx = torch.zeros(B, L * D * C).to(inputs.device, dtype=inputs.dtype)
else:
dy_dx = torch.zeros(1).to(inputs.device)
dy_dx = torch.zeros(1).to(inputs.device, dtype=inputs.dtype)

_backend.hash_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, H, calc_grad_inputs, dy_dx)

Expand All @@ -39,6 +42,7 @@ def forward(ctx, inputs, embeddings, offsets, base_resolution, calc_grad_inputs=
return outputs

@staticmethod
@custom_bwd
def backward(ctx, grad):
# grad: [B, L * C]

Expand All @@ -53,7 +57,7 @@ def backward(ctx, grad):
if calc_grad_inputs:
grad_inputs = torch.zeros_like(inputs)
else:
grad_inputs = torch.zeros(1).to(inputs.device)
grad_inputs = torch.zeros(1).to(inputs.device, dtype=inputs.dtype)

_backend.hash_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, H, calc_grad_inputs, dy_dx, grad_inputs)

Expand Down Expand Up @@ -108,12 +112,19 @@ def forward(self, inputs, size=1, calc_grad_inputs=False):
# inputs: [..., input_dim], normalized real world positions in [-size, size]
# return: [..., num_levels * level_dim]

if inputs.min().item() < -size or inputs.max().item() > size:
raise ValueError(f'HashGrid encoder: inputs range [{inputs.min().item()}, {inputs.max().item()}] not in [{-size}, {size}]!')

inputs = (inputs + size) / (2 * size) # map to [0, 1]

#print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())

prefix_shape = list(inputs.shape[:-1])
inputs = inputs.reshape(-1, self.input_dim)
inputs = inputs.view(-1, self.input_dim) # this consumes most time ?????

outputs = hash_encode(inputs, self.embeddings, self.offsets, self.base_resolution, calc_grad_inputs)
outputs = outputs.reshape(prefix_shape + [self.output_dim])
outputs = outputs.view(prefix_shape + [self.output_dim])

#print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())

return outputs
56 changes: 0 additions & 56 deletions hashencoder/src/hashencoder.cpp

This file was deleted.

Loading

0 comments on commit 2ee63d3

Please sign in to comment.