Skip to content

Commit

Permalink
Update tcnn
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom94 committed Mar 23, 2022
1 parent 5c69d09 commit 26b0620
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 59 deletions.
18 changes: 3 additions & 15 deletions include/neural-graphics-primitives/nerf_network.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,7 @@ class NerfNetwork : public tcnn::Network<float, T> {

virtual ~NerfNetwork() { }

void inference_mixed_precision(cudaStream_t stream, const tcnn::GPUMatrixDynamic<float>& input, tcnn::GPUMatrixDynamic<T>& output, bool use_inference_params = true) override {
if (input.layout() != tcnn::CM) {
throw std::runtime_error("NerfNetwork::inference_mixed_precision input must be in column major format.");
}

void inference_mixed_precision_impl(cudaStream_t stream, const tcnn::GPUMatrixDynamic<float>& input, tcnn::GPUMatrixDynamic<T>& output, bool use_inference_params = true) override {
uint32_t batch_size = input.n();
tcnn::GPUMatrixDynamic<T> density_network_input{m_pos_encoding->padded_output_width(), batch_size, stream, m_pos_encoding->preferred_output_layout()};
tcnn::GPUMatrixDynamic<T> rgb_network_input{m_rgb_network_input_width, batch_size, stream, m_dir_encoding->preferred_output_layout()};
Expand Down Expand Up @@ -144,11 +140,7 @@ class NerfNetwork : public tcnn::Network<float, T> {
return m_density_network->padded_output_width();
}

std::unique_ptr<tcnn::Context> forward(cudaStream_t stream, const tcnn::GPUMatrixDynamic<float>& input, tcnn::GPUMatrixDynamic<T>* output = nullptr, bool use_inference_params = false, bool prepare_input_gradients = false) override {
if (input.layout() != tcnn::CM || (output && output->layout() != tcnn::CM)) {
throw std::runtime_error("NerfNetwork::forward input and output must be in column major format.");
}

std::unique_ptr<tcnn::Context> forward_impl(cudaStream_t stream, const tcnn::GPUMatrixDynamic<float>& input, tcnn::GPUMatrixDynamic<T>* output = nullptr, bool use_inference_params = false, bool prepare_input_gradients = false) override {
// Make sure our temporary buffers have the correct size for the given batch size
uint32_t batch_size = input.n();

Expand Down Expand Up @@ -192,7 +184,7 @@ class NerfNetwork : public tcnn::Network<float, T> {
return forward;
}

void backward(
void backward_impl(
cudaStream_t stream,
const tcnn::Context& ctx,
const tcnn::GPUMatrixDynamic<float>& input,
Expand All @@ -202,10 +194,6 @@ class NerfNetwork : public tcnn::Network<float, T> {
bool use_inference_params = false,
tcnn::EGradientMode param_gradients_mode = tcnn::EGradientMode::Overwrite
) override {
if (input.layout() != tcnn::CM || output.layout() != tcnn::CM || dL_doutput.layout() != tcnn::CM || (dL_dinput && dL_dinput->layout() != tcnn::CM)) {
throw std::runtime_error("NerfNetwork::backward input and output must be in column major format.");
}

const auto& forward = dynamic_cast<const ForwardContext&>(ctx);

// Make sure our teporary buffers have the correct size for the given batch size
Expand Down
80 changes: 41 additions & 39 deletions include/neural-graphics-primitives/takikawa_encoding.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ __global__ void kernel_takikawa(
const TriangleOctreeNode* octree_nodes,
const TriangleOctreeDualNode* octree_dual_nodes,
const T* __restrict__ grid,
const tcnn::PitchedPtr<const float> data_in,
tcnn::PitchedPtr<T> data_out,
const tcnn::MatrixView<const float> data_in,
tcnn::MatrixView<T> data_out,
float* __restrict__ dy_dx
) {
uint32_t n_features = N_FEATURES_PER_LEVEL * n_levels;
Expand All @@ -49,9 +49,9 @@ __global__ void kernel_takikawa(
octree_dual_nodes,
n_levels + starting_level,
{
data_in(i)[0],
data_in(i)[1],
data_in(i)[2],
data_in(0, i),
data_in(1, i),
data_in(2, i),
},
[&](const TriangleOctreeDualNode& node, uint32_t level, Eigen::Vector3f pos) {
if (level < starting_level) {
Expand Down Expand Up @@ -97,11 +97,14 @@ __global__ void kernel_takikawa(
// Read params
#pragma unroll
for (uint32_t feature = 0; feature < N_FEATURES_PER_LEVEL; ++feature) {
((T*)&result)[feature] += (T)(weight * (float)((T*)&val)[feature]);
result[feature] += (T)(weight * (float)val[feature]);
}
}

*(tcnn::vector_t<T, N_FEATURES_PER_LEVEL>*)&data_out(i)[level * N_FEATURES_PER_LEVEL] = result;
#pragma unroll
for (uint32_t feature = 0; feature < N_FEATURES_PER_LEVEL; ++feature) {
data_out(level * N_FEATURES_PER_LEVEL + feature, i) = result[feature];
}
}

// Gradient
Expand Down Expand Up @@ -155,7 +158,7 @@ __global__ void kernel_takikawa(
for (; level < n_levels; ++level) {
#pragma unroll
for (uint32_t f = 0; f < N_FEATURES_PER_LEVEL; ++f) {
data_out(i)[level * N_FEATURES_PER_LEVEL + f] = (T)0.0f;
data_out(level * N_FEATURES_PER_LEVEL + f, i) = (T)0.0f;
}
}
}
Expand All @@ -165,23 +168,23 @@ template <typename T>
__global__ void kernel_takikawa_backward_input(
const uint32_t num_elements,
const uint32_t num_grid_features,
const tcnn::PitchedPtr<const T> dL_dy,
const tcnn::MatrixView<const T> dL_dy,
const float* __restrict__ dy_dx,
tcnn::PitchedPtr<float> dL_dx
tcnn::MatrixView<float> dL_dx
) {
const uint32_t input_index = threadIdx.x + blockIdx.x * blockDim.x;
if (input_index >= num_elements) return;

const uint32_t fan_out_grad = num_grid_features * 3;

const uint32_t i = input_index / 3;
const uint32_t j = input_index - i * 3;
const uint32_t j = input_index - i * 3;

float result = 0;
for (int k = 0; k < num_grid_features; ++k) {
result += (float)dL_dy(i)[k] * dy_dx[i * fan_out_grad + j * num_grid_features + k];
result += (float)dL_dy(k, i) * dy_dx[i * fan_out_grad + j * num_grid_features + k];
}
dL_dx(i)[j] = result;
dL_dx(j, i) = result;
}

template <typename T, typename GRAD_T, uint32_t N_FEATURES_PER_LEVEL>
Expand All @@ -193,8 +196,8 @@ __global__ void kernel_takikawa_backward(
const TriangleOctreeNode* octree_nodes,
const TriangleOctreeDualNode* octree_dual_nodes,
GRAD_T* __restrict__ params_gradient,
const tcnn::PitchedPtr<const float> data_in,
const tcnn::PitchedPtr<const T> dL_dy
const tcnn::MatrixView<const float> data_in,
const tcnn::MatrixView<const T> dL_dy
) {
uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;
const uint32_t encoded_index = i * N_FEATURES_PER_LEVEL * n_levels;
Expand All @@ -205,9 +208,9 @@ __global__ void kernel_takikawa_backward(
octree_dual_nodes,
n_levels + starting_level,
{
data_in(i)[0],
data_in(i)[1],
data_in(i)[2],
data_in(0, i),
data_in(1, i),
data_in(2, i),
},
[&](const TriangleOctreeDualNode& node, uint32_t level, Eigen::Vector3f pos) {
if (level < starting_level) {
Expand All @@ -222,7 +225,12 @@ __global__ void kernel_takikawa_backward(
}
}

auto grad = *(tcnn::vector_t<T, N_FEATURES_PER_LEVEL>*)&dL_dy(i)[N_FEATURES_PER_LEVEL * level];
tcnn::vector_t<T, N_FEATURES_PER_LEVEL> grad;

#pragma unroll
for (uint32_t f = 0; f < N_FEATURES_PER_LEVEL; ++f) {
grad[f] = dL_dy(N_FEATURES_PER_LEVEL * level + f, i);
}

// Tri-linear interpolation

Expand Down Expand Up @@ -282,8 +290,8 @@ public:
using grad_t = float;
#endif

TakikawaEncoding(uint32_t starting_level, bool sum_instead_of_concat, std::shared_ptr<TriangleOctree> octree, tcnn::InterpolationType interpolation_type)
: m_starting_level{starting_level}, m_sum_instead_of_concat{sum_instead_of_concat}, m_octree{octree}, m_interpolation_type{interpolation_type} {
TakikawaEncoding(uint32_t starting_level, std::shared_ptr<TriangleOctree> octree, tcnn::InterpolationType interpolation_type)
: m_starting_level{starting_level}, m_octree{octree}, m_interpolation_type{interpolation_type} {

if (m_starting_level >= m_octree->depth()) {
throw std::runtime_error{"Starting level must be below octree depth."};
Expand All @@ -294,16 +302,11 @@ public:
if (N_FEATURES_PER_LEVEL != 1 && N_FEATURES_PER_LEVEL != 2 && N_FEATURES_PER_LEVEL != 4 && N_FEATURES_PER_LEVEL != 8) {
throw std::runtime_error{"Number of features per level must be 1, 2, 4, or 8."};
}

// Only needs temporary storage if gradients are computed with different precision from T.
if (!std::is_same<grad_t, T>::value) {
m_params_gradient_tmp.resize(n_params());
}
}

virtual ~TakikawaEncoding() { }

std::unique_ptr<tcnn::Context> forward(
std::unique_ptr<tcnn::Context> forward_impl(
cudaStream_t stream,
const tcnn::GPUMatrixDynamic<float>& input,
tcnn::GPUMatrixDynamic<T>* output = nullptr,
Expand All @@ -328,15 +331,15 @@ public:
m_octree->nodes_gpu(),
m_octree->dual_nodes_gpu(),
use_inference_params ? m_params_inference : m_params,
input.pitched_ptr(),
output ? output->pitched_ptr() : tcnn::PitchedPtr<T>{},
input.view(),
output ? output->view() : tcnn::MatrixView<T>{},
forward->dy_dx.data()
);

return forward;
}

void backward(
void backward_impl(
cudaStream_t stream,
const tcnn::Context& ctx,
const tcnn::GPUMatrixDynamic<float>& input,
Expand All @@ -357,8 +360,11 @@ public:
// We accumulate gradients with grad_t precision, which, for performance reasons, is not always T.
// If not, accumulate in a temporary buffer and cast later.
grad_t* params_gradient;
tcnn::GPUMemoryArena::Allocation params_gradient_tmp;

if (!std::is_same<grad_t, T>::value) {
params_gradient = (grad_t*)m_params_gradient_tmp.data();
params_gradient_tmp = tcnn::allocate_workspace(stream, n_params() * sizeof(grad_t));
params_gradient = (grad_t*)params_gradient_tmp.data();
} else {
params_gradient = (grad_t*)m_params_gradient;
}
Expand All @@ -375,8 +381,8 @@ public:
m_octree->nodes_gpu(),
m_octree->dual_nodes_gpu(),
params_gradient,
input.pitched_ptr(),
dL_doutput.pitched_ptr()
input.view(),
dL_doutput.view()
);

if (!std::is_same<grad_t, T>::value) {
Expand All @@ -391,9 +397,9 @@ public:
tcnn::linear_kernel(kernel_takikawa_backward_input<T>, 0, stream,
num_elements * input_width(),
N_FEATURES_PER_LEVEL * n_levels(),
dL_doutput.pitched_ptr(),
dL_doutput.view(),
forward.dy_dx.data(),
dL_dinput->pitched_ptr()
dL_dinput->view()
);
}
}
Expand Down Expand Up @@ -453,7 +459,6 @@ public:
return {
{"otype", "Takikawa"},
{"starting_level", m_starting_level},
{"sum_instead_of_concat", m_sum_instead_of_concat},
{"n_levels", m_octree->depth()},
};
}
Expand All @@ -464,7 +469,6 @@ private:
};

uint32_t m_starting_level;
bool m_sum_instead_of_concat;

// derived sizes
uint32_t m_n_input_dims;
Expand All @@ -474,8 +478,6 @@ private:
// Storage of params
T* m_params;
T* m_params_inference;

tcnn::GPUMemory<grad_t> m_params_gradient_tmp;
T* m_params_gradient;

std::shared_ptr<TriangleOctree> m_octree;
Expand Down
6 changes: 3 additions & 3 deletions include/neural-graphics-primitives/trainable_buffer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ public:

virtual ~TrainableBuffer() { }

void inference_mixed_precision(cudaStream_t stream, const tcnn::GPUMatrixDynamic<float>& input, tcnn::GPUMatrixDynamic<T>& output, bool use_inference_matrices = true) override {
void inference_mixed_precision_impl(cudaStream_t stream, const tcnn::GPUMatrixDynamic<float>& input, tcnn::GPUMatrixDynamic<T>& output, bool use_inference_matrices = true) override {
throw std::runtime_error{"The trainable buffer does not support inference(). Its content is meant to be used externally."};
}

std::unique_ptr<tcnn::Context> forward(cudaStream_t stream, const tcnn::GPUMatrixDynamic<float>& input, tcnn::GPUMatrixDynamic<T>* output = nullptr, bool use_inference_matrices = false, bool prepare_input_gradients = false) override {
std::unique_ptr<tcnn::Context> forward_impl(cudaStream_t stream, const tcnn::GPUMatrixDynamic<float>& input, tcnn::GPUMatrixDynamic<T>* output = nullptr, bool use_inference_matrices = false, bool prepare_input_gradients = false) override {
throw std::runtime_error{"The trainable buffer does not support forward(). Its content is meant to be used externally."};
}

void backward(
void backward_impl(
cudaStream_t stream,
const tcnn::Context& ctx,
const tcnn::GPUMatrixDynamic<float>& input,
Expand Down
1 change: 0 additions & 1 deletion src/testbed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1751,7 +1751,6 @@ void Testbed::reset_network() {

m_encoding.reset(new TakikawaEncoding<precision_t>(
encoding_config["starting_level"],
encoding_config["sum_instead_of_concat"],
m_sdf.triangle_octree,
tcnn::string_to_interpolation_type(encoding_config.value("interpolation", "linear"))
));
Expand Down

0 comments on commit 26b0620

Please sign in to comment.