Skip to content

Commit

Permalink
Implement RNN for Torch backend (keras-team#340)
Browse files Browse the repository at this point in the history
* Add PyTorch numpy functionality

* Add dtype conversion

* Partial fix for PyTorch numpy tests

* small logic fix

* Revert numpy_test

* Add tensor conversion to numpy

* Fix some arithmetic tests

* Fix some torch functions for numpy compatibility

* Fix pytorch ops for numpy compatibility, add TODOs

* Fix formatting

* Implement nits and fix dtype standardization

* Add pytest skipif decorator and fix nits

* Fix formatting and rename dtypes map

* Split tests by backend

* Merge space

* Fix dtype issues from new type checking

* Implement torch.full and torch.full_like numpy compatible

* Implements logspace and linspace with tensor support for start and stop

* Replace len of shape with ndim

* Fix formatting

* Implement torch.trace

* Implement eye k diagonal arg

* Implement torch.tri

* Fix formatting issues

* Fix torch.take dimensionality

* Add split functionality

* Revert torch.eye implementation to prevent conflict

* Implement all padding modes

* Adds torch image resizing and torchvision dependency.

* Fix conditional syntax

* Make torchvision import optional

* Partial implementation of torch RNN

* Duplicate torch demo file

* Small ops fixes for torch unit tests

* delete nonfunctional gpu test file

* Revert rnn and formatting fixes

* Revert progbar

* Fix formatting

* Restore torch rnn

* Rough implementation of Torch RNN

* Rewrite tf.while_loop in Torch

* Implement tensor list comprehension functionality

* Revert tf changes

* Debug RNN tests

* Debug convolutional LSTM tests

* Fix tensor list conversion

* Fix zero output for masking

* Fix formatting

* Update comment
  • Loading branch information
nkovela1 authored and fchollet committed Jun 13, 2023
1 parent 4fd3ea1 commit d65881b
Show file tree
Hide file tree
Showing 4 changed files with 387 additions and 7 deletions.
9 changes: 7 additions & 2 deletions keras_core/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,13 @@ def convert_to_tensor(x, dtype=None):
return x.to(dtype)
return x

# Convert to np first in case of any non-numpy, numpy-compatible array.
x = np.array(x)
# Convert to np in case of any array-like that is not list or tuple.
if not isinstance(x, (list, tuple)):
x = np.array(x)
# Handle list or tuple of torch tensors
elif len(x) > 0 and isinstance(x[0], torch.Tensor):
return torch.stack(x)

return torch.as_tensor(x, dtype=dtype, device=get_device())


Expand Down
6 changes: 6 additions & 0 deletions keras_core/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def zeros(shape, dtype="float32"):
return torch.zeros(size=shape, dtype=dtype)


def zeros_like(x, dtype=None):
x = convert_to_tensor(x)
dtype = to_torch_dtype(dtype)
return torch.zeros_like(x, dtype=dtype)


def absolute(x):
return abs(x)

Expand Down
Loading

0 comments on commit d65881b

Please sign in to comment.