Skip to content

Commit

Permalink
optimize torch performance (keras-team#465)
Browse files Browse the repository at this point in the history
* optimize torch performance

* fixing tests

---------

Co-authored-by: Haifeng Jin <[email protected]>
  • Loading branch information
2 people authored and fchollet committed Jul 12, 2023
1 parent b06d183 commit 32f8bc6
Showing 1 changed file with 22 additions and 25 deletions.
47 changes: 22 additions & 25 deletions keras_core/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from keras_core.backend.common.stateless_scope import StatelessScope

DYNAMIC_SHAPES_OK = True

DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

TORCH_DTYPES = {
"float16": torch.float16,
Expand Down Expand Up @@ -39,20 +39,14 @@ def device_scope(device):
global_state.set_global_attribute("torch_device", previous_device)


def get_default_device():
return "cuda" if torch.cuda.is_available() else "cpu"


def get_device():
device = global_state.get_global_attribute("torch_device", None)
if device is None:
return get_default_device()
return DEFAULT_DEVICE
return device


def to_torch_dtype(dtype):
if dtype in [value for key, value in TORCH_DTYPES.items()]:
return dtype
dtype = standardize_dtype(dtype)
dtype = TORCH_DTYPES.get(dtype, None)
if dtype is None:
Expand Down Expand Up @@ -114,32 +108,32 @@ def __eq__(self, other):


def convert_to_tensor(x, dtype=None):
dtype = to_torch_dtype(dtype or getattr(x, "dtype", None))
device = get_device()
if isinstance(x, int):
dtype = torch.int32
if isinstance(x, float):
dtype = torch.float32
if is_tensor(x):
if dtype is None:
return x
return x.to(to_torch_dtype(dtype))
if isinstance(x, Variable):
# TorchDynamo has bugs supporting nn.Parameter type check.
# Return it directly instead of pass it to the rest of the logic in the
# function.
return x.value
if is_tensor(x):
if dtype and dtype != x.dtype:
x = x.to(dtype)
return x.to(device)

if isinstance(x, int):
return torch.as_tensor(x, dtype=torch.int32, device=get_device())
if isinstance(x, float):
return torch.as_tensor(x, dtype=torch.float32, device=get_device())
# Convert to np in case of any array-like that is not list or tuple.
if not isinstance(x, (list, tuple)):
x = np.array(x)
elif len(x) > 0 and any(isinstance(x1, torch.Tensor) for x1 in x):
# Handle list or tuple of torch tensors
return torch.stack([convert_to_tensor(x1) for x1 in x])
if isinstance(x, np.ndarray) and x.dtype == np.uint32:
# Torch backend does not support uint32.
x = x.astype(np.int64)
return torch.as_tensor(x, dtype=dtype, device=device)
if isinstance(x, np.ndarray):
if x.dtype == np.uint32:
# Torch backend does not support uint32.
x = x.astype(np.int64)
dtype = dtype or x.dtype
dtype = to_torch_dtype(dtype)
return torch.as_tensor(x, dtype=dtype, device=get_device())


def convert_to_numpy(x):
Expand Down Expand Up @@ -170,7 +164,10 @@ def cast(x, dtype):
if isinstance(x, KerasVariable):
x = x.value
if is_tensor(x):
return x.to(dtype)
if x.dtype == dtype:
return x
else:
return x.to(dtype)
return convert_to_tensor(x, dtype)


Expand Down Expand Up @@ -220,7 +217,7 @@ def symbolic_call(fn, args, kwargs, fill_value):
)
return fn(*meta_args, **meta_kwargs)
except:
with device_scope(get_default_device()):
with device_scope(DEFAULT_DEVICE):
# If the `"meta"` device placement fails, fall back to tracing
# eagerly with tensors on the default device. This will be
# more robust, but more expensive.
Expand Down

0 comments on commit 32f8bc6

Please sign in to comment.