Skip to content

Commit

Permalink
initial commit to add Multilayer Perceptron (MLP) extension (NVIDIA#790)
Browse files Browse the repository at this point in the history
  • Loading branch information
FDecaYed authored Apr 22, 2020
1 parent 2ec84eb commit 71511fa
Show file tree
Hide file tree
Showing 7 changed files with 1,210 additions and 1 deletion.
1 change: 1 addition & 0 deletions apex/mlp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mlp import *
70 changes: 70 additions & 0 deletions apex/mlp/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from copy import copy
import math
import torch
from torch import nn
import mlp_cuda
from .. import amp

class MlpFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
output = mlp_cuda.forward(args)
ctx.save_for_backward(*args)
ctx.outputs = output
return output[0]

@staticmethod
def backward(ctx, grad_o):
grads = mlp_cuda.backward(grad_o, ctx.outputs, ctx.saved_tensors)
del ctx.outputs
return tuple(grads)

mlp_function = amp.half_function(MlpFunction.apply)

class MLP(torch.nn.Module):
"""Launch MLP in C++
Args:
mlp_sizes (list of int): MLP sizes. Example: [1024,1024,1024] will create 2 MLP layers with shape 1024x1024
bias (bool): Default True:
relu (bool): Default True
"""
def __init__(self, mlp_sizes, bias=True, relu=True):
if not (bias and relu):
raise TypeError("bias and relu must be both true.")
super(MLP, self).__init__()
self.num_layers = len(mlp_sizes) - 1
self.mlp_sizes = copy(mlp_sizes)
self.bias = bias
self.relu= relu

# ignoring bias = False now
self.weights = []
self.biases = []
for i in range(self.num_layers):
w = torch.nn.Parameter(torch.empty(mlp_sizes[i+1], mlp_sizes[i]))
self.weights.append(w)
name = 'weight_{}'.format(i)
setattr(self, name, w)
b = torch.nn.Parameter(torch.empty(mlp_sizes[i+1]))
self.biases.append(b)
name = 'bias_{}'.format(i)
setattr(self, name, b)

self.reset_parameters()

def reset_parameters(self):
for weight in self.weights:
dimsum = weight.size(0) + weight.size(1)
std = math.sqrt(2. / float(dimsum))
nn.init.normal_(weight, 0., std)
for bias in self.biases:
std = math.sqrt(1. / float(bias.size(0)))
nn.init.normal_(bias, 0., std)

def forward(self, input):
return mlp_function(input, *self.weights, *self.biases)

def extra_repr(self):
s = F"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, ReLU={self.relu}"
return s
139 changes: 139 additions & 0 deletions csrc/mlp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>

#include <stdio.h>

size_t get_mlp_reserved_space(int batch_size, int num_layers, const int* output_features);

template <typename T>
size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features);

template <typename T>
int mlp_fp(
T* X,
int input_features,
int batch_size,
T** WPtr,
int num_layers,
int* output_features,
T** BPtr,
T* Y,
T* reserved_space);

template <typename T>
int mlp_bp(
T* X,
T* Y,
int input_features,
int batch_size,
T** WPtr,
int num_layers,
int* output_features,
T* dY,
T* reserved_space,
T* work_space,
T* dX,
T** dwPtr,
T** dbPtr);

std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) {
// inputs contains (input, weights, biases)
auto num_layers = (inputs.size() - 1) / 2;
auto batch_size = inputs[0].size(0);
auto input_features = inputs[0].size(1);

std::vector<int> output_features;
for (int i = 0; i < num_layers; i++) {
output_features.push_back(inputs[i + 1].size(0));
}

auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());

// create output/workspace tensor
// TODO(deyuf): just get buffer?
auto out = at::empty({batch_size, output_features.back()}, inputs[0].type());
auto reserved_space = at::empty({reserved_size}, inputs[0].type());

AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] {
std::vector<scalar_t*> w_ptr;
std::vector<scalar_t*> b_ptr;
for (int i = 0; i < num_layers; i++) {
w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());
}
auto result = mlp_fp<scalar_t>(
inputs[0].data_ptr<scalar_t>(),
input_features,
batch_size,
w_ptr.data(),
num_layers,
output_features.data(),
b_ptr.data(),
out.data_ptr<scalar_t>(),
reserved_space.data_ptr<scalar_t>());
});

return {out, reserved_space};
}

std::vector<at::Tensor> mlp_backward(
at::Tensor grad_o,
std::vector<at::Tensor> fprop_outputs,
std::vector<at::Tensor> inputs) {
// same code to get sizes and W pointers
auto num_layers = (inputs.size() - 1) / 2;
auto batch_size = inputs[0].size(0);
auto input_features = inputs[0].size(1);

std::vector<int> output_features;
for (int i = 0; i < num_layers; i++) {
output_features.push_back(inputs[i + 1].size(0));
}
// create outputs, length of inputs
std::vector<at::Tensor> outputs;
for (int i = 0; i < inputs.size(); i++) {
outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now
}

AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] {
std::vector<scalar_t*> w_ptr;
std::vector<scalar_t*> b_ptr;
for (int i = 0; i < num_layers; i++) {
w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());
}
std::vector<scalar_t*> outputs_ptr;
for (int i = 0; i < inputs.size(); i++) {
outputs_ptr.push_back(outputs[i].data_ptr<scalar_t>());
}

auto work_size =
get_mlp_bp_workspace_in_bytes<scalar_t>(batch_size, num_layers, output_features.data());

// auto work_space = at::empty({work_size*4}, at::kByte);
auto work_space = at::empty({work_size / sizeof(scalar_t)}, inputs[0].type());

auto result = mlp_bp<scalar_t>(
inputs[0].data_ptr<scalar_t>(),
fprop_outputs[0].data_ptr<scalar_t>(),
input_features,
batch_size,
w_ptr.data(),
num_layers,
output_features.data(),
grad_o.contiguous().data_ptr<scalar_t>(),
fprop_outputs[1].data_ptr<scalar_t>(),
work_space.data_ptr<scalar_t>(),
outputs_ptr[0],
outputs_ptr.data() + 1,
outputs_ptr.data() + 1 + num_layers);
});

return outputs;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &mlp_forward, "MLP forward");
m.def("backward", &mlp_backward, "MLP backward");
}
Loading

0 comments on commit 71511fa

Please sign in to comment.