Skip to content

Commit

Permalink
JAX now detects CUDA.
Browse files Browse the repository at this point in the history
Switching entirely to nixpkgs/unstable and using a proper overlay to
prevent JAX from compiling jaxlib was what made it work.
  • Loading branch information
hayden-donnelly committed Mar 21, 2024
1 parent 5bd71be commit dab3992
Showing 1 changed file with 6 additions and 14 deletions.
20 changes: 6 additions & 14 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,9 @@
flake-utils.url = "github:numtide/flake-utils";
};
outputs = inputs@{ self, nixpkgs, nixpkgs-unstable, flake-utils, ... }:
flake-utils.lib.eachSystem [ "x86_64-linux" ] (system: let
inherit (nixpkgs) lib;
unstableCudaPkgs = import nixpkgs-unstable {
inherit system;
config = {
allowUnfree = true;
cudaSupport = true;
};
};
in {
flake-utils.lib.eachSystem [ "x86_64-linux" ] (system: {
devShells = let
pyVer = "310";
pyVer = "311";
py = "python${pyVer}";
overlays = [
(final: prev: {
Expand All @@ -35,24 +26,25 @@
nativeCheckInputs = [];
pythonImportsCheck = [];
pytestFlagsArray = [];
passthru.tests = [];
doCheck = false;
});
};
};
})
];
stableJaxPkgs = import nixpkgs {
unstableCudaPkgs = import nixpkgs-unstable {
inherit system overlays;
config = {
allowUnfree = true;
cudaSupport = true;
};
};
in rec {
default = stableJaxPkgs.mkShell {
default = unstableCudaPkgs.mkShell {
name = "cuda";
buildInputs = [
(stableJaxPkgs.${py}.withPackages (pyp: with pyp; [
(unstableCudaPkgs.${py}.withPackages (pyp: with pyp; [
jax
jaxlib-bin
]))
Expand Down

0 comments on commit dab3992

Please sign in to comment.