Skip to content

Commit

Permalink
Merge pull request jax-ml#10904 from jakevdp:x64-trapz
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 452205618
  • Loading branch information
jax authors committed Jun 1, 2022
2 parents 2153a57 + 7e0fe7b commit a1f7ced
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,13 +330,17 @@ def result_type(*args):
@_wraps(np.trapz)
@partial(jit, static_argnames=('axis',))
def trapz(y, x=None, dx=1.0, axis: int = -1):
_check_arraylike('trapz', y)
y = moveaxis(y, axis, -1)
if x is not None:
if x is None:
_check_arraylike('trapz', y)
y, = _promote_dtypes_inexact(y)
else:
_check_arraylike('trapz', y, x)
y, x = _promote_dtypes_inexact(y, x)
if ndim(x) == 1:
dx = diff(x)
else:
dx = moveaxis(diff(x, axis=axis), axis, -1)
y = moveaxis(y, axis, -1)
return 0.5 * (dx * (y[..., 1:] + y[..., :-1])).sum(-1)


Expand Down

0 comments on commit a1f7ced

Please sign in to comment.