forked from TuringLang/AdvancedVI.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAdvancedVITapirExt.jl
37 lines (33 loc) · 889 Bytes
/
AdvancedVITapirExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
module AdvancedVITapirExt
if isdefined(Base, :get_extension)
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using Tapir
else
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..Tapir
end
function AdvancedVI.value_and_gradient!(
::ADTypes.AutoTapir, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
)
rule = Tapir.build_rrule(f, x)
y, g = Tapir.value_and_gradient!!(rule, f, x)
DiffResults.value!(out, y)
DiffResults.gradient!(out, last(g))
return out
end
function AdvancedVI.value_and_gradient!(
::ADTypes.AutoTapir,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult,
)
rule = Tapir.build_rrule(f, x, aux)
y, g = Tapir.value_and_gradient!!(rule, f, x, aux)
DiffResults.value!(out, y)
DiffResults.gradient!(out, g[2])
return out
end
end