forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a requirements file for multi-backend cuda (keras-team#472)
Not totally sure if we should merge this now, or wait for tf 2.14, but figured I could put it up anyway so people could use it. With tensorflow/tensorflow#59825 tf-nightly can be installed using cuda pip packages. Which means we can write a recipe for cross framework GPU support. To install a local development version... ```shell pip install -r requirements-cuda.txt python pip_build.py --install ``` To install the official pip version... ```shell pip install -r requirements-cuda.txt pip install keras-core --no-deps ``` Note that `--no-deps` is required to avoid pulling in `tensorflow` and `tf-nightly` at the same time. This should work in a clean python env, as long nvidia drivers are >=520.61.05. No conda or cuda shenanigans required!
- Loading branch information
1 parent
c8953e5
commit 59fca26
Showing
3 changed files
with
42 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
namex | ||
black>=22 | ||
flake8 | ||
isort | ||
pytest | ||
pandas | ||
absl-py | ||
requests | ||
h5py | ||
protobuf | ||
tensorboard-plugin-profile | ||
rich | ||
build | ||
dm-tree |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Tensorflow. | ||
# Cuda via pip is only on nightly right now. | ||
# We will pin a known working version to avoid breakages (nightly breaks often). | ||
tf-nightly[and-cuda]==2.14.0.dev20230712 | ||
|
||
# Torch. | ||
# Pin the version used in colab currently (works with tf cuda version). | ||
--extra-index-url https://download.pytorch.org/whl/cu118 | ||
torch==2.0.1+cu118 | ||
torchvision==0.15.2+cu118 | ||
|
||
# Jax. | ||
# Pin the version used in colab currently (works with tf cuda version). | ||
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html | ||
jax[cuda11_pip]==0.4.10 | ||
|
||
# Common deps. | ||
-r requirements-common.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,13 @@ | ||
# Tensorflow. | ||
tensorflow | ||
# TODO: Use Torch CPU | ||
# Remove after resolving Cuda version differences with TF | ||
|
||
# Torch. | ||
# TODO: Use Torch CPU, remove after resolving Cuda version differences with TF | ||
torch>=2.0.1+cpu | ||
torchvision>=0.15.1 | ||
|
||
# Jax. | ||
jax[cpu] | ||
namex | ||
black>=22 | ||
flake8 | ||
isort | ||
pytest | ||
pandas | ||
absl-py | ||
requests | ||
h5py | ||
protobuf | ||
tensorboard-plugin-profile | ||
rich | ||
build | ||
dm-tree | ||
|
||
# Common deps. | ||
-r requirements-common.txt |