using DiffEqBayes, CmdStan, DynamicHMC
using Distributions, BenchmarkTools using OrdinaryDiffEq, RecursiveArrayTools, ParameterizedFunctions using Plots
gr(fmt=:png)
Plots.GRBackend()
f = @ode_def LotkaVolterraTest begin dx = a*x - b*x*y dy = -c*y + d*x*y end a b c d
(::Main.WeaveSandBox10.LotkaVolterraTest{Main.WeaveSandBox10.var"###Paramet erizedDiffEqFunction#776",Main.WeaveSandBox10.var"###ParameterizedTGradFunc tion#777",Main.WeaveSandBox10.var"###ParameterizedJacobianFunction#778",Not hing,Nothing,ModelingToolkit.ODESystem}) (generic function with 1 method)
u0 = [1.0,1.0] tspan = (0.0,10.0) p = [1.5,1.0,3.0,1,0]
5-element Array{Float64,1}: 1.5 1.0 3.0 1.0 0.0
prob = ODEProblem(f,u0,tspan,p) sol = solve(prob,Tsit5())
retcode: Success Interpolation: specialized 4th order "free" interpolation t: 34-element Array{Float64,1}: 0.0 0.0776084743154256 0.23264513699277584 0.4291185174543143 0.6790821776882875 0.9444045910389707 1.2674601253261835 1.6192913723304114 1.9869755337814992 2.264090367186479 ⋮ 7.584862904164952 7.978068388305894 8.483164907244102 8.719247868929038 8.949206527971544 9.200184813643565 9.438028630962807 9.711807852444823 10.0 u: 34-element Array{Array{Float64,1},1}: [1.0, 1.0] [1.0454942346944578, 0.8576684823217127] [1.1758715885138267, 0.639459570317544] [1.4196809607170826, 0.4569962601282084] [1.8767193485546056, 0.32473343696185236] [2.5882499852859384, 0.26336255804531] [3.860708771268753, 0.2794458027885767] [5.750812903389158, 0.5220073140479389] [6.814978737433837, 1.917783300239219] [4.3929977807914105, 4.194671536988031] ⋮ [2.614252575185928, 0.26416950055716665] [4.2410731685818694, 0.30512345857554246] [6.791122470590543, 1.1345265418479897] [6.26537352594436, 2.741690196017545] [3.78076791065078, 4.431164786168439] [1.8164212283793362, 4.0640577258289365] [1.1465027088171469, 2.791172606389902] [0.9557986534742364, 1.6235632025270912] [1.03375813933372, 0.9063703701433561]
t = collect(range(1,stop=10,length=10)) sig = 0.49 data = convert(Array, VectorOfArray([(sol(t[i]) + sig*randn(2)) for i in 1:length(t)]))
2×10 Array{Float64,2}: 2.55558 6.27639 1.36855 2.21923 … 3.53585 3.96592 2.52811 0.415488 2.29089 1.91952 0.421956 0.789497 5.02123 0.985163
scatter(t, data[1,:], lab="#prey (data)") scatter!(t, data[2,:], lab="#predator (data)") plot!(sol)
priors = [Truncated(Normal(1.5,0.5),0.5,2.5),Truncated(Normal(1.2,0.5),0,2),Truncated(Normal(3.0,0.5),1,4),Truncated(Normal(1.0,0.5),0,2)]
4-element Array{Truncated{Normal{Float64},Continuous,Float64},1}: Truncated(Normal{Float64}(μ=1.5, σ=0.5), range=(0.5, 2.5)) Truncated(Normal{Float64}(μ=1.2, σ=0.5), range=(0.0, 2.0)) Truncated(Normal{Float64}(μ=3.0, σ=0.5), range=(1.0, 4.0)) Truncated(Normal{Float64}(μ=1.0, σ=0.5), range=(0.0, 2.0))
The solution converges for tolerance values lower than 1e-3, lower tolerance leads to better accuracy in result but is accompanied by longer warmup and sampling time, truncated normal priors are used for preventing Stan from stepping into negative values.
@btime bayesian_result_stan = stan_inference(prob,t,data,priors,num_samples=10_000,printsummary=false)
File /Users/vaibhav/DiffEqBenchmarks.jl/tmp/parameter_estimation_model.stan will be updated. File /Users/vaibhav/DiffEqBenchmarks.jl/tmp/parameter_estimation_model.stan will be updated. File /Users/vaibhav/DiffEqBenchmarks.jl/tmp/parameter_estimation_model.stan will be updated. File /Users/vaibhav/DiffEqBenchmarks.jl/tmp/parameter_estimation_model.stan will be updated. 174.979 s (1361505 allocations: 56.94 MiB) DiffEqBayes.StanModel{Stanmodel,Int64,Array{Float64,3},Array{String,1}}( n ame = "parameter_estimation_model" nchains = 1 num_samples = 10000 num_warmup = 1000 thin = 1 monitors = String[] model_file = "parameter_estimation_model.stan" data_file = "parameter_estimation_model_1.data.R" output = Output() file = "parameter_estimation_model_samples_1.csv" diagnostics_file = "" refresh = 100 pdir = "/Users/vaibhav/DiffEqBenchmarks.jl" tmpdir = "/Users/vaibhav/DiffEqBenchmarks.jl/tmp" output_format = :array method = Sample() num_samples = 10000 num_warmup = 1000 save_warmup = false thin = 1 algorithm = HMC() engine = NUTS() max_depth = 10 metric = CmdStan.diag_e stepsize = 1.0 stepsize_jitter = 1.0 adapt = Adapt() gamma = 0.05 delta = 0.8 kappa = 0.75 t0 = 10.0 init_buffer = 75 term_buffer = 50 window = 25 , 0, [-7.34017 0.853264 … 2.73116 0.915208; -7.55609 0.916111 … 2.78655 0.9 90211; … ; -7.86709 0.998835 … 2.9281 0.984592; -7.17184 0.990564 … 3.15138 1.15395], ["lp__", "accept_stat__", "stepsize__", "treedepth__", "n_leapfr og__", "divergent__", "energy__", "sigma1.1", "sigma1.2", "theta1", "theta2 ", "theta3", "theta4", "theta.1", "theta.2", "theta.3", "theta.4"])
@btime bayesian_result_turing = turing_inference(prob,Tsit5(),t,data,priors,num_samples=10_000)
36.270 s (227070592 allocations: 17.04 GiB) Object of type Chains, with data of type 9000×17×1 Array{Float64,3} Iterations = 1:9000 Thinning interval = 1 Chains = 1 Samples per chain = 9000 internals = acceptance_rate, hamiltonian_energy, hamiltonian_energy _error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth parameters = theta[1], theta[2], theta[3], theta[4], σ[1] 2-element Array{MCMCChains.ChainDataFrame,1} Summary Statistics parameters mean std naive_se mcse ess r_hat ────────── ────── ────── ──────── ────── ───────── ────── theta[1] 1.5354 0.1108 0.0012 0.0025 1964.1719 1.0002 theta[2] 1.0029 0.1140 0.0012 0.0024 2301.4344 1.0000 theta[3] 2.8708 0.2866 0.0030 0.0062 2005.1551 1.0001 theta[4] 0.9931 0.1133 0.0012 0.0025 1982.7440 1.0001 σ[1] 0.6659 0.1251 0.0013 0.0019 3876.7909 0.9999 Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% ────────── ────── ────── ────── ────── ────── theta[1] 1.3386 1.4578 1.5296 1.6042 1.7718 theta[2] 0.8143 0.9250 0.9913 1.0679 1.2563 theta[3] 2.3434 2.6735 2.8580 3.0545 3.4664 theta[4] 0.7877 0.9149 0.9858 1.0665 1.2334 σ[1] 0.4662 0.5787 0.6496 0.7350 0.9609
@btime bayesian_result_dynamichmc = dynamichmc_inference(prob,Tsit5(),t,data,priors,num_samples=10_000)
21.949 s (100464976 allocations: 9.87 GiB) (posterior = NamedTuple{(:parameters, :σ),Tuple{Array{Float64,1},Array{Floa t64,1}}}[(parameters = [1.6510330907693331, 1.068371927076397, 2.6700385090 60144, 0.8756214940363967], σ = [1.203931721461571, 0.3308111772255924]), ( parameters = [1.6897436377513675, 1.0430237851150368, 2.601079414339675, 0. 8247486490381333], σ = [1.007210075547111, 0.26357292049156145]), (paramete rs = [1.3588194900440342, 0.8370391784177251, 3.3046299477504593, 1.2798289 643619438], σ = [1.0998922369939241, 0.3738164505949778]), (parameters = [1 .4430017363542815, 0.9091301577477175, 3.271178240763378, 1.086768588189342 5], σ = [1.2038981260993595, 0.2660212128891589]), (parameters = [1.4535668 044488257, 0.909523050743244, 3.276905486386487, 1.0955803011090495], σ = [ 1.193617628965834, 0.266783760453349]), (parameters = [1.45382751529481, 0. 9532735140562241, 3.0300413150609877, 1.1049828006952258], σ = [1.132553004 6772012, 0.2917324556865594]), (parameters = [1.4629441621588475, 0.9054342 337051878, 3.268017222062399, 1.0759377422089689], σ = [1.5434123158162087, 0.2630168067124241]), (parameters = [1.3432192939479721, 0.825200409360124 4, 3.23173317430294, 1.314322745979057], σ = [1.2728622042633615, 0.2834347 620521368]), (parameters = [1.3071165916749654, 0.8104259044697025, 3.68452 1124745626, 1.3408512948378726], σ = [0.9409149022983262, 0.419082388705157 9]), (parameters = [1.4258804864502665, 0.818718339392575, 3.17126692706493 03, 1.1349048291980846], σ = [0.8467645109894342, 0.29965312853945]) … (p arameters = [1.4929664889259344, 1.068060207207104, 2.9980815767360105, 1.0 362722020356832], σ = [0.8524511919141108, 0.4482810791932937]), (parameter s = [1.5068920418235292, 0.9936324443451314, 3.013292259319715, 1.006927312 596321], σ = [0.9778604786346289, 0.4288247971595753]), (parameters = [1.53 0315288282088, 0.9233321541133878, 2.840696323323293, 1.0067946206822225], σ = [0.9540683524575444, 0.35318040365018893]), (parameters = [1.5505068918 424791, 0.9340834095344153, 2.694088018241023, 0.981661958022394], σ = [0.8 378202680123857, 0.438497330199877]), (parameters = [1.580069085047083, 0.9 574854791458529, 2.784597017864487, 0.9118234023648312], σ = [1.21505488058 82329, 0.5631116480285596]), (parameters = [1.6025521791897022, 0.988843140 6893479, 2.6233074069937046, 0.9488797265890467], σ = [1.037985906111572, 0 .4863899486464513]), (parameters = [1.5077089778689483, 0.9704062659457484, 3.0367756170875526, 1.0314730363950464], σ = [1.0786621444009297, 0.217552 949931122]), (parameters = [1.5520298148311464, 0.9400878972595683, 2.87265 0856245965, 0.9998739642555784], σ = [1.227634634762198, 0.4718763018501869 ]), (parameters = [1.5273979739480217, 0.8662328033036745, 2.94837472761127 9, 0.9942321580199855], σ = [1.1752103688524014, 0.4380025823869909]), (par ameters = [1.4438018525248824, 1.0782035384607451, 3.0685242676380358, 1.11 4376664725165], σ = [0.9165385987338418, 0.5590224516059303])], chain = [[0 .5014012076016618, 0.06613592624053985, 0.9820928951771175, -0.132821365854 4598, 0.18559263552904492, -1.1062075279342756], [0.5245768237900944, 0.042 1239802792651, 0.9559265182340466, -0.19267660690069474, 0.0071842072197196 81, -1.3334252111772547], [0.3066163007722468, -0.17788440144269851, 1.1953 244996370895, 0.24672644741113411, 0.0952122086361469, -0.9839903758467082] , [0.3667254830857303, -0.0952670072137852, 1.1851502382863834, 0.083208695 12451416, 0.185564730432472, -1.3241792256765121], [0.3740204010537183, -0. 0948349369325416, 1.1868995277952565, 0.09128417824245073, 0.17698871993422 36, -1.3213168346592004], [0.37419974435428993, -0.047853413298526076, 1.10 85762547619769, 0.09982976986688438, 0.12447438052679104, -1.23191814437407 72], [0.3804509546376101, -0.09934063410839251, 1.1841834468928054, 0.07319 259966502849, 0.43399575467533663, -1.3355373449966152], [0.295069190844180 04, -0.1921290017113334, 1.173018579795021, 0.2733215108911501, 0.241268068 832914, -1.2607732985133422], [0.26782363622966066, -0.21019536151694698, 1 .304140564617526, 0.2933047069037269, -0.060902576750357054, -0.86968774663 49355], [0.35478950813638943, -0.20001516222454077, 1.1541311695804692, 0.1 2654879649088271, -0.16633265015393867, -1.2051297114868595] … [0.4007650 7284361934, 0.06584411273848387, 1.097972609695804, 0.03562985260400764, -0 .15963932435014783, -0.8023348343922363], [0.41004927927037105, -0.00638791 50096104766, 1.103033254886344, 0.006903429002392361, -0.022388279007001202 , -0.8467068417008177], [0.4254737849472855, -0.0797662456031932, 1.0440492 063891291, 0.006771641279273761, -0.047019961822117456, -1.0407762940956404 ], [0.43858190446463063, -0.06818954117058304, 0.9910597494073224, -0.01850 8268157665365, -0.17695167881579085, -0.8244015556721811], [0.4574685707963 2175, -0.04344472342525278, 1.024103165553915, -0.0923089454158774, 0.19478 9244980836, -0.5742773613937568], [0.4715974703994052, -0.01121956389250995 3, 0.964435890736649, -0.05247322540186434, 0.03728220672432128, -0.7207446 133330677], [0.410591265463622, -0.03004046429797, 1.1107963002546206, 0.03 098791299850871, 0.0757215180603821, -1.5253130101228625], [0.4395636321635 849, -0.061781900366717046, 1.0552352467668187, -0.0001260436875935144, 0.2 0508925675922857, -0.7510384000948402], [0.42356561700015377, -0.1436015805 6116548, 1.081254078735848, -0.0057845402200649935, 0.16144716889204658, -0 .8255304727623616], [0.36727980981178215, 0.0752962658522917, 1.12119675144 7984, 0.10829520349857374, -0.08715109722221948, -0.5815656427589329]], tre e_statistics = DynamicHMC.TreeStatisticsNUTS[DynamicHMC.TreeStatisticsNUTS( -24.972044404887683, 6, turning at positions -58:5, 0.9284041843982449, 63, DynamicHMC.Directions(0x65a2fb05)), DynamicHMC.TreeStatisticsNUTS(-25.3263 8209243423, 4, turning at positions -3:12, 0.9298976637327503, 15, DynamicH MC.Directions(0x047aa51c)), DynamicHMC.TreeStatisticsNUTS(-27.9351072019710 46, 5, turning at positions 21:36, 0.9343109632181998, 63, DynamicHMC.Direc tions(0x2dba49a4)), DynamicHMC.TreeStatisticsNUTS(-25.00498788899212, 6, tu rning at positions -45:18, 0.8699256646318402, 63, DynamicHMC.Directions(0x 8ed16c52)), DynamicHMC.TreeStatisticsNUTS(-22.50513635867809, 2, turning at positions -3:0, 0.9999999999999999, 3, DynamicHMC.Directions(0x2fcadc9c)), DynamicHMC.TreeStatisticsNUTS(-23.968698002582858, 5, turning at positions -15:-30, 0.9165666825262788, 47, DynamicHMC.Directions(0x5d090511)), Dynam icHMC.TreeStatisticsNUTS(-25.585626717018357, 4, turning at positions 0:15, 0.9670086865497939, 15, DynamicHMC.Directions(0x13c553ff)), DynamicHMC.Tre eStatisticsNUTS(-25.631172292787063, 5, turning at positions -27:-42, 0.971 1521254641908, 47, DynamicHMC.Directions(0x3253cfc5)), DynamicHMC.TreeStati sticsNUTS(-26.07656446522952, 5, turning at positions -25:6, 0.949838397156 478, 31, DynamicHMC.Directions(0x4c67ae26)), DynamicHMC.TreeStatisticsNUTS( -27.78911332837673, 6, turning at positions -57:6, 0.8146798188135186, 63, DynamicHMC.Directions(0xc22c3646)) … DynamicHMC.TreeStatisticsNUTS(-24.37 2869031737785, 2, turning at positions -3:0, 0.8507016296521978, 3, Dynamic HMC.Directions(0xd92cd1ac)), DynamicHMC.TreeStatisticsNUTS(-25.297885575321 747, 4, turning at positions 9:24, 0.9885396990076063, 31, DynamicHMC.Direc tions(0xe63419b8)), DynamicHMC.TreeStatisticsNUTS(-22.634170513642232, 5, t urning at positions 23:38, 0.8892474784021949, 47, DynamicHMC.Directions(0x c4729276)), DynamicHMC.TreeStatisticsNUTS(-25.08815603029298, 4, turning at positions -23:-26, 0.895586547306448, 27, DynamicHMC.Directions(0x9810a3a1 )), DynamicHMC.TreeStatisticsNUTS(-26.118326468048103, 5, turning at positi ons -24:7, 0.9844447809259682, 31, DynamicHMC.Directions(0x323379e7)), Dyna micHMC.TreeStatisticsNUTS(-24.28092130771148, 4, turning at positions -1:14 , 0.9924232112951569, 15, DynamicHMC.Directions(0xa26197ce)), DynamicHMC.Tr eeStatisticsNUTS(-24.846350847828873, 5, turning at positions 29:60, 0.7453 192520971835, 63, DynamicHMC.Directions(0x1dbc5a3c)), DynamicHMC.TreeStatis ticsNUTS(-24.7120873743796, 5, turning at positions -18:13, 0.8981696068490 232, 31, DynamicHMC.Directions(0xc9d4b76d)), DynamicHMC.TreeStatisticsNUTS( -30.37914990178817, 5, turning at positions -14:-45, 0.9122853621385598, 63 , DynamicHMC.Directions(0xc5496412)), DynamicHMC.TreeStatisticsNUTS(-25.526 836941259393, 5, turning at positions 0:31, 0.9531169057544701, 31, Dynamic HMC.Directions(0x9bf05aff))], κ = Gaussian kinetic energy (LinearAlgebra.Di agonal), √diag(M⁻¹): [0.07239305596109954, 0.09679410633782563, 0.093136522 85944729, 0.13242179029089104, 0.2883075001026171, 0.3132664821176181], ϵ = 0.08349883321492188)
Lotka-Volterra Equation is a "predator-prey" model, it models population of two species in which one is the predator (wolf) and the other is the prey (rabbit). It depicts a cyclic behaviour, which is also seen in its Uncertainity Quantification Plots. This behaviour makes it easy to estimate even at very high tolerance values (1e-3).
using DiffEqBenchmarks DiffEqBenchmarks.bench_footer(WEAVE_ARGS[:folder],WEAVE_ARGS[:file])
These benchmarks are a part of the DiffEqBenchmarks.jl repository, found at: https://github.com/JuliaDiffEq/DiffEqBenchmarks.jl
To locally run this tutorial, do the following commands:
using DiffEqBenchmarks
DiffEqBenchmarks.weave_file("ParameterEstimation","DiffEqBayesLotkaVolterra.jmd")
Computer Information:
Julia Version 1.4.0
Commit b8e9a9ecc6 (2020-03-21 16:36 UTC)
Platform Info:
OS: macOS (x86_64-apple-darwin18.6.0)
CPU: Intel(R) Core(TM) i7-6700HQ CPU @ 2.60GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-8.0.1 (ORCJIT, skylake)
Package Information:
Status: `/Users/vaibhav/DiffEqBenchmarks.jl/Project.toml`
[28f2ccd6-bb30-5033-b560-165f7b14dc2f] ApproxFun 0.11.10
[a134a8b2-14d6-55f6-9291-3336d3ab0209] BlackBoxOptim 0.5.0
[a93c6f00-e57d-5684-b7b6-d8193f3e46c0] DataFrames 0.20.2
[2b5f629d-d688-5b77-993f-72d75c75574e] DiffEqBase 6.25.2
[ebbdde9d-f333-5424-9be2-dbf1e9acfb5e] DiffEqBayes 2.9.1
[eb300fae-53e8-50a0-950c-e21f52c2b7e0] DiffEqBiological 4.2.0
[f3b72e0c-5b89-59e1-b016-84e28bfd966d] DiffEqDevTools 2.18.0
[c894b116-72e5-5b58-be3c-e6d8d4ac2b12] DiffEqJump 6.5.0
[1130ab10-4a5a-5621-a13d-e4788d82bd4c] DiffEqParamEstim 1.13.0
[a077e3f3-b75c-5d7f-a0c6-6bc4c8ec64a9] DiffEqProblemLibrary 4.6.4
[ef61062a-5684-51dc-bb67-a0fcdec5c97d] DiffEqUncertainty 1.4.1
[0c46a032-eb83-5123-abaf-570d42b7fbaa] DifferentialEquations 6.12.0
[7073ff75-c697-5162-941a-fcdaad2a7d2a] IJulia 1.21.1
[7f56f5a3-f504-529b-bc02-0b1fe5e64312] LSODA 0.6.1
[76087f3c-5699-56af-9a33-bf431cd00edd] NLopt 0.5.1
[c030b06c-0b6d-57c2-b091-7029874bd033] ODE 2.6.0
[54ca160b-1b9f-5127-a996-1867f4bc2a2c] ODEInterface 0.4.6
[09606e27-ecf5-54fc-bb29-004bd9f985bf] ODEInterfaceDiffEq 3.6.0
[1dea7af3-3e70-54e6-95c3-0bf5283fa5ed] OrdinaryDiffEq 5.32.2
[2dcacdae-9679-587a-88bb-8b444fb7085b] ParallelDataTransfer 0.5.0
[65888b18-ceab-5e60-b2b9-181511a3b968] ParameterizedFunctions 5.0.3
[91a5bcdd-55d7-5caf-9e0b-520d859cae80] Plots 0.29.9
[b4db0fb7-de2a-5028-82bf-5021f5cfa881] ReactionNetworkImporters 0.1.5
[f2c3362d-daeb-58d1-803e-2bc74f2840b4] RecursiveFactorization 0.1.0
[9672c7b4-1e72-59bd-8a11-6ac3964bc41f] SteadyStateDiffEq 1.5.0
[c3572dad-4567-51f8-b174-8c6c989267f4] Sundials 3.9.0
[a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f] TimerOutputs 0.5.3
[44d3d7a6-8a23-5bf8-98c5-b353f8df5ec9] Weave 0.9.4
[b77e0a4c-d291-57a0-90e8-8db25a27a240] InteractiveUtils
[d6f4376e-aef5-505a-96c1-9c027394607a] Markdown
[44cfe95a-1eb2-52ea-b672-e2afdf69b78f] Pkg
[9a3f8284-a2c9-5f02-9a11-845980a1fd5c] Random