Skip to content

Commit

Permalink
Remove cmin, cmax and cinv
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored and soumith committed Jan 17, 2017
1 parent 3b6644d commit f91bb96
Show file tree
Hide file tree
Showing 18 changed files with 226 additions and 269 deletions.
2 changes: 0 additions & 2 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ Containers

.. autoclass:: Module
:members:
.. autoclass:: Container
:members:

Convolution Layers
----------------------------------
Expand Down
8 changes: 2 additions & 6 deletions docs/source/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,11 @@ view of a storage and defines numeric operations on it.
.. automethod:: ceil_
.. automethod:: char
.. automethod:: chunk
.. automethod:: cinv
.. automethod:: cinv_
.. automethod:: inv
.. automethod:: inv_
.. automethod:: clamp
.. automethod:: clamp_
.. automethod:: clone
.. automethod:: cmax
.. automethod:: cmax_
.. automethod:: cmin
.. automethod:: cmin_
.. automethod:: contiguous
.. automethod:: copy_
.. automethod:: cos
Expand Down
4 changes: 1 addition & 3 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,8 @@ Pointwise Ops
.. autofunction:: atan
.. autofunction:: atan2
.. autofunction:: ceil
.. autofunction:: cinv
.. autofunction:: inv
.. autofunction:: clamp
.. autofunction:: cmax
.. autofunction:: cmin
.. autofunction:: cos
.. autofunction:: cosh
.. autofunction:: div
Expand Down
15 changes: 7 additions & 8 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def gather_variable(shape, index_dim, max_indices):
(Asin, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),) ),
(Acos, (), (torch.randn(S, S, S).clamp(-0.9, 0.9),) ),
(Atan, (), ((S, S, S),) ),
(Cinv, (), (torch.rand(S, S, S) + 0.1,) ),
(Reciprocal, (), (torch.rand(S, S, S) + 0.1,) ),
(Cmax, (), ((S, S, S), (S, S, S)) ),
(Cmin, (), ((S, S, S), (S, S, S)) ),
(Round, (), ((S, S, S),) ),
Expand Down Expand Up @@ -841,7 +841,7 @@ def gather_variable(shape, index_dim, max_indices):
('asin', (S, S, S), () ),
('acos', (S, S, S), () ),
('atan', (S, S, S), () ),
('cinv', (S, S, S), () ),
('reciprocal', (S, S, S), () ),
('round', (S, S, S), () ),
('sign', (S, S, S), () ),
('trunc', (S, S, S), () ),
Expand All @@ -851,10 +851,10 @@ def gather_variable(shape, index_dim, max_indices):
('fmod', (S, S, S), (1.5,) ),
('remainder', (S, S, S), (1.5,) ),
('lerp', (S, S, S), ((S, S, S), 0.4) ),
('cmax', (S, S, S), ((S, S, S),) ),
('cmax', (S, S, S), (0.5,), 'constant' ),
('cmin', (S, S, S), ((S, S, S),) ),
('cmin', (S, S, S), (0.5,), 'constant' ),
('max', (S, S, S), () ),
('max', (S, S, S), ((S, S, S),), 'elementwise' ),
('min', (S, S, S), () ),
('min', (S, S, S), ((S, S, S),), 'elementwise' ),
('mean', (S, S, S), () ),
('mean', (S, S, S), (1,), 'dim' ),
('sum', (S, S, S), () ),
Expand All @@ -872,8 +872,6 @@ def gather_variable(shape, index_dim, max_indices):
('addr', (S, M), ((S,), (M,)), ),
('addr', (S, M), (0.2, 0.6, (S,), (M,)), 'coef' ),
('dot', (L,), ((L,),), ),
('max', (S, S, S), () ),
('min', (S, S, S), () ),
('addcmul', (S, S), ((S, S), (S, S)) ),
('addcmul', (S, S), (0.5, (S, S), (S, S)), 'scale' ),
('addcdiv', (S, S), ((S, S), (S, S)) ),
Expand Down Expand Up @@ -906,6 +904,7 @@ def gather_variable(shape, index_dim, max_indices):
# TODO: mode, median, sort, kthvalue, topk (problem with indices)
# TODO: indexAdd, indexCopy, indexFill
# TODO: resize, resize_as (tensors only have resize_ and resize_as_)
# TODO: clamp with min/max


def create_input(call_args):
Expand Down
6 changes: 3 additions & 3 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,6 @@ def tmp(t):
('chunk', medium_2d, lambda t: [4, 1], 'dim' ),
('clamp', medium_2d_scaled, lambda t: [-1, 5], ),
('clone', medium_2d, lambda t: [], ),
('cmax', medium_2d, lambda t: [medium_2d(t)], ),
('cmin', medium_2d, lambda t: [medium_2d(t)], ),
('contiguous', medium_2d, lambda t: [], ),
('cross', new_t(M, 3, M), lambda t: [new_t(M, 3, M)(t)], ),
('cumprod', small_3d, lambda t: [1], ),
Expand Down Expand Up @@ -166,8 +164,10 @@ def tmp(t):
('lerp', small_3d, lambda t: [small_3d(t), 0.3], ),
('max', small_3d_unique, lambda t: [], ),
('max', small_3d_unique, lambda t: [1], 'dim' ),
('max', medium_2d, lambda t: [medium_2d(t)], 'elementwise' ),
('min', small_3d_unique, lambda t: [], ),
('min', small_3d_unique, lambda t: [1], 'dim' ),
('min', medium_2d, lambda t: [medium_2d(t)], 'elementwise' ),
('mean', small_3d, lambda t: [], ),
('mean', small_3d, lambda t: [1], 'dim' ),
('mode', small_3d, lambda t: [], ),
Expand Down Expand Up @@ -262,7 +262,7 @@ def tmp(t):
'cos',
'cosh',
'exp',
'cinv',
'reciprocal',
'floor',
'frac',
'neg',
Expand Down
34 changes: 20 additions & 14 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,17 +203,11 @@ def _testCSelection(self, torchfn, mathfn):
expected_c.map2_(a, b, lambda _, a, b: mathfn(a, b))
self.assertEqual(expected_c, c, 0)

# Tensor and scalar
v = random.random()
c = torchfn(a, v)
expected_c.map_(a, lambda _, a: mathfn(a, v))
self.assertEqual(expected_c, c, 0)

def test_cmax(self):
self._testCSelection(torch.cmax, max)
def test_max_elementwise(self):
self._testCSelection(torch.max, max)

def test_cmin(self):
self._testCSelection(torch.cmin, min)
def test_min_elementwise(self):
self._testCSelection(torch.min, min)

def test_lerp(self):
def TH_lerp(a, b, weight):
Expand Down Expand Up @@ -332,14 +326,14 @@ def test_neg(self):
res_neg.neg_()
self.assertEqual(res_neg, res_add)

def test_cinv(self):
def test_reciprocal(self):
a = torch.randn(100,89)
zeros = torch.Tensor().resize_as_(a).zero_()

res_pow = torch.pow(a, -1)
res_inv = a.clone()
res_inv.cinv_()
self.assertEqual(res_inv, res_pow)
res_reciprocal = a.clone()
res_reciprocal.reciprocal_()
self.assertEqual(res_reciprocal, res_pow)

def test_mul(self):
m1 = torch.randn(10,10)
Expand Down Expand Up @@ -520,6 +514,18 @@ def test_clamp(self):
res2[i] = max(min_val, min(max_val, res2[i]))
self.assertEqual(res1, res2)

res1 = torch.clamp(m1, min=min_val)
res2 = m1.clone()
for i in iter_indices(res2):
res2[i] = max(min_val, res2[i])
self.assertEqual(res1, res2)

res1 = torch.clamp(m1, max=max_val)
res2 = m1.clone()
for i in iter_indices(res2):
res2[i] = min(max_val, res2[i])
self.assertEqual(res1, res2)

def test_pow(self):
# [res] torch.pow([res,] x)

Expand Down
18 changes: 9 additions & 9 deletions torch/autograd/_functions/pointwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def forward(self, i):

def backward(self, grad_output):
i, = self.saved_tensors
return grad_output * (1 - i.mul(i)).sqrt_().cinv_()
return grad_output * (1 - i.mul(i)).sqrt_().reciprocal_()


class Acos(Function):
Expand All @@ -181,7 +181,7 @@ def forward(self, i):

def backward(self, grad_output):
i, = self.saved_tensors
return grad_output.mul((1 - i.mul(i)).sqrt_().cinv_()).neg_()
return grad_output.mul((1 - i.mul(i)).sqrt_().reciprocal_()).neg_()


class Atan(Function):
Expand All @@ -191,13 +191,13 @@ def forward(self, i):

def backward(self, grad_output):
i, = self.saved_tensors
return grad_output * i.mul(i).add_(1).cinv_()
return grad_output * i.mul(i).add_(1).reciprocal_()


class Cinv(Function):
class Reciprocal(Function):

def forward(self, i):
result = i.cinv()
result = i.reciprocal()
self.save_for_backward(result)
return result

Expand All @@ -210,7 +210,7 @@ class Cmax(Function):

def forward(self, a, b):
self._max_buffer = a.gt(b).type_as(a)
return a.cmax(b)
return a.max(b)

def backward(self, grad_output):
return (
Expand All @@ -227,7 +227,7 @@ def __init__(self, constant):

def forward(self, i):
self._max_buffer = i.gt(self.constant).type_as(i)
return i.cmax(self.constant)
return i.clamp(min=self.constant)

def backward(self, grad_output):
return grad_output * self._max_buffer
Expand All @@ -237,7 +237,7 @@ class Cmin(Function):

def forward(self, a, b):
self._min_buffer = a.lt(b).type_as(a)
return a.cmin(b)
return a.min(b)

def backward(self, grad_output):
return (
Expand All @@ -254,7 +254,7 @@ def __init__(self, constant):

def forward(self, i):
self._min_buffer = i.lt(self.constant).type_as(i)
return i.cmin(self.constant)
return i.clamp(max=self.constant)

def backward(self, grad_output):
return grad_output * self._min_buffer
Expand Down
2 changes: 1 addition & 1 deletion torch/autograd/_functions/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def backward(self, reward):
probs /= probs.sum(1).expand_as(probs)
grad_probs = probs.new().resize_as_(probs).zero_()
output_probs = probs.gather(1, samples)
output_probs.add_(1e-6).cinv_()
output_probs.add_(1e-6).reciprocal_()
output_probs.neg_().mul_(reward)
# TODO: add batched index_add
for i in range(probs.size(0)):
Expand Down
30 changes: 15 additions & 15 deletions torch/autograd/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,23 +389,19 @@ def cosh(self):
def abs(self):
return Abs()(self)

def clamp(self, min_val, max_val):
return Clamp(min_val, max_val)(self)

def cinv(self):
return Cinv()(self)

def cmax(self, other):
if isinstance(other, Variable):
return Cmax()(self, other)
def clamp(self, min=None, max=None):
if min is None and max is None:
raise ValueError("clamp requires specifying at least one of "
"min and max arguments")
elif min is None and max is not None:
return CminConstant(max)(self)
elif min is not None and max is None:
return CmaxConstant(min)(self)
else:
return CmaxConstant(other)(self)
return Clamp(min, max)(self)

def cmin(self, other):
if isinstance(other, Variable):
return Cmin()(self, other)
else:
return CminConstant(other)(self)
def reciprocal(self):
return Reciprocal()(self)

def floor(self):
return Floor()(self)
Expand Down Expand Up @@ -456,9 +452,13 @@ def mean(self, dim=None):
return Mean(dim)(self)

def max(self, dim=None):
if isinstance(dim, Variable):
return Cmax()(self, dim)
return Max(dim)(self)

def min(self, dim=None):
if isinstance(dim, Variable):
return Cmin()(self, dim)
return Min(dim)(self)

def mode(self, dim):
Expand Down
8 changes: 2 additions & 6 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,14 @@ IMPLEMENT_STATELESS(mean)
IMPLEMENT_STATELESS(std)
IMPLEMENT_STATELESS(var)
IMPLEMENT_STATELESS(norm)
IMPLEMENT_STATELESS(cinv)
IMPLEMENT_STATELESS(reciprocal)
IMPLEMENT_STATELESS(neg)
IMPLEMENT_STATELESS(add)
IMPLEMENT_STATELESS(mul)
IMPLEMENT_STATELESS(div)
IMPLEMENT_STATELESS(fmod)
IMPLEMENT_STATELESS(min)
IMPLEMENT_STATELESS(max)
IMPLEMENT_STATELESS(cmax)
IMPLEMENT_STATELESS(cmin)
IMPLEMENT_STATELESS(dot)
IMPLEMENT_STATELESS(sum)
IMPLEMENT_STATELESS(prod)
Expand Down Expand Up @@ -534,16 +532,14 @@ static PyMethodDef TorchMethods[] = {
{"std", (PyCFunction)THPModule_std, METH_VARARGS | METH_KEYWORDS, NULL},
{"var", (PyCFunction)THPModule_var, METH_VARARGS | METH_KEYWORDS, NULL},
{"norm", (PyCFunction)THPModule_norm, METH_VARARGS | METH_KEYWORDS, NULL},
{"cinv", (PyCFunction)THPModule_cinv, METH_VARARGS | METH_KEYWORDS, NULL},
{"reciprocal", (PyCFunction)THPModule_reciprocal, METH_VARARGS | METH_KEYWORDS, NULL},
{"neg", (PyCFunction)THPModule_neg, METH_VARARGS | METH_KEYWORDS, NULL},
{"add", (PyCFunction)THPModule_add, METH_VARARGS | METH_KEYWORDS, NULL},
{"mul", (PyCFunction)THPModule_mul, METH_VARARGS | METH_KEYWORDS, NULL},
{"div", (PyCFunction)THPModule_div, METH_VARARGS | METH_KEYWORDS, NULL},
{"fmod", (PyCFunction)THPModule_fmod, METH_VARARGS | METH_KEYWORDS, NULL},
{"min", (PyCFunction)THPModule_min, METH_VARARGS | METH_KEYWORDS, NULL},
{"max", (PyCFunction)THPModule_max, METH_VARARGS | METH_KEYWORDS, NULL},
{"cmax", (PyCFunction)THPModule_cmax, METH_VARARGS | METH_KEYWORDS, NULL},
{"cmin", (PyCFunction)THPModule_cmin, METH_VARARGS | METH_KEYWORDS, NULL},
{"dot", (PyCFunction)THPModule_dot, METH_VARARGS | METH_KEYWORDS, NULL},
{"sum", (PyCFunction)THPModule_sum, METH_VARARGS | METH_KEYWORDS, NULL},
{"prod", (PyCFunction)THPModule_prod, METH_VARARGS | METH_KEYWORDS, NULL},
Expand Down
14 changes: 14 additions & 0 deletions torch/csrc/generic/methods/TensorCompare.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,13 @@
return: real
arguments:
- THTensor* self
- cname: cmin
return: argument 0
arguments:
- arg: THTensor* result
output: True
- THTensor* self
- THTensor* other
- cname: min
return: argument 0,1
arguments:
Expand All @@ -420,6 +427,13 @@
return: real
arguments:
- THTensor* self
- cname: cmax
return: argument 0
arguments:
- arg: THTensor* result
output: True
- THTensor* self
- THTensor* other
- cname: max
return: argument 0,1
arguments:
Expand Down
Loading

0 comments on commit f91bb96

Please sign in to comment.