Skip to content

Commit

Permalink
[ROLLBACK]
Browse files Browse the repository at this point in the history
Copybara import of the project:

--
3ad0854 by Jake VanderPlas <[email protected]>:

[x64] make jnp.histogram and related functions work with strict promotion

PiperOrigin-RevId: 452189426
  • Loading branch information
yashk2810 authored and jax authors committed Jun 1, 2022
1 parent 8816fd9 commit f6d4373
Showing 1 changed file with 24 additions and 32 deletions.
56 changes: 24 additions & 32 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,38 +396,35 @@ def correlate(a, v, mode='valid', *, precision=None):

@_wraps(np.histogram_bin_edges)
def histogram_bin_edges(a, bins=10, range=None, weights=None):
del weights # unused, because string bins is not supported.
if isinstance(bins, str):
raise NotImplementedError("string values for `bins` not implemented.")
_check_arraylike("histogram_bin_edges", a, bins)
a = ravel(a)
dtype = dtypes._to_inexact_dtype(_dtype(a))
if _ndim(bins) == 1:
return asarray(bins, dtype=dtype)
bins = core.concrete_or_error(operator.index, bins,
"bins argument of histogram_bin_edges")
b = asarray(bins)
if b.ndim == 1:
return b
if range is None:
range = [a.min(), a.max()]
range = asarray(range, dtype=dtype)
if range.shape != (2,):
raise ValueError("`range` must be either None or a sequence of scalars.")
assert len(range) == 2
range = asarray(range)
range = (where(ptp(range) == 0, range[0] - 0.5, range[0]),
where(ptp(range) == 0, range[1] + 0.5, range[1]))
dtype = _dtype(a)
if issubdtype(dtype, integer):
dtype = promote_types(dtype, float32)
return linspace(range[0], range[1], bins + 1, dtype=dtype)


@_wraps(np.histogram)
def histogram(a, bins=10, range=None, weights=None, density=None):
if weights is None:
_check_arraylike("histogram", a, bins)
a = ravel(*_promote_dtypes_inexact(a))
weights = ones_like(a)
_check_arraylike("histogram", a, bins)
if weights is not None and a.shape != weights.shape:
raise ValueError("weights should have the same shape as a.")
a = ravel(a)
if weights is not None:
weights = ravel(weights)
else:
_check_arraylike("histogram", a, bins, weights)
if a.shape != weights.shape:
raise ValueError("weights should have the same shape as a.")
a, weights = map(ravel, _promote_dtypes_inexact(a, weights))

weights = ones_like(a)
bin_edges = histogram_bin_edges(a, bins, range, weights)
bin_idx = searchsorted(bin_edges, a, side='right')
bin_idx = where(a == bin_edges[-1], len(bin_edges) - 1, bin_idx)
Expand Down Expand Up @@ -455,16 +452,12 @@ def histogram2d(x, y, bins=10, range=None, weights=None, density=None):

@_wraps(np.histogramdd)
def histogramdd(sample, bins=10, range=None, weights=None, density=None):
if weights is None:
_check_arraylike("histogramdd", sample)
sample, = _promote_dtypes_inexact(sample)
else:
_check_arraylike("histogramdd", sample, weights)
if weights.shape != sample.shape[:1]:
raise ValueError("should have one weight for each sample.")
sample, weights = _promote_dtypes_inexact(sample, weights)
_check_arraylike("histogramdd", sample)
N, D = shape(sample)

if weights is not None and weights.shape != (N,):
raise ValueError("should have one weight for each sample.")

if range is not None and (
len(range) != D or _any(r is not None and len(r) != 2 for r in range)):
raise ValueError(f"For sample.shape={(N, D)}, range must be a sequence "
Expand All @@ -478,10 +471,10 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=None):
# when bin_size is integer, the same bin is used for each dimension
bins = D * [bins]

bin_idx_by_dim = D * [None]
bin_idx_by_dim = D*[None]
nbins = np.empty(D, int)
bin_edges_by_dim = D * [None]
dedges = D * [None]
bin_edges_by_dim = D*[None]
dedges = D*[None]

for i in builtins.range(D):
range_i = None if range is None else range[i]
Expand All @@ -500,7 +493,6 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=None):
hist = hist[core]

if density:
hist = hist.astype(sample.dtype)
hist /= hist.sum()
for norm in ix_(*dedges):
hist /= norm
Expand Down Expand Up @@ -808,9 +800,9 @@ def ravel_multi_index(multi_index, dims, mode='raise', order='C'):
else:
raise ValueError(f"invalid order={order!r}. Expected 'C' or 'F'")

result = array(0, dtype=multi_index[0].dtype)
result = array(0, dtype=dtypes.canonicalize_dtype(int_))
for i, s in zip(multi_index, strides):
result = result + i * int(s)
result = result + i * s
return result


Expand Down

0 comments on commit f6d4373

Please sign in to comment.