diff --git a/tensorflow_probability/python/math/interpolation.py b/tensorflow_probability/python/math/interpolation.py index 74c3c91a37..d8d21e12a1 100644 --- a/tensorflow_probability/python/math/interpolation.py +++ b/tensorflow_probability/python/math/interpolation.py @@ -551,11 +551,12 @@ def batch_interp_regular_nd_grid(x, Interpolate a function of one variable. ```python - y_ref = tf.exp(tf.linspace(start=0., stop=10., 20)) + y_ref = tf.exp(tf.linspace(start=0., stop=10., num=20)) tfp.math.batch_interp_regular_nd_grid( # x.shape = [3, 1], x_ref_min/max.shape = [1]. Trailing `1` for `1-D`. - x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[1.], y_ref=y_ref) + x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[10.], y_ref=y_ref, + axis=0) ==> approx [exp(6.0), exp(0.5), exp(3.3)] ```