Skip to content

Commit

Permalink
Add a requirements file for multi-backend cuda (keras-team#472)
Browse files Browse the repository at this point in the history
Not totally sure if we should merge this now, or wait for tf 2.14, but
figured I could put it up anyway so people could use it. With
tensorflow/tensorflow#59825
tf-nightly can be installed using cuda pip packages. Which means we
can write a recipe for cross framework GPU support.

To install a local development version...
```shell
pip install -r requirements-cuda.txt
python pip_build.py --install
```

To install the official pip version...
```shell
pip install -r requirements-cuda.txt
pip install keras-core --no-deps
```

Note that `--no-deps` is required to avoid pulling in `tensorflow` and
`tf-nightly` at the same time.

This should work in a clean python env, as long nvidia drivers are
>=520.61.05. No conda or cuda shenanigans required!
  • Loading branch information
mattdangerw authored and fchollet committed Jul 16, 2023
1 parent c8953e5 commit 59fca26
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 17 deletions.
15 changes: 15 additions & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
namex
black>=22
flake8
isort
pytest
pandas
absl-py
requests
h5py
protobuf
google
tensorboard-plugin-profile
rich
build
dm-tree
18 changes: 18 additions & 0 deletions requirements-cuda.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Tensorflow.
# Cuda via pip is only on nightly right now.
# We will pin a known working version to avoid breakages (nightly breaks often).
tf-nightly[and-cuda]==2.14.0.dev20230712

# Torch.
# Pin the version used in colab currently (works with tf cuda version).
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.0.1+cu118
torchvision==0.15.2+cu118

# Jax.
# Pin the version used in colab currently (works with tf cuda version).
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda11_pip]==0.4.10

# Common deps.
-r requirements-common.txt
26 changes: 9 additions & 17 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
# Tensorflow.
tensorflow
# TODO: Use Torch CPU
# Remove after resolving Cuda version differences with TF

# Torch.
# TODO: Use Torch CPU, remove after resolving Cuda version differences with TF
torch>=2.0.1+cpu
torchvision>=0.15.1

# Jax.
jax[cpu]
namex
black>=22
flake8
isort
pytest
pandas
absl-py
requests
h5py
protobuf
google
tensorboard-plugin-profile
rich
build
dm-tree

# Common deps.
-r requirements-common.txt

0 comments on commit 59fca26

Please sign in to comment.