Skip to content

Commit

Permalink
Remove device argument from diag and diagflat
Browse files Browse the repository at this point in the history
  • Loading branch information
niboshi committed Oct 10, 2019
1 parent 4fd2210 commit 0cad1fd
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 19 deletions.
12 changes: 2 additions & 10 deletions chainerx_cc/chainerx/python/routines.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,16 +262,8 @@ void InitChainerxCreation(pybind11::module& m) {
"k"_a = 0,
"dtype"_a = "float64",
"device"_a = nullptr);
m.def("diag",
[](const ArrayBodyPtr& v, int64_t k, py::handle device) { return MoveArrayBody(Diag(Array{v}, k, GetDevice(device))); },
"v"_a,
"k"_a = 0,
"device"_a = nullptr);
m.def("diagflat",
[](const ArrayBodyPtr& v, int64_t k, py::handle device) { return MoveArrayBody(Diagflat(Array{v}, k, GetDevice(device))); },
"v"_a,
"k"_a = 0,
"device"_a = nullptr);
m.def("diag", [](const ArrayBodyPtr& v, int64_t k) { return MoveArrayBody(Diag(Array{v}, k)); }, "v"_a, "k"_a = 0);
m.def("diagflat", [](const ArrayBodyPtr& v, int64_t k) { return MoveArrayBody(Diagflat(Array{v}, k)); }, "v"_a, "k"_a = 0);
m.def("linspace",
[](Scalar start, Scalar stop, int64_t num, bool endpoint, py::handle dtype, py::handle device) {
return MoveArrayBody(Linspace(
Expand Down
11 changes: 6 additions & 5 deletions chainerx_cc/chainerx/routines/creation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,9 @@ Array AsContiguousArray(const Array& a, const absl::optional<Dtype>& dtype) {
return out;
}

Array Diag(const Array& v, int64_t k, Device& device) {
Array Diag(const Array& v, int64_t k) {
Array out{};
Device& device = v.device();

int8_t ndim = v.ndim();
if (ndim == 1) {
Expand Down Expand Up @@ -293,19 +294,19 @@ Array Diag(const Array& v, int64_t k, Device& device) {

BackwardBuilder bb{"diag", v, out};
if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
bt.Define([& device = v.device(), k](BackwardContext& bctx) {
bt.Define([k](BackwardContext& bctx) {
const Array& gout = *bctx.output_grad();
bctx.input_grad() = Diag(gout, k, device);
bctx.input_grad() = Diag(gout, k);
});
}
bb.Finalize();

return out;
}

Array Diagflat(const Array& v, int64_t k, Device& device) {
Array Diagflat(const Array& v, int64_t k) {
// TODO(hvy): Use Ravel or Flatten when implemented instead of Reshape.
return Diag(v.Reshape({v.GetTotalSize()}), k, device);
return Diag(v.Reshape({v.GetTotalSize()}), k);
}

// Creates a 1-d array with evenly spaced numbers.
Expand Down
6 changes: 2 additions & 4 deletions chainerx_cc/chainerx/routines/creation.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,9 @@ inline Array AsContiguous(const Array& a) { return AsContiguous(a, a.dtype()); }
// An input array with shape {} results in a new array with shape {1}.
Array AsContiguousArray(const Array& a, const absl::optional<Dtype>& dtype = absl::nullopt);

// TODO(niboshi): Remove device argument and use v.device(). Also fix tests
Array Diag(const Array& v, int64_t k = 0, Device& device = GetDefaultDevice());
Array Diag(const Array& v, int64_t k = 0);

// TODO(niboshi): Remove device argument and use v.device(). Also fix tests
Array Diagflat(const Array& v, int64_t k = 0, Device& device = GetDefaultDevice());
Array Diagflat(const Array& v, int64_t k = 0);

// Creates a 1-d array with evenly spaced numbers.
Array Linspace(
Expand Down

0 comments on commit 0cad1fd

Please sign in to comment.