Skip to content

Commit

Permalink
Merge pull request numpy#13107 from eric-wieser/simplify-val-nd
Browse files Browse the repository at this point in the history
MAINT: Unify polynomial valnd functions
  • Loading branch information
eric-wieser authored Mar 13, 2019
2 parents 35a905f + fcea19a commit a1d6cf7
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 144 deletions.
28 changes: 4 additions & 24 deletions numpy/polynomial/chebyshev.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,14 +1224,7 @@ def chebval2d(x, y, c):
.. versionadded:: 1.7.0
"""
try:
x, y = np.array((x, y), copy=0)
except Exception:
raise ValueError('x, y are incompatible')

c = chebval(x, c)
c = chebval(y, c, tensor=False)
return c
return pu._valnd(chebval, c, x, y)


def chebgrid2d(x, y, c):
Expand Down Expand Up @@ -1284,9 +1277,7 @@ def chebgrid2d(x, y, c):
.. versionadded:: 1.7.0
"""
c = chebval(x, c)
c = chebval(y, c)
return c
return pu._gridnd(chebval, c, x, y)


def chebval3d(x, y, z, c):
Expand Down Expand Up @@ -1337,15 +1328,7 @@ def chebval3d(x, y, z, c):
.. versionadded:: 1.7.0
"""
try:
x, y, z = np.array((x, y, z), copy=0)
except Exception:
raise ValueError('x, y, z are incompatible')

c = chebval(x, c)
c = chebval(y, c, tensor=False)
c = chebval(z, c, tensor=False)
return c
return pu._valnd(chebval, c, x, y, z)


def chebgrid3d(x, y, z, c):
Expand Down Expand Up @@ -1401,10 +1384,7 @@ def chebgrid3d(x, y, z, c):
.. versionadded:: 1.7.0
"""
c = chebval(x, c)
c = chebval(y, c)
c = chebval(z, c)
return c
return pu._gridnd(chebval, c, x, y, z)


def chebvander(x, deg):
Expand Down
28 changes: 4 additions & 24 deletions numpy/polynomial/hermite.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,14 +981,7 @@ def hermval2d(x, y, c):
.. versionadded:: 1.7.0
"""
try:
x, y = np.array((x, y), copy=0)
except Exception:
raise ValueError('x, y are incompatible')

c = hermval(x, c)
c = hermval(y, c, tensor=False)
return c
return pu._valnd(hermval, c, x, y)


def hermgrid2d(x, y, c):
Expand Down Expand Up @@ -1041,9 +1034,7 @@ def hermgrid2d(x, y, c):
.. versionadded:: 1.7.0
"""
c = hermval(x, c)
c = hermval(y, c)
return c
return pu._gridnd(hermval, c, x, y)


def hermval3d(x, y, z, c):
Expand Down Expand Up @@ -1094,15 +1085,7 @@ def hermval3d(x, y, z, c):
.. versionadded:: 1.7.0
"""
try:
x, y, z = np.array((x, y, z), copy=0)
except Exception:
raise ValueError('x, y, z are incompatible')

c = hermval(x, c)
c = hermval(y, c, tensor=False)
c = hermval(z, c, tensor=False)
return c
return pu._valnd(hermval, c, x, y, z)


def hermgrid3d(x, y, z, c):
Expand Down Expand Up @@ -1158,10 +1141,7 @@ def hermgrid3d(x, y, z, c):
.. versionadded:: 1.7.0
"""
c = hermval(x, c)
c = hermval(y, c)
c = hermval(z, c)
return c
return pu._gridnd(hermval, c, x, y, z)


def hermvander(x, deg):
Expand Down
28 changes: 4 additions & 24 deletions numpy/polynomial/hermite_e.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,14 +975,7 @@ def hermeval2d(x, y, c):
.. versionadded:: 1.7.0
"""
try:
x, y = np.array((x, y), copy=0)
except Exception:
raise ValueError('x, y are incompatible')

c = hermeval(x, c)
c = hermeval(y, c, tensor=False)
return c
return pu._valnd(hermeval, c, x, y)


def hermegrid2d(x, y, c):
Expand Down Expand Up @@ -1035,9 +1028,7 @@ def hermegrid2d(x, y, c):
.. versionadded:: 1.7.0
"""
c = hermeval(x, c)
c = hermeval(y, c)
return c
return pu._gridnd(hermeval, c, x, y)


def hermeval3d(x, y, z, c):
Expand Down Expand Up @@ -1088,15 +1079,7 @@ def hermeval3d(x, y, z, c):
.. versionadded:: 1.7.0
"""
try:
x, y, z = np.array((x, y, z), copy=0)
except Exception:
raise ValueError('x, y, z are incompatible')

c = hermeval(x, c)
c = hermeval(y, c, tensor=False)
c = hermeval(z, c, tensor=False)
return c
return pu._valnd(hermeval, c, x, y, z)


def hermegrid3d(x, y, z, c):
Expand Down Expand Up @@ -1152,10 +1135,7 @@ def hermegrid3d(x, y, z, c):
.. versionadded:: 1.7.0
"""
c = hermeval(x, c)
c = hermeval(y, c)
c = hermeval(z, c)
return c
return pu._gridnd(hermeval, c, x, y, z)


def hermevander(x, deg):
Expand Down
28 changes: 4 additions & 24 deletions numpy/polynomial/laguerre.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,14 +981,7 @@ def lagval2d(x, y, c):
.. versionadded:: 1.7.0
"""
try:
x, y = np.array((x, y), copy=0)
except Exception:
raise ValueError('x, y are incompatible')

c = lagval(x, c)
c = lagval(y, c, tensor=False)
return c
return pu._valnd(lagval, c, x, y)


def laggrid2d(x, y, c):
Expand Down Expand Up @@ -1041,9 +1034,7 @@ def laggrid2d(x, y, c):
.. versionadded:: 1.7.0
"""
c = lagval(x, c)
c = lagval(y, c)
return c
return pu._gridnd(lagval, c, x, y)


def lagval3d(x, y, z, c):
Expand Down Expand Up @@ -1094,15 +1085,7 @@ def lagval3d(x, y, z, c):
.. versionadded:: 1.7.0
"""
try:
x, y, z = np.array((x, y, z), copy=0)
except Exception:
raise ValueError('x, y, z are incompatible')

c = lagval(x, c)
c = lagval(y, c, tensor=False)
c = lagval(z, c, tensor=False)
return c
return pu._valnd(lagval, c, x, y, z)


def laggrid3d(x, y, z, c):
Expand Down Expand Up @@ -1158,10 +1141,7 @@ def laggrid3d(x, y, z, c):
.. versionadded:: 1.7.0
"""
c = lagval(x, c)
c = lagval(y, c)
c = lagval(z, c)
return c
return pu._gridnd(lagval, c, x, y, z)


def lagvander(x, deg):
Expand Down
28 changes: 4 additions & 24 deletions numpy/polynomial/legendre.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,14 +1025,7 @@ def legval2d(x, y, c):
.. versionadded:: 1.7.0
"""
try:
x, y = np.array((x, y), copy=0)
except Exception:
raise ValueError('x, y are incompatible')

c = legval(x, c)
c = legval(y, c, tensor=False)
return c
return pu._valnd(legval, c, x, y)


def leggrid2d(x, y, c):
Expand Down Expand Up @@ -1085,9 +1078,7 @@ def leggrid2d(x, y, c):
.. versionadded:: 1.7.0
"""
c = legval(x, c)
c = legval(y, c)
return c
return pu._gridnd(legval, c, x, y)


def legval3d(x, y, z, c):
Expand Down Expand Up @@ -1138,15 +1129,7 @@ def legval3d(x, y, z, c):
.. versionadded:: 1.7.0
"""
try:
x, y, z = np.array((x, y, z), copy=0)
except Exception:
raise ValueError('x, y, z are incompatible')

c = legval(x, c)
c = legval(y, c, tensor=False)
c = legval(z, c, tensor=False)
return c
return pu._valnd(legval, c, x, y, z)


def leggrid3d(x, y, z, c):
Expand Down Expand Up @@ -1202,10 +1185,7 @@ def leggrid3d(x, y, z, c):
.. versionadded:: 1.7.0
"""
c = legval(x, c)
c = legval(y, c)
c = legval(z, c)
return c
return pu._gridnd(legval, c, x, y, z)


def legvander(x, deg):
Expand Down
28 changes: 4 additions & 24 deletions numpy/polynomial/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,14 +910,7 @@ def polyval2d(x, y, c):
.. versionadded:: 1.7.0
"""
try:
x, y = np.array((x, y), copy=0)
except Exception:
raise ValueError('x, y are incompatible')

c = polyval(x, c)
c = polyval(y, c, tensor=False)
return c
return pu._valnd(polyval, c, x, y)


def polygrid2d(x, y, c):
Expand Down Expand Up @@ -970,9 +963,7 @@ def polygrid2d(x, y, c):
.. versionadded:: 1.7.0
"""
c = polyval(x, c)
c = polyval(y, c)
return c
return pu._gridnd(polyval, c, x, y)


def polyval3d(x, y, z, c):
Expand Down Expand Up @@ -1023,15 +1014,7 @@ def polyval3d(x, y, z, c):
.. versionadded:: 1.7.0
"""
try:
x, y, z = np.array((x, y, z), copy=0)
except Exception:
raise ValueError('x, y, z are incompatible')

c = polyval(x, c)
c = polyval(y, c, tensor=False)
c = polyval(z, c, tensor=False)
return c
return pu._valnd(polyval, c, x, y, z)


def polygrid3d(x, y, z, c):
Expand Down Expand Up @@ -1087,10 +1070,7 @@ def polygrid3d(x, y, z, c):
.. versionadded:: 1.7.0
"""
c = polyval(x, c)
c = polyval(y, c)
c = polyval(z, c)
return c
return pu._gridnd(polyval, c, x, y, z)


def polyvander(x, deg):
Expand Down
48 changes: 48 additions & 0 deletions numpy/polynomial/polyutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,51 @@ def _fromroots(line_f, mul_f, roots):
p = tmp
n = m
return p[0]


def _valnd(val_f, c, *args):
"""
Helper function used to implement the ``<type>val<n>d`` functions.
Parameters
----------
val_f : function(array_like, array_like, tensor: bool) -> array_like
The ``<type>val`` function, such as ``polyval``
c, args :
See the ``<type>val<n>d`` functions for more detail
"""
try:
args = tuple(np.array(args, copy=False))
except Exception:
# preserve the old error message
if len(args) == 2:
raise ValueError('x, y, z are incompatible')
elif len(args) == 3:
raise ValueError('x, y are incompatible')
else:
raise ValueError('ordinates are incompatible')

it = iter(args)
x0 = next(it)

# use tensor on only the first
c = val_f(x0, c)
for xi in it:
c = val_f(xi, c, tensor=False)
return c


def _gridnd(val_f, c, *args):
"""
Helper function used to implement the ``<type>grid<n>d`` functions.
Parameters
----------
val_f : function(array_like, array_like, tensor: bool) -> array_like
The ``<type>val`` function, such as ``polyval``
c, args :
See the ``<type>grid<n>d`` functions for more detail
"""
for xi in args:
c = val_f(xi, c)
return c

0 comments on commit a1d6cf7

Please sign in to comment.