forked from FluxML/Flux.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloading.jl
103 lines (86 loc) · 3.96 KB
/
loading.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
loadleaf!(dst, src, err) = dst
loadleaf!(dst::AbstractArray, src, err) =
error("Tried to copy $src into an array destination; this is not allowed.")
loadleaf!(dst, src::AbstractArray, err) =
error("Tried to copy an array to $dst; this is not allowed.")
function loadleaf!(dst::AbstractArray, src::Bool, err)
if iszero(src)
dst .= src
else
error("Cannot copy boolean parameter == true to non-zero parameter.")
end
return dst
end
loadleaf!(dst::Bool, src::AbstractArray, err) = iszero(dst) ? dst :
error("Cannot copy non-zero parameter to boolean parameter == true.")
function loadleaf!(dst::AbstractArray, src::AbstractArray, err)
(size(dst) == size(src)) || throw(err)
copyto!(dst, src)
end
_tie_check(dst::Bool, src::AbstractArray) = iszero(dst) ||
error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.")
_tie_check(dst::AbstractArray, src::Bool) = (iszero(dst) && iszero(src)) ||
error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.")
_tie_check(dst::AbstractArray, src::AbstractArray) = (dst == src) ||
error("Encountered tied destination parameters with untied and mismatched sources.")
_tie_check(dst, src) = true
_bool_tie_check(dst, src) = true
_filter_children(f, children::NamedTuple) =
NamedTuple(filter(kv -> f(kv[2]), pairs(children)))
_filter_children(f, children) = filter(f, children)
"""
loadmodel!(dst, src)
Copy all the parameters (trainable and non-trainable) from `src` into `dst`.
Recursively walks `dst` and `src` together using [`Functors.children`](@ref),
and calling `copyto!` on parameter arrays or throwing an error when there is a mismatch.
Non-array elements (such as activation functions) are not copied and need not match.
Zero bias vectors and `bias=false` are considered equivalent
(see extended help for more details).
# Examples
```julia
julia> dst = Chain(Dense(Flux.ones32(2, 5), Flux.ones32(2), tanh), Dense(2 => 1; bias = [1f0]))
Chain(
Dense(5 => 2, tanh), # 12 parameters
Dense(2 => 1), # 3 parameters
) # Total: 4 arrays, 15 parameters, 316 bytes.
julia> dst[1].weight ≈ ones(2, 5) # by construction
true
julia> src = Chain(Dense(5 => 2, relu), Dense(2 => 1, bias=false));
julia> Flux.loadmodel!(dst, src);
julia> dst[1].weight ≈ ones(2, 5) # values changed
false
julia> iszero(dst[2].bias)
true
```
# Extended help
Throws an error when:
- `dst` and `src` do not share the same fields (at any level)
- the sizes of leaf nodes are mismatched between `dst` and `src`
- copying non-array values to/from an array parameter
(except inactive parameters described below)
- `dst` is a "tied" parameter (i.e. refers to another parameter) and
loaded into multiple times with mismatched source values
Inactive parameters can be encoded by using the boolean value `false` instead of an array.
If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied);
however, attempting to copy a non-zero array to an inactive parameter will throw an error.
Likewise, copying a `src` value of `false` to any `dst` array is valid,
but copying a `src` value of `true` will error.
"""
function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet())
ldsts = _filter_children(filter, functor(dst)[1])
lsrcs = _filter_children(filter, functor(src)[1])
(keys(ldsts) == keys(lsrcs)) ||
throw(ArgumentError("Tried to load $src into $dst but the structures do not match."))
err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.")
foreach(ldsts, lsrcs) do ldst, lsrc
if ldst in cache # we already loaded this parameter before
_tie_check(ldst, lsrc) && return ldst
elseif Functors.isleaf(ldst) # our first time loading this leaf
push!(cache, ldst)
loadleaf!(ldst, lsrc, err)
else # this isn't a leaf
loadmodel!(ldst, lsrc; filter = filter, cache = cache)
end
end
return dst
end