-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathflake.nix
85 lines (84 loc) · 3.58 KB
/
flake.nix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
{
description = "Generative neural networks for 3D terrain";
inputs = {
nixpkgs.url = "github:nixos/nixpkgs/23.11";
nixpkgs-unstable.url = "github:nixos/nixpkgs/nixpkgs-unstable";
flake-utils.url = "github:numtide/flake-utils";
# Patched version of nixGL from kenrandunderscore.
# PR: https://github.com/nix-community/nixGL/pull/165
# TODO: switch back to github:nix-community/nixGL when PR is merged.
nixgl.url = "github:hayden-donnelly/nixGL";
};
outputs = inputs@{ self, nixpkgs, nixpkgs-unstable, flake-utils, nixgl, ... }:
flake-utils.lib.eachSystem [ "x86_64-linux" ] (system: let
inherit (nixpkgs-unstable) lib;
in {
devShells = let
pyVer = "311";
py = "python${pyVer}";
overlays = [
nixgl.overlay
(final: prev: {
${py} = prev.${py}.override {
packageOverrides = finalPkgs: prevPkgs: {
jax = prevPkgs.jax.overridePythonAttrs (o: {
nativeCheckInputs = [];
pythonImportsCheck = [];
pytestFlagsArray = [];
passthru.tests = [];
doCheck = false;
});
# For some reason Flax has jaxlib as a builtInput and tensorflow as a nativeCheckInput,
# so set jaxlib to jaxlib-bin in order to avoid building jaxlib and turn off all checks
# to avoid building tensorflow.
jaxlib = prevPkgs.jaxlib-bin;
flax = prevPkgs.flax.overridePythonAttrs (o: {
nativeCheckInputs = [];
pythonImportsCheck = [];
pytestFlagsArray = [];
doCheck = false;
});
wandb = prevPkgs.wandb.overridePythonAttrs(o: {
nativeCheckInputs = [];
pythonIMportsCheck = [];
doCheck = false;
});
};
};
})
];
unstableCudaPkgs = import nixpkgs-unstable {
inherit system overlays;
config = {
allowUnfree = true;
cudaSupport = true;
};
};
in rec {
default = unstableCudaPkgs.mkShell {
name = "cuda";
buildInputs = [
(unstableCudaPkgs.${py}.withPackages (pyp: with pyp; [
jax
jaxlib-bin
flax
pyarrow
pillow
pandas
datasets
wandb
tifffile
zarr
]))
unstableCudaPkgs.cudaPackages.cudatoolkit
unstableCudaPkgs.cudaPackages.cuda_cudart
unstableCudaPkgs.cudaPackages.cudnn
];
shellHook = ''
source <(sed -Ee '/\$@/d' ${lib.getExe unstableCudaPkgs.nixgl.nixGLIntel})
source <(sed -Ee '/\$@/d' ${lib.getExe unstableCudaPkgs.nixgl.auto.nixGLNvidia}*)
'';
};
};
});
}