Skip to content

Commit

Permalink
Have median reduce over all dims and return just the value when dim i…
Browse files Browse the repository at this point in the history
…s not provided
  • Loading branch information
lantiga authored and soumith committed Jul 4, 2017
1 parent 635bb5e commit 05c2baf
Show file tree
Hide file tree
Showing 18 changed files with 68 additions and 5 deletions.
9 changes: 8 additions & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,7 +1539,14 @@ def test_median(self):
x = torch.rand(size, size)
x0 = x.clone()

res1val, res1ind = torch.median(x, keepdim=False)
nelem = x.nelement()
res1val = torch.median(x)
res2val, _ = torch.sort(x.view(nelem))
ind = int(math.floor((nelem + 1) / 2) - 1)

self.assertEqual(res2val[ind], res1val, 0)

res1val, res1ind = torch.median(x, dim=1, keepdim=False)
res2val, res2ind = torch.sort(x)
ind = int(math.floor((size + 1) / 2) - 1)

Expand Down
2 changes: 0 additions & 2 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2327,8 +2327,6 @@
is squeezed (see :func:`torch.squeeze`), resulting in the outputs Tensor having 1 fewer
dimension than :attr:`input`.
.. note:: This function is not defined for ``torch.cuda.Tensor`` yet.
Args:
input (Tensor): the input `Tensor`
dim (int): the dimension to reduce
Expand Down
2 changes: 1 addition & 1 deletion torch/autograd/_functions/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class Mode(_SelectionFunction):


class Median(_SelectionFunction):
has_all_reduce = False
pass


class Kthvalue(_SelectionFunction):
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/generic/methods/TensorCompare.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -628,12 +628,18 @@
name: median
backends:
- CPU
- CUDA
variants:
- method
- function
return: argument 0,1
options:
- before_call: long __last_dim = THTensor_(nDimension)(LIBRARY_STATE ((THPTensor*)$arg2)->cdata)-1;
- cname: medianall
return: real
arguments:
- THTensor* self
- cname: median
before_call: long __last_dim = THTensor_(nDimension)(LIBRARY_STATE ((THPTensor*)$arg2)->cdata)-1;
arguments:
- arg: THTensor* values
output: True
Expand Down
1 change: 1 addition & 0 deletions torch/lib/THD/master_worker/common/Functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ enum Functions: std::uint16_t {
tensorDot,
tensorMinall,
tensorMaxall,
tensorMedianall,
tensorSumall,
tensorProdall,
tensorNeg,
Expand Down
9 changes: 9 additions & 0 deletions torch/lib/THD/master_worker/master/generic/THDTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,15 @@ real THDTensor_(maxall)(THDTensor *self) {
return receiveValueFromWorker<real>(THDState::s_current_worker);
}

real THDTensor_(medianall)(THDTensor *self) {
masterCommandChannel->sendMessage(
packMessage(Functions::tensorMedianall, self),
THDState::s_current_worker
);

return receiveValueFromWorker<real>(THDState::s_current_worker);
}

accreal THDTensor_(sumall)(THDTensor *self) {
masterCommandChannel->sendMessage(
packMessage(Functions::tensorSumall, self),
Expand Down
1 change: 1 addition & 0 deletions torch/lib/THD/master_worker/master/generic/THDTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ THD_API real THDTensor_(get4d)(const THDTensor *tensor, long x0, long x1,
THD_API accreal THDTensor_(dot)(THDTensor *self, THDTensor *src);
THD_API real THDTensor_(minall)(THDTensor *self);
THD_API real THDTensor_(maxall)(THDTensor *self);
THD_API real THDTensor_(medianall)(THDTensor *self);
THD_API accreal THDTensor_(sumall)(THDTensor *self);
THD_API accreal THDTensor_(prodall)(THDTensor *self);
THD_API void THDTensor_(neg)(THDTensor *self, THDTensor *src);
Expand Down
1 change: 1 addition & 0 deletions torch/lib/THD/master_worker/worker/Dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ static const std::unordered_map<rpc::function_id_type, dispatch_fn> functions {
{Functions::tensorDot, tensorDot},
{Functions::tensorMinall, tensorMinall},
{Functions::tensorMaxall, tensorMaxall},
{Functions::tensorMedianall, tensorMedianall},
{Functions::tensorSumall, tensorSumall},
{Functions::tensorProdall, tensorProdall},
{Functions::tensorNeg, tensorNeg},
Expand Down
15 changes: 15 additions & 0 deletions torch/lib/THD/master_worker/worker/dispatch/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,21 @@ static void tensorMaxall(rpc::RPCMessage& raw_message) {
}
}

static void tensorMedianall(rpc::RPCMessage& raw_message) {
thpp::Tensor *tensor = unpackRetrieveTensor(raw_message);
finalize(raw_message);

if (thpp::isInteger(tensor->type())) {
long long value = dynamic_cast<thpp::IntTensor*>(tensor)->medianall();
sendValueToMaster(value);
} else if (thpp::isFloat(tensor->type())) {
double value = dynamic_cast<thpp::FloatTensor*>(tensor)->medianall();
sendValueToMaster(value);
} else {
throw std::invalid_argument("expected scalar type");
}
}

static void tensorSumall(rpc::RPCMessage& raw_message) {
thpp::Tensor *tensor = unpackRetrieveTensor(raw_message);
finalize(raw_message);
Expand Down
1 change: 1 addition & 0 deletions torch/lib/THPP/Tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ struct TensorScalarInterface : public Tensor {
virtual scalar_type dot(const Tensor& source) = 0;
virtual scalar_type minall() = 0;
virtual scalar_type maxall() = 0;
virtual scalar_type medianall() = 0;
virtual scalar_type sumall() = 0;
virtual scalar_type prodall() = 0;
virtual TensorScalarInterface& add(const Tensor& src, scalar_type value) = 0;
Expand Down
1 change: 1 addition & 0 deletions torch/lib/THPP/tensors/THCSTensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ struct THCSTensor : public interface_traits<real>::tensor_interface_type {
virtual scalar_type dot(const Tensor& source) override;
virtual scalar_type minall() override;
virtual scalar_type maxall() override;
virtual scalar_type medianall() override;
virtual scalar_type sumall() override;
virtual scalar_type prodall() override;
virtual THCSTensor& neg(const Tensor& src) override;
Expand Down
1 change: 1 addition & 0 deletions torch/lib/THPP/tensors/THCTensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ struct THCTensor : public interface_traits<real>::tensor_interface_type {
virtual scalar_type dot(const Tensor& source) override;
virtual scalar_type minall() override;
virtual scalar_type maxall() override;
virtual scalar_type medianall() override;
virtual scalar_type sumall() override;
virtual scalar_type prodall() override;
virtual THCTensor& neg(const Tensor& src) override;
Expand Down
1 change: 1 addition & 0 deletions torch/lib/THPP/tensors/THSTensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ struct THSTensor : public interface_traits<real>::tensor_interface_type {
virtual scalar_type dot(const Tensor& source) override;
virtual scalar_type minall() override;
virtual scalar_type maxall() override;
virtual scalar_type medianall() override;
virtual scalar_type sumall() override;
virtual scalar_type prodall() override;
virtual THSTensor& neg(const Tensor& src) override;
Expand Down
1 change: 1 addition & 0 deletions torch/lib/THPP/tensors/THTensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ struct THTensor : public interface_traits<real>::tensor_interface_type {
virtual scalar_type dot(const Tensor& source) override;
virtual scalar_type minall() override;
virtual scalar_type maxall() override;
virtual scalar_type medianall() override;
virtual scalar_type sumall() override;
virtual scalar_type prodall() override;
virtual THTensor& neg(const Tensor& src) override;
Expand Down
5 changes: 5 additions & 0 deletions torch/lib/THPP/tensors/generic/THCSTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,11 @@ auto THCSTensor<real>::maxall() -> scalar_type {
throw std::runtime_error("THCSTensor::maxall() not supported");
}

template<>
auto THCSTensor<real>::medianall() -> scalar_type {
throw std::runtime_error("THCSTensor::medianall() not supported");
}

template<>
auto THCSTensor<real>::sumall() -> scalar_type {
throw std::runtime_error("THCSTensor::sumall() not supported");
Expand Down
5 changes: 5 additions & 0 deletions torch/lib/THPP/tensors/generic/THCTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,11 @@ auto THCTensor<real>::maxall() -> scalar_type {
return uncast_scalar(THCTensor_(maxall)(state, tensor));
}

template<>
auto THCTensor<real>::medianall() -> scalar_type {
return uncast_scalar(THCTensor_(medianall)(state, tensor));
}

template<>
auto THCTensor<real>::sumall() -> scalar_type {
return THCTensor_(sumall)(state, tensor);
Expand Down
5 changes: 5 additions & 0 deletions torch/lib/THPP/tensors/generic/THSTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,11 @@ auto THSTensor<real>::maxall() -> scalar_type {
throw std::runtime_error("THSTensor::maxall() not supported");
}

template<>
auto THSTensor<real>::medianall() -> scalar_type {
throw std::runtime_error("THSTensor::medianall() not supported");
}

template<>
auto THSTensor<real>::sumall() -> scalar_type {
throw std::runtime_error("THSTensor::sumall() not supported");
Expand Down
5 changes: 5 additions & 0 deletions torch/lib/THPP/tensors/generic/THTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,11 @@ auto THTensor<real>::maxall() -> scalar_type {
return THTensor_(maxall)(tensor);
}

template<>
auto THTensor<real>::medianall() -> scalar_type {
return THTensor_(medianall)(tensor);
}

template<>
auto THTensor<real>::sumall() -> scalar_type {
return THTensor_(sumall)(tensor);
Expand Down

0 comments on commit 05c2baf

Please sign in to comment.