forked from TuringLang/AdvancedVI.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAdvancedVIForwardDiffExt.jl
42 lines (36 loc) · 1011 Bytes
/
AdvancedVIForwardDiffExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
module AdvancedVIForwardDiffExt
if isdefined(Base, :get_extension)
using ForwardDiff
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
else
using ..ForwardDiff
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
end
getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize
function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoForwardDiff,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
)
chunk_size = getchunksize(ad)
config = if isnothing(chunk_size)
ForwardDiff.GradientConfig(f, x)
else
ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk(length(x), chunk_size))
end
ForwardDiff.gradient!(out, f, x, config)
return out
end
function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoForwardDiff,
f,
x::AbstractVector,
aux,
out::DiffResults.MutableDiffResult,
)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end
end