import Pkg
Pkg.add("Lux")
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support
# Seeding
rng = Xoshiro(0)
# Construct the layer
model = Chain(BatchNorm(128), Dense(128, 256, tanh), BatchNorm(256),
Chain(Dense(256, 1, tanh), Dense(1, 10)))
# Get the device determined by Lux
device = gpu_device()
# Parameter and State Variables
ps, st = Lux.setup(rng, model) .|> device
# Dummy Input
x = rand(rng, Float32, 128, 2) |> device
# Run the model
y, st = Lux.apply(model, x, ps, st)
# Gradients
gs = only(gradient(p -> sum(first(Lux.apply(model, x, p, st))), ps))
# Optimization
st_opt = Optimisers.setup(Optimisers.Adam(0.0001), ps)
st_opt, ps = Optimisers.update(st_opt, ps, gs)
Look in the examples directory for self-contained usage examples. The documentation has examples sorted into proper categories.
The full test of Lux.jl
takes a long time, here's how to test a portion of the code.
For each @testitem
, there are corresponding tags
, for example:
@testitem "SkipConnection" setup=[SharedTestSetup] tags=[:core_layers]
For example, let's consider the tests for SkipConnection
:
@testitem "SkipConnection" setup=[SharedTestSetup] tags=[:core_layers] begin
...
end
We can test the group to which SkipConnection
belongs by testing core_layers
.
To do so set the LUX_TEST_GROUP
environment variable, or rename the tag to
further narrow the test scope:
export LUX_TEST_GROUP="core_layers"
Or directly modify the default test tag in runtests.jl
:
# const LUX_TEST_GROUP = lowercase(get(ENV, "LUX_TEST_GROUP", "all"))
const LUX_TEST_GROUP = lowercase(get(ENV, "LUX_TEST_GROUP", "core_layers"))
But be sure to restore the default value "all" before submitting the code.
Furthermore if you want to run a specific test based on the name of the testset, you can use TestEnv.jl as follows. Start with activating the Lux environment and then run the following:
using TestEnv; TestEnv.activate(); using ReTestItems;
# Assuming you are in the main directory of Lux
ReTestItems.runtests("tests/"; name = "NAME OF THE TEST")
For the SkipConnection
tests that would be:
ReTestItems.runtests("tests/"; name = SkipConnection)
For usage related questions, please use Github Discussions which allows questions and answers to be indexed. To report bugs use github issues or even better send in a pull request.
If you found this library to be useful in academic work, then please cite:
@software{pal2023lux,
author = {Pal, Avik},
title = {{Lux: Explicit Parameterization of Deep Neural Networks in Julia}},
month = apr,
year = 2023,
note = {If you use this software, please cite it as below.},
publisher = {Zenodo},
version = {v0.5.0},
doi = {10.5281/zenodo.7808904},
url = {https://doi.org/10.5281/zenodo.7808904}
}
@thesis{pal2023efficient,
title = {{On Efficient Training \& Inference of Neural Differential Equations}},
author = {Pal, Avik},
year = {2023},
school = {Massachusetts Institute of Technology}
}
Also consider starring our github repo.