Skip to content

Commit

Permalink
Some dtype fixes (keras-team#935)
Browse files Browse the repository at this point in the history
* Some dtype fixes

* Nits
  • Loading branch information
fchollet committed Sep 20, 2023
1 parent 0386142 commit 72c2e8c
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 5 deletions.
7 changes: 7 additions & 0 deletions keras_core/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ def append(


def arange(start, stop=None, step=1, dtype=None):
if dtype is None:
if hasattr(start, "dtype"):
dtype = start.dtype
elif isinstance(start, int):
dtype = "int32"
else:
dtype = config.floatx()
return jnp.arange(start, stop, step=step, dtype=dtype)


Expand Down
18 changes: 17 additions & 1 deletion keras_core/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np

from keras_core.backend import config
from keras_core.backend import standardize_dtype


def add(x1, x2):
return np.add(x1, x2)
Expand Down Expand Up @@ -77,6 +80,13 @@ def append(


def arange(start, stop=None, step=None, dtype=None):
if dtype is None:
if hasattr(start, "dtype"):
dtype = start.dtype
elif isinstance(start, int):
dtype = "int32"
else:
dtype = config.floatx()
return np.arange(start, stop, step=step, dtype=dtype)


Expand Down Expand Up @@ -124,6 +134,7 @@ def argsort(x, axis=-1):


def array(x, dtype=None):
dtype = dtype or config.floatx()
return np.array(x, dtype=dtype)


Expand Down Expand Up @@ -271,6 +282,7 @@ def floor(x):


def full(shape, fill_value, dtype=None):
dtype = dtype or config.floatx()
return np.full(shape, fill_value, dtype=dtype)


Expand Down Expand Up @@ -592,7 +604,11 @@ def square(x):


def sqrt(x):
return np.sqrt(x)
dtype = None
if hasattr(x, "dtype"):
if standardize_dtype(x.dtype).startswith("int"):
dtype = config.floatx()
return np.sqrt(x, dtype=dtype)


def squeeze(x, axis=None):
Expand Down
11 changes: 11 additions & 0 deletions keras_core/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tensorflow as tf
from tensorflow.experimental import numpy as tfnp

from keras_core.backend import config
from keras_core.backend.tensorflow.core import convert_to_tensor


Expand Down Expand Up @@ -176,6 +177,13 @@ def append(
def arange(start, stop=None, step=1, dtype=None):
# tfnp.arange has trouble with dynamic Tensors in compiled function.
# tf.range does not.
if dtype is None:
if hasattr(start, "dtype"):
dtype = start.dtype
elif isinstance(start, int):
dtype = "int32"
else:
dtype = config.floatx()
return tf.range(start, stop, delta=step, dtype=dtype)


Expand Down Expand Up @@ -749,6 +757,9 @@ def square(x):


def sqrt(x):
x = convert_to_tensor(x)
if tf.as_dtype(x.dtype).is_integer:
x = tf.cast(x, dtype=config.floatx())
return tfnp.sqrt(x)


Expand Down
10 changes: 9 additions & 1 deletion keras_core/backend/torch/numpy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import torch

from keras_core.backend import config
from keras_core.backend.torch.core import cast
from keras_core.backend.torch.core import convert_to_tensor
from keras_core.backend.torch.core import get_device
Expand Down Expand Up @@ -91,7 +92,7 @@ def zeros(shape, dtype="float32"):

def zeros_like(x, dtype=None):
x = convert_to_tensor(x)
dtype = to_torch_dtype(dtype)
dtype = to_torch_dtype(dtype or x.dtype)
return torch.zeros_like(x, dtype=dtype)


Expand Down Expand Up @@ -160,6 +161,13 @@ def append(


def arange(start, stop=None, step=1, dtype=None):
if dtype is None:
if hasattr(start, "dtype"):
dtype = start.dtype
elif isinstance(start, int):
dtype = "int32"
else:
dtype = config.floatx()
dtype = to_torch_dtype(dtype)
if stop is None:
return torch.arange(end=start, dtype=dtype, device=get_device())
Expand Down
36 changes: 33 additions & 3 deletions keras_core/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3571,9 +3571,37 @@ def test_split(self):
self.assertEqual(len(knp.Split(2)(x)), 2)

def test_sqrt(self):
x = np.array([[1, 4, 9], [16, 25, 36]])
self.assertAllClose(knp.sqrt(x), np.sqrt(x))
self.assertAllClose(knp.Sqrt()(x), np.sqrt(x))
x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float32")
ref_y = np.sqrt(x)
y = knp.sqrt(x)
self.assertEqual(standardize_dtype(y.dtype), "float32")
self.assertAllClose(y, ref_y)
y = knp.Sqrt()(x)
self.assertEqual(standardize_dtype(y.dtype), "float32")
self.assertAllClose(y, ref_y)

@pytest.mark.skipif(
backend.backend() == "jax", reason="JAX does not support float64."
)
def test_sqrt_float64(self):
x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float64")
ref_y = np.sqrt(x)
y = knp.sqrt(x)
self.assertEqual(standardize_dtype(y.dtype), "float64")
self.assertAllClose(y, ref_y)
y = knp.Sqrt()(x)
self.assertEqual(standardize_dtype(y.dtype), "float64")
self.assertAllClose(y, ref_y)

def test_sqrt_int32(self):
x = np.array([[1, 4, 9], [16, 25, 36]], dtype="int32")
ref_y = np.sqrt(x)
y = knp.sqrt(x)
self.assertEqual(standardize_dtype(y.dtype), "float32")
self.assertAllClose(y, ref_y)
y = knp.Sqrt()(x)
self.assertEqual(standardize_dtype(y.dtype), "float32")
self.assertAllClose(y, ref_y)

def test_stack(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
Expand Down Expand Up @@ -3704,6 +3732,8 @@ def test_arange(self):
self.assertAllClose(knp.Arange()(3, 7), np.arange(3, 7))
self.assertAllClose(knp.Arange()(3, 7, 2), np.arange(3, 7, 2))

self.assertEqual(standardize_dtype(knp.arange(3).dtype), "int32")

def test_full(self):
self.assertAllClose(knp.full([2, 3], 0), np.full([2, 3], 0))
self.assertAllClose(knp.full([2, 3], 0.1), np.full([2, 3], 0.1))
Expand Down

0 comments on commit 72c2e8c

Please sign in to comment.