Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Jan 11, 2023
1 parent 0b1dc3b commit 5301e06
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions src/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ function __inference_postprocess end
struct DefaultPostprocess end

__inference_postprocess(::DefaultPostprocess, result::Marginal) = __inference_postprocess(DefaultPostprocess(), result, ReactiveMP.getaddons(result))
__inference_postprocess(::DefaultPostprocess, result::AbstractVector) = map((element) -> __inference_postprocess(DefaultPostprocess(), element), result)
__inference_postprocess(::DefaultPostprocess, result::AbstractArray) = map((element) -> __inference_postprocess(DefaultPostprocess(), element), result)

# Default postprocessing step removes Marginal type wrapper if no addons are present, and keeps the Marginal type wrapper otherwise
__inference_postprocess(::DefaultPostprocess, result, addons::Nothing) = __inference_postprocess(UnpackMarginalPostprocess(), result)
Expand All @@ -151,7 +151,7 @@ __inference_postprocess(::DefaultPostprocess, result, addons::Any) = __inference
struct UnpackMarginalPostprocess end

__inference_postprocess(::UnpackMarginalPostprocess, result::Marginal) = getdata(result)
__inference_postprocess(::UnpackMarginalPostprocess, result::AbstractVector) = map((element) -> __inference_postprocess(UnpackMarginalPostprocess(), element), result)
__inference_postprocess(::UnpackMarginalPostprocess, result::AbstractArray) = map((element) -> __inference_postprocess(UnpackMarginalPostprocess(), element), result)

"""This postprocessing step does nothing"""
struct NoopPostprocess end
Expand Down Expand Up @@ -640,9 +640,8 @@ Base.string(::FromMessageAutoUpdate) = "μ"

import Base: fetch

# TODO for arrays
Base.fetch(strategy::Union{ FromMarginalAutoUpdate, FromMessageAutoUpdate }, variables::AbstractArray) = (Base.fetch(strategy, variable) for variable in variables)
Base.fetch(::FromMarginalAutoUpdate, variable::Union{DataVariable, RandomVariable}) = ReactiveMP.getmarginal(variable, IncludeAll())

Base.fetch(::FromMessageAutoUpdate, variable::RandomVariable) = ReactiveMP.messagein(variable, 1) # Here we assume that predictive message has index `1`
Base.fetch(::FromMessageAutoUpdate, variable::DataVariable) = error("`FromMessageAutoUpdate` fetch strategy is not implemented for `DataVariable`")

Expand Down Expand Up @@ -693,9 +692,11 @@ end

import Base: fetch

# TODO for arrays
Base.fetch(autoupdate::RxInferenceAutoUpdate) = fetch(autoupdate, ReactiveMP.getdata(Rocket.getrecent(autoupdate.recent)))
Base.fetch(autoupdate::RxInferenceAutoUpdate, something) = zip(as_tuple(autoupdate.datavars), as_tuple(autoupdate.callback(something)))
Base.fetch(autoupdate::RxInferenceAutoUpdate) = fetch(autoupdate, autoupdate.recent)
Base.fetch(autoupdate::RxInferenceAutoUpdate, something) = fetch(autoupdate, something, ReactiveMP.getdata(ReactiveMP.getrecent(something)))
Base.fetch(autoupdate::RxInferenceAutoUpdate, something::Union{AbstractArray, Base.Generator}) = fetch(autoupdate, something, ReactiveMP.getdata.(ReactiveMP.getrecent.(something)))

Base.fetch(autoupdate::RxInferenceAutoUpdate, _, data) = zip(as_tuple(autoupdate.datavars), as_tuple(autoupdate.callback(data)))

"""
@autoupdates
Expand Down Expand Up @@ -1113,7 +1114,8 @@ function Rocket.on_next!(executor::RxInferenceEventExecutor{T}, event::T) where
fupdates = map(fetch, _autoupdates)

# This loop correspond to the different VMP iterations
for iteration in 1:_iterations
# Here `_iterations` can be `Ref` too, so we use `[]`. Should not affect integers
for iteration in 1:_iterations[]
inference_invoke_event(Val(:before_iteration), Val(_enabled_events), _events, _model, iteration)

# At first we update all our priors (auto updates) with the fixed values from the `redirectupdate` field
Expand Down Expand Up @@ -1667,8 +1669,8 @@ function rxinference(;

# `iterations` might be set to `nothing` in which case we assume `1` iteration
_iterations = something(iterations, 1)
_iterations isa Integer || error("`iterations` argument must be of type Integer or `nothing`")
_iterations > 0 || error("`iterations` arguments must be greater than zero")
(_iterations isa Integer || _iterations isa Ref{<:Integer}) || error("`iterations` argument must be of type Integer, Ref{<:Integer}, or `nothing`")
_iterations[] > 0 || error("`iterations` arguments must be greater than zero")

_keephistory = something(keephistory, 0)
_keephistory isa Integer || error("`keephistory` argument must be of type Integer or `nothing`")
Expand All @@ -1689,7 +1691,7 @@ function rxinference(;

if is_free_energy
if _keephistory > 0
fe_actor = ScoreActor(S, _iterations, _keephistory)
fe_actor = ScoreActor(S, _iterations[], _keephistory)
end
fe_scheduler = PendingScheduler()
fe_objective = BetheFreeEnergy(BetheFreeEnergyDefaultMarginalSkipStrategy, fe_scheduler, free_energy_diagnostics)
Expand Down Expand Up @@ -1725,7 +1727,7 @@ function rxinference(;
if isnothing(historyvars)
# First what we do - we check if `historyvars` is nothing
# In which case we mirror the `returnvars` specication and use either `KeepLast()` or `KeepEach` (depending on the iterations spec)
historyoption = _iterations > 1 ? KeepEach() : KeepLast()
historyoption = _iterations[] > 1 ? KeepEach() : KeepLast()
historyvars = Dict(name => historyoption for name in returnvars)
elseif historyvars === KeepEach() || historyvars === KeepLast()
# Second we check if it is one of the two possible global values: `KeepEach` and `KeepLast`.
Expand All @@ -1748,7 +1750,7 @@ function rxinference(;
history = nothing

if !isnothing(historyvars) && _keephistory > 0
historyactors = Dict(name => make_actor(vardict[name], historyoption, _iterations) for (name, historyoption) in pairs(historyvars))
historyactors = Dict(name => make_actor(vardict[name], historyoption, _iterations[]) for (name, historyoption) in pairs(historyvars))
history = Dict(name => CircularBuffer(_keephistory) for (name, _) in pairs(historyvars))
end

Expand Down

0 comments on commit 5301e06

Please sign in to comment.