Skip to content

Commit

Permalink
add Transpose View Expand C functions
Browse files Browse the repository at this point in the history
  • Loading branch information
albanD authored and soumith committed Jun 17, 2017
1 parent dd5c7c4 commit 462ab8a
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torch/csrc/autograd/functions/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,14 @@ bool THPAutograd_initFunctions(PyObject* _unused)

static PyTypeObject CloneClass;
addClass<Clone, NoCtor>(module, CloneClass, "Clone");

static PyTypeObject IdentityClass;
addClass<Identity, NoCtor>(module, IdentityClass, "Identity");
static PyTypeObject TransposeClass;
addClass<Transpose, NoCtor>(module, TransposeClass, "CTranspose");
static PyTypeObject ViewClass;
addClass<View, NoCtor>(module, ViewClass, "CView");
static PyTypeObject ExpandClass;
addClass<Expand, NoCtor>(module, ExpandClass, "CExpand");

THPObjectPtr parent(PyImport_ImportModule("torch._C"));
if (!parent) return false;
Expand Down
40 changes: 40 additions & 0 deletions torch/csrc/autograd/functions/tensor.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "tensor.h"

#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/functions/basic_ops.h"
#include "torch/csrc/autograd/functions/utils.h"
#include "torch/csrc/utils/auto_gpu.h"

Expand All @@ -22,4 +23,43 @@ auto Clone::apply(const variable_list& inputs) -> variable_list {
});
};

auto Transpose::apply(const variable_list& inputs) -> variable_list {
check_input_variables("Transpose", inputs, 1);

auto& input = inputs[0]->data;
AutoGPU guard(input->getDevice());

std::unique_ptr<thpp::Tensor> output(input->newTranspose(dim1, dim2));

return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) {
return std::make_shared<Transpose>(dim1, dim2);
});
}

auto View::apply(const variable_list& inputs) -> variable_list {
check_input_variables("View", inputs, 1);

auto& input = inputs[0]->data;
AutoGPU guard(input->getDevice());

std::unique_ptr<thpp::Tensor> output(input->newView(size));

return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) {
return std::make_shared<View>(input->sizes());
});
}

auto Expand::apply(const variable_list& inputs) -> variable_list {
check_input_variables("Expand", inputs, 1);

auto& input = inputs[0]->data;
AutoGPU guard(input->getDevice());

std::unique_ptr<thpp::Tensor> output(input->newExpand(size));

return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) {
return std::make_shared<Error>("Expand is not differentiable", std::move(f));
});
}

}} // namespace torch::autograd
29 changes: 29 additions & 0 deletions torch/csrc/autograd/functions/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,35 @@ struct Clone : public Function {
virtual variable_list apply(const variable_list& inputs) override;
};

struct Transpose : public Function {
Transpose(long dim1, long dim2)
: dim1(dim1)
, dim2(dim2) {}

virtual variable_list apply(const variable_list& inputs) override;

long dim1;
long dim2;
};

struct View : public Function {
View(std::vector<long> size)
: size(size) {}

virtual variable_list apply(const variable_list& inputs) override;

std::vector<long> size;
};

struct Expand : public Function {
Expand(std::vector<long> size)
: size(size) {}

virtual variable_list apply(const variable_list& inputs) override;

std::vector<long> size;
};

}}


0 comments on commit 462ab8a

Please sign in to comment.