diff --git a/docs/src/utilities.md b/docs/src/utilities.md index 5edaf46c2a..f41500520c 100644 --- a/docs/src/utilities.md +++ b/docs/src/utilities.md @@ -40,6 +40,15 @@ Flux.orthogonal Flux.sparse_init ``` +## Changing the type of model parameters + +```@docs +Flux.f64 +Flux.f32 +``` + +The default `eltype` for models is `Float32` since models are often trained/run on GPUs. The `eltype` of model `m` can be changed to `Float64` by `f64(m)`, or to `Float32` by `f32(m)`. + ## Model Building Flux provides some utility functions to help you generate models in an automated fashion. diff --git a/src/functor.jl b/src/functor.jl index 1e7f9e1fc2..afda1f5b84 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -74,7 +74,18 @@ adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs) paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m) +""" + f32(m) + +Convert the `eltype` of model's parameters to `Float32`. +""" f32(m) = paramtype(Float32, m) + +""" + f64(m) + +Convert the `eltype` of model's parameters to `Float64`. +""" f64(m) = paramtype(Float64, m) # Functors for certain Julia data structures