Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calibration Example of Complex Model #1516

Open
ClaudMor opened this issue Jan 8, 2021 · 1 comment
Open

Calibration Example of Complex Model #1516

ClaudMor opened this issue Jan 8, 2021 · 1 comment

Comments

@ClaudMor
Copy link

ClaudMor commented Jan 8, 2021

Hello,

Following the suggestion from @cpfiffer from this issue, here I'm posting a self-contained script with a complex model whose calibration with Turing I may not be able optimize properly.

Unfortunately I am not a domain expert in MCMC or ADVI, so the best I could do was to try all the techniques and gridsearching all the arguments in the tutorials. I also implemented the suggestions received in these issues:

  1. Type inference in Turing model when using vectorized ~
  2. :forwarddiff backend with type specification goes OOM
  3. Inference using :reversediff, :tracker and :zygote performs worse than :forwarddiff

but I think that the amount of domain expertise required to properly tune the samplers and the ADVI is probably worth a question. This issue may of course be used as an example on how to tune Turing to calibrate models with over 30 parameters and around 15 time series to reproduce.

I would then like to present here a proxy to our model, and a calibration pipeline with ( I hope!) as many details as possible already set up , so that if you think it's appropriate you may help us optimize its calibration as much as Turing allows, exploring Turing's parameters and functionalities surely better and more efficiently than what I managed to.

Our model is an age-stratified S_E_I_H_ICU_R_D epidemiological model, with 32 parameters in total. We have 5 age classes, and the data we would like to calibrate the parameters against are 5 time series of hospitalizations ( one per age class), 5 time series of ICU occupancies and 5 time series of cumulated deaths.
Here is the code:

using DifferentialEquations
using Memoization, Turing, ReverseDiff, Zygote
using DiffEqParamEstim , Optim

using Distributions
import Statistics

using Bijectors                                            # to use the more advanced interfaces of ADVI
using Bijectors: Scale, Shift

using Plots
pyplot()

using BenchmarkTools 



# # I leave these commented for now, but I noticed that performance varies largely by playing with them. 
# Turing.setredcache(false)
# Turing.setadbackend(:forwarddiff)



# define the epidemiological model as a DifferentialEquations.jl. the model is wrapped by a closure that allows to specify the last 32-12 = 20 parameters values. In this way, the 32-parameters model effectively becomes a 12-parameters model. The first 20 may be initialized with a previou ssuccessful calibration. This may  let one reduce calibration times when testing.
function SEIH_ICU_RD!(; calibrated_parameters::Vector{Float64})
    function parameterized_SEIH_ICU_RD!(du,u,p,t)

        # passed parameters
        if length(calibrated_parameters) == 0
            β,                                                      # the transmissibility
            ϵ,                                                      # incubation period
            λ_IR_1, λ_IR_2, λ_IR_3, λ_IR_4, λ_IR_5,                 # age-stratified delays from infected to recovered ( without going through hospitlazation)
            λ_IH_1, λ_IH_2, λ_IH_3, λ_IH_4, λ_IH_5,                 # age-stratified delays from infected to hospitalized
            λ_HICU_1, λ_HICU_2, λ_HICU_3,λ_HICU_4, λ_HICU_5,        # age-stratified delays from hopsitalized to ICU
            λ_HR_1, λ_HR_2, λ_HR_3, λ_HR_4, λ_HR_5,                 # age-stratified delays from hospitalized to recovery
            λ_ICUD_1, λ_ICUD_2,  λ_ICUD_3, λ_ICUD_4, λ_ICUD_5,      # age-stratified delays from ICU to death
            λ_ICUR_1, λ_ICUR_2, λ_ICUR_3, λ_ICUR_4,  λ_ICUR_5 = p   # age-stratified delays from ICU to recovery
        else
            β,                                                      
            ϵ,                                                      
            λ_IR_1, λ_IR_2, λ_IR_3, λ_IR_4, λ_IR_5,                 
            λ_IH_1, λ_IH_2, λ_IH_3, λ_IH_4, λ_IH_5 = p              
            
            λ_HICU_1, λ_HICU_2, λ_HICU_3,λ_HICU_4, λ_HICU_5,                           
            λ_HR_1, λ_HR_2, λ_HR_3, λ_HR_4, λ_HR_5,                                     
            λ_ICUD_1, λ_ICUD_2,  λ_ICUD_3, λ_ICUD_4, λ_ICUD_5,                          
            λ_ICUR_1, λ_ICUR_2, λ_ICUR_3, λ_ICUR_4,  λ_ICUR_5 = calibrated_parameters   
        end
        



        # group parameters by age classes
        λ_IR  =  [λ_IR_1, λ_IR_2, λ_IR_3, λ_IR_4, λ_IR_5]


        λ_IH   = [λ_IH_1, λ_IH_2, λ_IH_3, λ_IH_4, λ_IH_5]


        λ_HICU = [λ_HICU_1, λ_HICU_2, λ_HICU_3, λ_HICU_4, λ_HICU_5]


        λ_HR   = [λ_HR_1, λ_HR_2, λ_HR_3, λ_HR_4, λ_HR_5]


        λ_ICUD = [λ_ICUD_1, λ_ICUD_2,  λ_ICUD_3, λ_ICUD_4, λ_ICUD_5]


        λ_ICUR = [λ_ICUR_1, λ_ICUR_2, λ_ICUR_3, λ_ICUR_4, λ_ICUR_5]


        # State variables

        # Susceptibles
        S       =   @view u[5*0+1:5*1]
        # Exposed
        E       =   @view u[5*1+1:5*2]
        # Infected
        I       =   @view u[5*2+1:5*3]
        # Hospitalized
        H       =   @view u[5*3+1:5*4] 
        # ICU occupancy 
        ICU     =   @view u[5*4+1:5*5]
        # Cumulated Recovered
        R       =   @view u[5*5+1:5*6]
        # Daily Deaths ( age disaggregated)
        D       =   @view u[5*6+1:5*7]
        # Daily Hospistal admissions ( all those who get transfered to ICU are first admitted to the hospital )
        AdH     =   @view u[5*7+1:5*8]
        # Daily ICU admissions 
        AdICU   =   @view u[5*8+1:5*9]



        # State variables differentials
        dS       =   @view du[5*0+1:5*1]
        dE       =   @view du[5*1+1:5*2]
        dI       =   @view du[5*2+1:5*3]
        dH       =   @view du[5*3+1:5*4]
        dICU     =   @view du[5*4+1:5*5]
        dR       =   @view du[5*5+1:5*6]
        dD       =   @view du[5*6+1:5*7]
        dAdH     =   @view du[5*7+1:5*8]
        dAdICU   =   @view du[5*8+1:5*9]

        # Force of infection
        Λ =  β * μ .* [sum([C_5[i,j]*(I[j])/N_5[j] for j in 1:size(C_5)[1]]) for i in 1:size(C_5)[2]]  #β

        # System of equations

        @. dS   = - Λ*S

        @. dE   = Λ*S - ϵ*E 

        @. dI   =  ϵ * E - ((1 - η )*λ_IR + η* λ_IH)*I

        @. dH   = η* λ_IH * I -* λ_HICU + ( 1- χ ) * λ_HR ) * H

        @. dICU = λ_HICU * χ * H - (δ_ICU * λ_ICUD + (1 - δ_ICU) * λ_ICUR) * ICU 

        @. dR   = (1 - η ) * λ_IR * I + (1-χ)*λ_HR*H + (1-δ_ICU)*λ_ICUR*ICU

        @. dD   =  δ_ICU*λ_ICUD*ICU

        @. dAdH     =  η * λ_IH * I

        @. dAdICU   =  λ_HICU * χ * H 

    end
end;

Model initial conditions, contact matrix and population:

# Initial conditions.
# contact matrix

const C_5 = [2.87954   3.07131  2.21196  0.343359  0.256795;
             2.93036   4.6993   5.49855  1.01668   0.233958;
             1.53385   3.99626  2.88881  1.00161   0.631737;
             0.597362  1.85386  2.51295  1.41536   0.943012;
             0.402476  0.38432  1.42786  0.849536  7.48054e-5 ]

# age-stratified population
const N_5             = [922868, 967257, 1330871, 530458, 588825]
# number of age classes
const num_age_classes = length(N_5)
# initial infected
const I₀              = repeat([5],num_age_classes)
# initial susceptibles
const S₀              = N_5 .- I₀
# initial exposed
const E₀              = [0 for n in 1:num_age_classes]
# initial hospitalized
const H₀              = [0 for n in 1:num_age_classes]
# initial ICUs
const ICU₀            = [0 for n in 1:num_age_classes]
# initial recovered
const R₀              = [0 for n in 1:num_age_classes]
# initial deaths
const D₀              = [0 for n in 1:num_age_classes]
# initial hopsitalizal admissons
const AdH₀            = [0 for n in 1:num_age_classes]
# initial ICU admissions
const AdICU₀          = [0 for n in 1:num_age_classes]
# collect initial condition in an array for the SEIH_ICU_RD! model
const B = vcat([S₀, E₀, I₀, H₀, ICU₀, R₀, D₀, AdH₀, AdICU₀ ]...);

Model initial parameter values:

# Model intial parameter values.

# The following values have been derived from literature and consitute the initial condition for the parameters ( epidemiological delays) to be calibrated
const P = [ 0.15,
            0.1724137931034483,
            0.5, 0.5, 0.5, 0.5, 0.5,
            0.11834005925895003, 0.11834005925895003, 0.11834005925895003, 0.12104523689190094, 0.12104523689190094, 
            0.3846153745721452, 0.3846153745721452, 0.3846153745721452, 0.3846153745721452, 0.3846153745721452,
            0.1111111111111111, 0.1111111111111111, 0.1111111111111111, 0.1111111111111111, 0.1111111111111111, 
            0.08522883763258698, 0.08522883763258698, 0.08522883763258698, 0.08522883763258698, 0.08522883763258698, 
            0.04947445517352199, 0.04947445517352199, 0.04947445517352199, 0.04947445517352199, 0.04947445517352199 ]

#The following values have been derived from literature and consitute the initial condition for the parameters ( epidemiological delays) that must remain constant and thus they are not to be calibrated

# susceptibility
const μ     = [ 0.4759758925436791 ,0.8266397761918497 ,0.8280047127031845 ,0.8108310931308417 ,0.7399999999999999]

# hospitalization rate
const η     =  [0.0027585284135976107 ,0.032802784575350706 ,0.10232295466653041 ,0.20404289877803708 ,0.261949855298025]

# ICU fraction
const χ     = [ 0.05 ,0.053934808432505525 ,0.1401166566857344 ,0.35206205203805013 ,0.6069703305850978]

# ICU conditioned fatality rate
const δ_ICU =  [0.007842851966738704, 0.018512721107622265, 0.06495525917707529, 0.1850599579504777, 0.45689350535412815]; 

Model parameter priors

# model uninformative priors

const priors = [ Uniform(0.11, 0.29) ,
                Uniform(0.08608442196509436, 0.4459389648119829) ,
                Uniform(0.05, 0.95) ,Uniform(0.05, 0.95) ,Uniform(0.05, 0.95) ,Uniform(0.05, 0.95) ,Uniform(0.05, 0.95) ,
                Uniform( 0.05756513737673442, 0.8451135758017284) ,Uniform(0.05756513737673442, 0.8451135758017284) ,Uniform(0.05756513737673442, 0.8451135758017284) ,Uniform(0.05699050074851731, 0.978307745097927) ,Uniform( 0.05699050074851731, 0.978307745097927) ,
                Uniform(0.0717764830429439, 20.0) ,Uniform(0.0717764830429439, 20.0) ,Uniform(0.0717764830429439, 20.0) ,Uniform( 0.0717764830429439, 20.0) ,Uniform(0.0717764830429439, 20.0) ,
                Uniform(0.06825938566552901, 0.11976047904191617) ,Uniform(0.06825938566552901, 0.11976047904191617) ,Uniform( 0.06825938566552901, 0.11976047904191617) ,Uniform(0.06825938566552901, 0.11976047904191617) ,Uniform(0.06825938566552901, 0.11976047904191617) ,
                Uniform(0.033601109960270686, 0.6397852910389515) ,Uniform( 0.033601109960270686, 0.6397852910389515) ,Uniform(0.033601109960270686, 0.6397852910389515) ,Uniform(0.033601109960270686, 0.6397852910389515) ,Uniform(0.033601109960270686, 0.6397852910389515) ,
                Uniform( 0.015723763849568886, 1.397487888396577) ,Uniform(0.015723763849568886, 1.397487888396577) ,Uniform(0.015723763849568886, 1.397487888396577) ,Uniform(0.015723763849568886, 1.397487888396577) ,Uniform( 0.015723763849568886, 1.397487888396577) ]

Time span and time interval we are going to calibrate and solve within:

# time span
const t = 1:30
# time interval
const T = (1.0, 30.0)

Hard coded calibration data. Each timeseries coresponds to an age class.

# calibration data
const H_calibration_data = vcat([ 
                            [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.005517056827195221, 0.005517056827195221, 0.005517056827195221, 0.005517056827195221, 0.019309698895183275, 0.019309698895183275, 0.03586086937676894, 0.03861939779036655, 0.04413645461756177, 0.10758260813030682, 0.16551170481585664, 0.23999197198299213, 0.4082622052124464, 0.799973239943307, 0.7503197284985501, 1.026172569858311, 1.0868601949574586, 1.282715712322889, 1.906143133795949, 1.8978675485551562, 2.474399986997057, 3.3957484771386586, 4.369509007138616, 5.536366526090404],
                            [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.06560556915070141, 0.06560556915070141, 0.06560556915070141, 0.06560556915070141, 0.22961949202745494, 0.22961949202745494, 0.4264361994795592, 0.45923898405490987, 0.5248445532056113, 1.2793085984386776, 1.9681670745210424, 2.8538422580555114, 4.854812117151904, 9.512807526851704, 8.922357404495392, 12.202635862030462, 12.924297122688179, 15.253294827538078, 22.66672414156734, 22.568315787841286, 29.424097764089584, 40.38022781225672, 51.95961076735552, 65.83518864272887],
                            [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.20464590933306082, 0.20464590933306082, 0.20464590933306082, 0.20464590933306082, 0.7162606826657129, 0.7162606826657129, 1.3301984106648952, 1.4325213653314257, 1.6371672746644865, 3.990595231994686, 6.139377279991825, 8.902097055988145, 15.1437972906465, 29.673656853293817, 27.83184366929627, 38.06413913594931, 40.31524413861298, 47.58017391993664, 70.70516167457251, 70.39819281057292, 91.78369033587778, 125.95955719449893, 162.07956019178417, 205.36217001572652],
                            [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.40808579755607416, 0.40808579755607416, 0.40808579755607416, 0.40808579755607416, 1.4283002914462595, 1.4283002914462595, 2.6525576841144822, 2.856600582892519, 3.2646863804485933, 7.957673052343446, 12.242573926682224, 17.751732193689225, 30.19834901914949, 59.17244064563075, 55.49966846762609, 75.9039583454298, 80.39290211854662, 94.87994793178724, 140.99364305562364, 140.3815143592895, 183.02648020389927, 251.17680839576366, 323.20395166441074, 409.5140978475204],
                            [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.52389971059605, 0.52389971059605, 0.52389971059605, 0.52389971059605, 1.833648987086175, 1.833648987086175, 3.405348118874325, 3.66729797417235, 4.1911976847684, 10.216044356622975, 15.7169913178815, 22.789637410928176, 38.7685785841077, 75.96545803642725, 71.2503606410628, 97.4453461708653, 103.20824298742185, 121.80668271358162, 181.00735001093528, 180.22150044504122, 234.96902020232844, 322.46027187186877, 414.9285707920716, 525.7333595831362]
                            ])

const ICU_calibration_data = vcat([
                            [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.1, 0.15000000000000002, 0.65, 0.8500000000000001, 1.5, 1.9000000000000001, 2.25, 2.5, 3.3000000000000003, 3.75, 4.8500000000000005, 6.75, 7.5, 8.55, 9.3, 10.3, 11.350000000000001],
                            [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10786961686501105, 0.10786961686501105, 0.16180442529751657, 0.7011525096225718, 0.9168917433525939, 1.6180442529751657, 2.0495227204352098, 2.4270663794627487, 2.696740421625276, 3.5596973565453647, 4.045110632437915, 5.231676417953036, 7.281199138388246, 8.09022126487583, 9.222852241958444, 10.031874368446028, 11.110570537096137, 12.243201514178754],
                            [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2802333133714688, 0.2802333133714688, 0.4203499700572032, 1.8215165369145472, 2.381983163657485, 4.203499700572031, 5.324432954057907, 6.305249550858048, 7.0058328342867195, 9.24769934125847, 10.50874925143008, 13.591315698516237, 18.91574865257414, 21.01749850286016, 23.95994829326058, 26.061698143546597, 28.864031277261283, 31.806481067661707],
                            [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7041241040761003, 0.7041241040761003, 1.0561861561141503, 4.576806676494652, 5.9850548846468525, 10.561861561141503, 13.378357977445905, 15.842792341712256, 17.603102601902506, 23.23609543451131, 26.40465390285376, 34.15001904769086, 47.52837702513677, 52.80930780570752, 60.20261089850657, 65.48354167907732, 72.52478271983833, 79.91808581263739],
                            [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.2139406611701955, 1.2139406611701955, 1.8209109917552933, 7.8906142976062705, 10.318495619946662, 18.209109917552933, 23.064872562233713, 27.313664876329398, 30.348516529254887, 40.060041818616455, 45.522774793882334, 58.87612206675448, 81.9409946289882, 91.04554958776467, 103.79192653005171, 112.89648148882819, 125.03588810053013, 137.7822650428172]
                                ])

const D_calibration_data = vcat([
                            [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01568570393347741, 0.03137140786695482, 0.03921425983369352, 0.03921425983369352, 0.10195707556760315, 0.13332848343455797, 0.1646998913015128, 0.2039141511352063, 0.3607711904699804, 0.46272826603758355, 0.635271009305835, 0.8705565683079962, 1.0430993115762477, 1.2077992028777604],
                            [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03702544221524453, 0.07405088443048906, 0.09256360553811133, 0.09256360553811133, 0.24066537439908944, 0.3147162588295785, 0.38876714326006756, 0.4813307487981789, 0.8515851709506241, 1.0922505453497136, 1.4995304097174034, 2.0549120429460714, 2.4621919073137613, 2.8509590505738287],
                            [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.12991051835415057, 0.25982103670830115, 0.32477629588537643, 0.32477629588537643, 0.8444183693019787, 1.1042394060102798, 1.364060442718581, 1.6888367386039573, 2.9879419221454633, 3.8323602914474417, 5.261375993343099, 7.210033768655356, 8.639049470551013, 10.003109913269594],
                            [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3701199159009554, 0.7402398318019108, 0.9252997897523885, 0.9252997897523885, 2.4057794533562102, 3.146019285158121, 3.886259116960032, 4.8115589067124205, 8.512758065721973, 10.918537519078184, 14.989856593988693, 20.541655332503023, 24.61297440741353, 28.499233524373565],
                            [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9137870107082563, 1.8275740214165126, 2.2844675267706407, 2.2844675267706407, 5.939615569603666, 7.767189591020179, 9.594763612436692, 11.879231139207333, 21.017101246289894, 26.95671681589356, 37.00837393368438, 50.715179094308226, 60.766836212099044, 70.36159982453573]
                                ])

You may check that the model compiles and runs correctly

# you may check it works
problem = ODEProblem(SEIH_ICU_RD!(calibrated_parameters = Float64[]) , B, (1.0, 120.0), P)
solution = solve(problem, Tsit5(), saveat = 1:120);
plot(solution.t, [slice[6] for slice in solution.u ])

And we can also optimize it via ModelingToolkit.jl:

# and you can toolktize it to get 5x better solve performance
fast_problem = ODEProblem(modelingtoolkitize(problem), B, (1.0, 120.0), P)
fast_solution = solve(fast_problem, Tsit5(), saveat = 1:120);
plot(fast_solution.t, [slice[6] for slice in fast_solution.u ])

Now to the calibration. We will show that the model is "calibratable" with these data by using Optim via DiffeentialEquations.jl. It may also be useful to get starting values for other optimization algorithms.

# prepare data ( and convert them to Int)
const optim_data = convert(Array,convert(Array, VectorOfArray([Array{Int64,1}(floor.(data)) for data in vcat(H_calibration_data, ICU_calibration_data, D_calibration_data) ]))')

# define indexes w.r.t. compare the data
const save_idxs = vcat(5*3+1:5*4, 5*4+1:5*5 ,5*6+1:5*7);

# set bounds for FminBox
const lower = [uniform.a for uniform in priors];
const upper = [uniform.b for  uniform in priors];

problem = ODEProblem(SEIH_ICU_RD!(calibrated_parameters = Float64[]) , B, T, P)
fast_problem = ODEProblem(modelingtoolkitize(problem), B, T, P; jac = true, sparse = true)
cost_function = build_loss_objective(fast_problem,Tsit5(),L2Loss(t,optim_data), maxiters=5*10^7,verbose=true, save_idxs=save_idxs) 

result_Optim = @time Optim.optimize(  cost_function,lower,upper, P, Optim.Fminbox(BFGS()) ,Optim.Options(show_trace = false , iterations =  200,time_limit = 180)) # will take approx 1 minute and a half.

One may plot the simuated trajectories using the calibrated parameters togheter with the data used to calibrate

# get calibrated parameters
const P_optim = result_Optim.minimizer;

# get the soluton with calibrated parameters
const optim_calibrated_solution = solve(fast_problem , Tsit5(); p = P_optim , saveat = 1:30 );


# plot solution using calibrated parameters togheter with calibration points
# check that calibration was somewhat successful: please keep in mind that the data are fake, and the time series that are poorly reproduced by  the model tend to have a maximum value of <10, so they are probably ignored by the loss during calibration. Anyway there are techniques to account for it, but we'll omit them here for the sake of simplicity.
plots =[]
for (i, idx) in enumerate(save_idxs)
    p = plot(optim_calibrated_solution.t , [optim_calibrated_solution.u[j][idx] for j in t], size = (1400, 3000), label = "simulated")
    plot!(optim_calibrated_solution.t , optim_data[i,:], seriestype = :scatter , label = "data")
    push!(plots, p)
end
plot(plots..., layout = (length(plots), 1))

now let's try to calibrate using NUTS. This is the Turing model we are using:

# now let's write down the turing model we are currently using
@model function turing_model(data::Array{Int64,2} , problem::ODEProblem, priors , save_idxs::Array{Int64,1} ,  𝒯::Tuple{Float64,Float64} , ::Type{T} = Float64  ) where {T}
    σ ~ InverseGamma(2, 3)

    
    p ~ arraydist(priors) 


    predicted::Vector{Vector{T}}  = solve(problem, Tsit5(), p = p, saveat=𝒯[1]:𝒯[2], save_idxs = save_idxs).u   

    # when the solve doesn'y converge, predicted has length 1. We detect it, and return -Inf loglikelihood to reject the sample.
    if length(predicted) == size(data, 2)
        for i = 1:length(predicted)
            data[:,i] ~ MvNormal(predicted[i], σ) 
        end
    else
        Turing.@addlogprob! -Inf
        return
    end
    
end

Let's calibrate the model:

# instamtiate turing model
tur_mod = turing_model(optim_data, fast_problem, priors, save_idxs, T )

# sample from posterior. It takes approx 5 minutes
chain = @time sample(tur_mod, NUTS(), 1000)   # using MCMCThreads() over multilple chains seriously slows it  down. Should one use MCMCDistributed() preferably?

Get the parameters and check it correctly calibrated the time series:

# get the parameters
P_turing = vcat([Statistics.mean(chain["p[$i]"]) for i in 1:32])   

# evaluate the solution with such parameters
const turing_calibrated_solution = solve(fast_problem , Tsit5(); p = P_turing , saveat = 1:30 );

# plot
plots =[]
for (i, idx) in enumerate(save_idxs)
    p = plot(turing_calibrated_solution.t , [turing_calibrated_solution.u[j][idx] for j in t], size = (1400, 3000), label = "simulated")
    plot!(turing_calibrated_solution.t , optim_data[i,:], seriestype = :scatter , label = "data")
    push!(plots, p)
end
plot(plots..., layout = (length(plots), 1))

We can also try ADVI

advi = ADVI(10, 10000)
q = @time vi(tur_mod, advi)  #about 80 seconds

Sample from the multivariate posteriors and plot the simulations togheter with the data

# sample from the posterior
extractions_advi = [rand(q) for i in 1:10000] ;

#evaluate the parameter values
P_advi = [Statistics.mean([extraction[i] for extraction in extractions_advi]) for i in 2:33] 

# get the advi calibrated solution
const advi_calibrated_solution = solve(fast_problem, Tsit5(); p = P_advi , saveat = 1:30 );


# plot advi results
plots =[]
for (i, idx) in enumerate(save_idxs)
    p = plot(advi_calibrated_solution.t , [advi_calibrated_solution.u[j][idx] for j in t], size = (1400, 3000), label = "simulated")
    plot!(advi_calibrated_solution.t , optim_data[i,:], seriestype = :scatter , label = "data")
    push!(plots, p)
end
plot(plots..., layout = (length(plots), 1))

So trying also other models it seems that NUTS performs generally better than ADVI, but we are far from sure.

How would you calibrate this 32-parameter model? What exact setup would you use? Could you please present your practical reasoning, so that non-domain experts may try to reproduce it in other scenarios?

Thanks in advance

EDIT1: as suggested in this comment I moved save_idxs as a keyword argument to the solve inside the turing_model.

@devmotion
Copy link
Member

Completely unrelated comment: you should specify save_idxs as a keyword argument to solve.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants