Lotka-Volterra Bayesian Parameter Estimation Benchmarks

Vaibhav Dixit, Chris Rackauckas

Parameter Estimation of Lotka-Volterra Equation using DiffEqBayes.jl

using DiffEqBayes, CmdStan, DynamicHMC
using Distributions, BenchmarkTools
using OrdinaryDiffEq, RecursiveArrayTools, ParameterizedFunctions
using Plots
gr(fmt=:png)
Plots.GRBackend()

Initializing the problem

f = @ode_def LotkaVolterraTest begin
    dx = a*x - b*x*y
    dy = -c*y + d*x*y
end a b c d
(::Main.WeaveSandBox10.LotkaVolterraTest{Main.WeaveSandBox10.var"###Paramet
erizedDiffEqFunction#776",Main.WeaveSandBox10.var"###ParameterizedTGradFunc
tion#777",Main.WeaveSandBox10.var"###ParameterizedJacobianFunction#778",Not
hing,Nothing,ModelingToolkit.ODESystem}) (generic function with 1 method)
u0 = [1.0,1.0]
tspan = (0.0,10.0)
p = [1.5,1.0,3.0,1,0]
5-element Array{Float64,1}:
 1.5
 1.0
 3.0
 1.0
 0.0
prob = ODEProblem(f,u0,tspan,p)
sol = solve(prob,Tsit5())
retcode: Success
Interpolation: specialized 4th order "free" interpolation
t: 34-element Array{Float64,1}:
  0.0
  0.0776084743154256
  0.23264513699277584
  0.4291185174543143
  0.6790821776882875
  0.9444045910389707
  1.2674601253261835
  1.6192913723304114
  1.9869755337814992
  2.264090367186479
  ⋮
  7.584862904164952
  7.978068388305894
  8.483164907244102
  8.719247868929038
  8.949206527971544
  9.200184813643565
  9.438028630962807
  9.711807852444823
 10.0
u: 34-element Array{Array{Float64,1},1}:
 [1.0, 1.0]
 [1.0454942346944578, 0.8576684823217127]
 [1.1758715885138267, 0.639459570317544]
 [1.4196809607170826, 0.4569962601282084]
 [1.8767193485546056, 0.32473343696185236]
 [2.5882499852859384, 0.26336255804531]
 [3.860708771268753, 0.2794458027885767]
 [5.750812903389158, 0.5220073140479389]
 [6.814978737433837, 1.917783300239219]
 [4.3929977807914105, 4.194671536988031]
 ⋮
 [2.614252575185928, 0.26416950055716665]
 [4.2410731685818694, 0.30512345857554246]
 [6.791122470590543, 1.1345265418479897]
 [6.26537352594436, 2.741690196017545]
 [3.78076791065078, 4.431164786168439]
 [1.8164212283793362, 4.0640577258289365]
 [1.1465027088171469, 2.791172606389902]
 [0.9557986534742364, 1.6235632025270912]
 [1.03375813933372, 0.9063703701433561]

We take the solution data obtained and add noise to it to obtain data for using in the Bayesian Inference of the parameters

t = collect(range(1,stop=10,length=10))
sig = 0.49
data = convert(Array, VectorOfArray([(sol(t[i]) + sig*randn(2)) for i in 1:length(t)]))
2×10 Array{Float64,2}:
 2.55558   6.27639  1.36855  2.21923   …  3.53585   3.96592  2.52811
 0.415488  2.29089  1.91952  0.421956     0.789497  5.02123  0.985163

Plots of the actual data and generated data

scatter(t, data[1,:], lab="#prey (data)")
scatter!(t, data[2,:], lab="#predator (data)")
plot!(sol)
priors = [Truncated(Normal(1.5,0.5),0.5,2.5),Truncated(Normal(1.2,0.5),0,2),Truncated(Normal(3.0,0.5),1,4),Truncated(Normal(1.0,0.5),0,2)]
4-element Array{Truncated{Normal{Float64},Continuous,Float64},1}:
 Truncated(Normal{Float64}(μ=1.5, σ=0.5), range=(0.5, 2.5))
 Truncated(Normal{Float64}(μ=1.2, σ=0.5), range=(0.0, 2.0))
 Truncated(Normal{Float64}(μ=3.0, σ=0.5), range=(1.0, 4.0))
 Truncated(Normal{Float64}(μ=1.0, σ=0.5), range=(0.0, 2.0))

Stan.jl backend

The solution converges for tolerance values lower than 1e-3, lower tolerance leads to better accuracy in result but is accompanied by longer warmup and sampling time, truncated normal priors are used for preventing Stan from stepping into negative values.

@btime bayesian_result_stan = stan_inference(prob,t,data,priors,num_samples=10_000,printsummary=false)
File /Users/vaibhav/DiffEqBenchmarks.jl/tmp/parameter_estimation_model.stan
 will be updated.


File /Users/vaibhav/DiffEqBenchmarks.jl/tmp/parameter_estimation_model.stan
 will be updated.


File /Users/vaibhav/DiffEqBenchmarks.jl/tmp/parameter_estimation_model.stan
 will be updated.


File /Users/vaibhav/DiffEqBenchmarks.jl/tmp/parameter_estimation_model.stan
 will be updated.

  174.979 s (1361505 allocations: 56.94 MiB)
DiffEqBayes.StanModel{Stanmodel,Int64,Array{Float64,3},Array{String,1}}(  n
ame =                    "parameter_estimation_model"
  nchains =                 1
  num_samples =             10000
  num_warmup =                1000
  thin =                    1
  monitors =                String[]
  model_file =              "parameter_estimation_model.stan"
  data_file =               "parameter_estimation_model_1.data.R"
  output =                  Output()
    file =                    "parameter_estimation_model_samples_1.csv"
    diagnostics_file =        ""
    refresh =                 100
  pdir =                   "/Users/vaibhav/DiffEqBenchmarks.jl"
  tmpdir =                 "/Users/vaibhav/DiffEqBenchmarks.jl/tmp"
  output_format =           :array
  method =                  Sample()
    num_samples =             10000
    num_warmup =              1000
    save_warmup =             false
    thin =                    1
    algorithm =               HMC()
      engine =                  NUTS()
        max_depth =               10
      metric =                  CmdStan.diag_e
      stepsize =                1.0
      stepsize_jitter =         1.0
    adapt =                   Adapt()
      gamma =                   0.05
      delta =                   0.8
      kappa =                   0.75
      t0 =                      10.0
      init_buffer =             75
      term_buffer =             50
      window =                  25
, 0, [-7.34017 0.853264 … 2.73116 0.915208; -7.55609 0.916111 … 2.78655 0.9
90211; … ; -7.86709 0.998835 … 2.9281 0.984592; -7.17184 0.990564 … 3.15138
 1.15395], ["lp__", "accept_stat__", "stepsize__", "treedepth__", "n_leapfr
og__", "divergent__", "energy__", "sigma1.1", "sigma1.2", "theta1", "theta2
", "theta3", "theta4", "theta.1", "theta.2", "theta.3", "theta.4"])

Turing.jl backend

@btime bayesian_result_turing = turing_inference(prob,Tsit5(),t,data,priors,num_samples=10_000)
36.270 s (227070592 allocations: 17.04 GiB)
Object of type Chains, with data of type 9000×17×1 Array{Float64,3}

Iterations        = 1:9000
Thinning interval = 1
Chains            = 1
Samples per chain = 9000
internals         = acceptance_rate, hamiltonian_energy, hamiltonian_energy
_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, 
nom_step_size, numerical_error, step_size, tree_depth
parameters        = theta[1], theta[2], theta[3], theta[4], σ[1]

2-element Array{MCMCChains.ChainDataFrame,1}

Summary Statistics
  parameters    mean     std  naive_se    mcse        ess   r_hat
  ──────────  ──────  ──────  ────────  ──────  ─────────  ──────
    theta[1]  1.5354  0.1108    0.0012  0.0025  1964.1719  1.0002
    theta[2]  1.0029  0.1140    0.0012  0.0024  2301.4344  1.0000
    theta[3]  2.8708  0.2866    0.0030  0.0062  2005.1551  1.0001
    theta[4]  0.9931  0.1133    0.0012  0.0025  1982.7440  1.0001
        σ[1]  0.6659  0.1251    0.0013  0.0019  3876.7909  0.9999

Quantiles
  parameters    2.5%   25.0%   50.0%   75.0%   97.5%
  ──────────  ──────  ──────  ──────  ──────  ──────
    theta[1]  1.3386  1.4578  1.5296  1.6042  1.7718
    theta[2]  0.8143  0.9250  0.9913  1.0679  1.2563
    theta[3]  2.3434  2.6735  2.8580  3.0545  3.4664
    theta[4]  0.7877  0.9149  0.9858  1.0665  1.2334
        σ[1]  0.4662  0.5787  0.6496  0.7350  0.9609

DynamicHMC.jl backend

@btime bayesian_result_dynamichmc = dynamichmc_inference(prob,Tsit5(),t,data,priors,num_samples=10_000)
21.949 s (100464976 allocations: 9.87 GiB)
(posterior = NamedTuple{(:parameters, :σ),Tuple{Array{Float64,1},Array{Floa
t64,1}}}[(parameters = [1.6510330907693331, 1.068371927076397, 2.6700385090
60144, 0.8756214940363967], σ = [1.203931721461571, 0.3308111772255924]), (
parameters = [1.6897436377513675, 1.0430237851150368, 2.601079414339675, 0.
8247486490381333], σ = [1.007210075547111, 0.26357292049156145]), (paramete
rs = [1.3588194900440342, 0.8370391784177251, 3.3046299477504593, 1.2798289
643619438], σ = [1.0998922369939241, 0.3738164505949778]), (parameters = [1
.4430017363542815, 0.9091301577477175, 3.271178240763378, 1.086768588189342
5], σ = [1.2038981260993595, 0.2660212128891589]), (parameters = [1.4535668
044488257, 0.909523050743244, 3.276905486386487, 1.0955803011090495], σ = [
1.193617628965834, 0.266783760453349]), (parameters = [1.45382751529481, 0.
9532735140562241, 3.0300413150609877, 1.1049828006952258], σ = [1.132553004
6772012, 0.2917324556865594]), (parameters = [1.4629441621588475, 0.9054342
337051878, 3.268017222062399, 1.0759377422089689], σ = [1.5434123158162087,
 0.2630168067124241]), (parameters = [1.3432192939479721, 0.825200409360124
4, 3.23173317430294, 1.314322745979057], σ = [1.2728622042633615, 0.2834347
620521368]), (parameters = [1.3071165916749654, 0.8104259044697025, 3.68452
1124745626, 1.3408512948378726], σ = [0.9409149022983262, 0.419082388705157
9]), (parameters = [1.4258804864502665, 0.818718339392575, 3.17126692706493
03, 1.1349048291980846], σ = [0.8467645109894342, 0.29965312853945])  …  (p
arameters = [1.4929664889259344, 1.068060207207104, 2.9980815767360105, 1.0
362722020356832], σ = [0.8524511919141108, 0.4482810791932937]), (parameter
s = [1.5068920418235292, 0.9936324443451314, 3.013292259319715, 1.006927312
596321], σ = [0.9778604786346289, 0.4288247971595753]), (parameters = [1.53
0315288282088, 0.9233321541133878, 2.840696323323293, 1.0067946206822225], 
σ = [0.9540683524575444, 0.35318040365018893]), (parameters = [1.5505068918
424791, 0.9340834095344153, 2.694088018241023, 0.981661958022394], σ = [0.8
378202680123857, 0.438497330199877]), (parameters = [1.580069085047083, 0.9
574854791458529, 2.784597017864487, 0.9118234023648312], σ = [1.21505488058
82329, 0.5631116480285596]), (parameters = [1.6025521791897022, 0.988843140
6893479, 2.6233074069937046, 0.9488797265890467], σ = [1.037985906111572, 0
.4863899486464513]), (parameters = [1.5077089778689483, 0.9704062659457484,
 3.0367756170875526, 1.0314730363950464], σ = [1.0786621444009297, 0.217552
949931122]), (parameters = [1.5520298148311464, 0.9400878972595683, 2.87265
0856245965, 0.9998739642555784], σ = [1.227634634762198, 0.4718763018501869
]), (parameters = [1.5273979739480217, 0.8662328033036745, 2.94837472761127
9, 0.9942321580199855], σ = [1.1752103688524014, 0.4380025823869909]), (par
ameters = [1.4438018525248824, 1.0782035384607451, 3.0685242676380358, 1.11
4376664725165], σ = [0.9165385987338418, 0.5590224516059303])], chain = [[0
.5014012076016618, 0.06613592624053985, 0.9820928951771175, -0.132821365854
4598, 0.18559263552904492, -1.1062075279342756], [0.5245768237900944, 0.042
1239802792651, 0.9559265182340466, -0.19267660690069474, 0.0071842072197196
81, -1.3334252111772547], [0.3066163007722468, -0.17788440144269851, 1.1953
244996370895, 0.24672644741113411, 0.0952122086361469, -0.9839903758467082]
, [0.3667254830857303, -0.0952670072137852, 1.1851502382863834, 0.083208695
12451416, 0.185564730432472, -1.3241792256765121], [0.3740204010537183, -0.
0948349369325416, 1.1868995277952565, 0.09128417824245073, 0.17698871993422
36, -1.3213168346592004], [0.37419974435428993, -0.047853413298526076, 1.10
85762547619769, 0.09982976986688438, 0.12447438052679104, -1.23191814437407
72], [0.3804509546376101, -0.09934063410839251, 1.1841834468928054, 0.07319
259966502849, 0.43399575467533663, -1.3355373449966152], [0.295069190844180
04, -0.1921290017113334, 1.173018579795021, 0.2733215108911501, 0.241268068
832914, -1.2607732985133422], [0.26782363622966066, -0.21019536151694698, 1
.304140564617526, 0.2933047069037269, -0.060902576750357054, -0.86968774663
49355], [0.35478950813638943, -0.20001516222454077, 1.1541311695804692, 0.1
2654879649088271, -0.16633265015393867, -1.2051297114868595]  …  [0.4007650
7284361934, 0.06584411273848387, 1.097972609695804, 0.03562985260400764, -0
.15963932435014783, -0.8023348343922363], [0.41004927927037105, -0.00638791
50096104766, 1.103033254886344, 0.006903429002392361, -0.022388279007001202
, -0.8467068417008177], [0.4254737849472855, -0.0797662456031932, 1.0440492
063891291, 0.006771641279273761, -0.047019961822117456, -1.0407762940956404
], [0.43858190446463063, -0.06818954117058304, 0.9910597494073224, -0.01850
8268157665365, -0.17695167881579085, -0.8244015556721811], [0.4574685707963
2175, -0.04344472342525278, 1.024103165553915, -0.0923089454158774, 0.19478
9244980836, -0.5742773613937568], [0.4715974703994052, -0.01121956389250995
3, 0.964435890736649, -0.05247322540186434, 0.03728220672432128, -0.7207446
133330677], [0.410591265463622, -0.03004046429797, 1.1107963002546206, 0.03
098791299850871, 0.0757215180603821, -1.5253130101228625], [0.4395636321635
849, -0.061781900366717046, 1.0552352467668187, -0.0001260436875935144, 0.2
0508925675922857, -0.7510384000948402], [0.42356561700015377, -0.1436015805
6116548, 1.081254078735848, -0.0057845402200649935, 0.16144716889204658, -0
.8255304727623616], [0.36727980981178215, 0.0752962658522917, 1.12119675144
7984, 0.10829520349857374, -0.08715109722221948, -0.5815656427589329]], tre
e_statistics = DynamicHMC.TreeStatisticsNUTS[DynamicHMC.TreeStatisticsNUTS(
-24.972044404887683, 6, turning at positions -58:5, 0.9284041843982449, 63,
 DynamicHMC.Directions(0x65a2fb05)), DynamicHMC.TreeStatisticsNUTS(-25.3263
8209243423, 4, turning at positions -3:12, 0.9298976637327503, 15, DynamicH
MC.Directions(0x047aa51c)), DynamicHMC.TreeStatisticsNUTS(-27.9351072019710
46, 5, turning at positions 21:36, 0.9343109632181998, 63, DynamicHMC.Direc
tions(0x2dba49a4)), DynamicHMC.TreeStatisticsNUTS(-25.00498788899212, 6, tu
rning at positions -45:18, 0.8699256646318402, 63, DynamicHMC.Directions(0x
8ed16c52)), DynamicHMC.TreeStatisticsNUTS(-22.50513635867809, 2, turning at
 positions -3:0, 0.9999999999999999, 3, DynamicHMC.Directions(0x2fcadc9c)),
 DynamicHMC.TreeStatisticsNUTS(-23.968698002582858, 5, turning at positions
 -15:-30, 0.9165666825262788, 47, DynamicHMC.Directions(0x5d090511)), Dynam
icHMC.TreeStatisticsNUTS(-25.585626717018357, 4, turning at positions 0:15,
 0.9670086865497939, 15, DynamicHMC.Directions(0x13c553ff)), DynamicHMC.Tre
eStatisticsNUTS(-25.631172292787063, 5, turning at positions -27:-42, 0.971
1521254641908, 47, DynamicHMC.Directions(0x3253cfc5)), DynamicHMC.TreeStati
sticsNUTS(-26.07656446522952, 5, turning at positions -25:6, 0.949838397156
478, 31, DynamicHMC.Directions(0x4c67ae26)), DynamicHMC.TreeStatisticsNUTS(
-27.78911332837673, 6, turning at positions -57:6, 0.8146798188135186, 63, 
DynamicHMC.Directions(0xc22c3646))  …  DynamicHMC.TreeStatisticsNUTS(-24.37
2869031737785, 2, turning at positions -3:0, 0.8507016296521978, 3, Dynamic
HMC.Directions(0xd92cd1ac)), DynamicHMC.TreeStatisticsNUTS(-25.297885575321
747, 4, turning at positions 9:24, 0.9885396990076063, 31, DynamicHMC.Direc
tions(0xe63419b8)), DynamicHMC.TreeStatisticsNUTS(-22.634170513642232, 5, t
urning at positions 23:38, 0.8892474784021949, 47, DynamicHMC.Directions(0x
c4729276)), DynamicHMC.TreeStatisticsNUTS(-25.08815603029298, 4, turning at
 positions -23:-26, 0.895586547306448, 27, DynamicHMC.Directions(0x9810a3a1
)), DynamicHMC.TreeStatisticsNUTS(-26.118326468048103, 5, turning at positi
ons -24:7, 0.9844447809259682, 31, DynamicHMC.Directions(0x323379e7)), Dyna
micHMC.TreeStatisticsNUTS(-24.28092130771148, 4, turning at positions -1:14
, 0.9924232112951569, 15, DynamicHMC.Directions(0xa26197ce)), DynamicHMC.Tr
eeStatisticsNUTS(-24.846350847828873, 5, turning at positions 29:60, 0.7453
192520971835, 63, DynamicHMC.Directions(0x1dbc5a3c)), DynamicHMC.TreeStatis
ticsNUTS(-24.7120873743796, 5, turning at positions -18:13, 0.8981696068490
232, 31, DynamicHMC.Directions(0xc9d4b76d)), DynamicHMC.TreeStatisticsNUTS(
-30.37914990178817, 5, turning at positions -14:-45, 0.9122853621385598, 63
, DynamicHMC.Directions(0xc5496412)), DynamicHMC.TreeStatisticsNUTS(-25.526
836941259393, 5, turning at positions 0:31, 0.9531169057544701, 31, Dynamic
HMC.Directions(0x9bf05aff))], κ = Gaussian kinetic energy (LinearAlgebra.Di
agonal), √diag(M⁻¹): [0.07239305596109954, 0.09679410633782563, 0.093136522
85944729, 0.13242179029089104, 0.2883075001026171, 0.3132664821176181], ϵ =
 0.08349883321492188)

Conclusion

Lotka-Volterra Equation is a "predator-prey" model, it models population of two species in which one is the predator (wolf) and the other is the prey (rabbit). It depicts a cyclic behaviour, which is also seen in its Uncertainity Quantification Plots. This behaviour makes it easy to estimate even at very high tolerance values (1e-3).

using DiffEqBenchmarks
DiffEqBenchmarks.bench_footer(WEAVE_ARGS[:folder],WEAVE_ARGS[:file])

Appendix

These benchmarks are a part of the DiffEqBenchmarks.jl repository, found at: https://github.com/JuliaDiffEq/DiffEqBenchmarks.jl

To locally run this tutorial, do the following commands:

using DiffEqBenchmarks
DiffEqBenchmarks.weave_file("ParameterEstimation","DiffEqBayesLotkaVolterra.jmd")

Computer Information:

Julia Version 1.4.0
Commit b8e9a9ecc6 (2020-03-21 16:36 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin18.6.0)
  CPU: Intel(R) Core(TM) i7-6700HQ CPU @ 2.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-8.0.1 (ORCJIT, skylake)

Package Information:

Status: `/Users/vaibhav/DiffEqBenchmarks.jl/Project.toml`
[28f2ccd6-bb30-5033-b560-165f7b14dc2f] ApproxFun 0.11.10
[a134a8b2-14d6-55f6-9291-3336d3ab0209] BlackBoxOptim 0.5.0
[a93c6f00-e57d-5684-b7b6-d8193f3e46c0] DataFrames 0.20.2
[2b5f629d-d688-5b77-993f-72d75c75574e] DiffEqBase 6.25.2
[ebbdde9d-f333-5424-9be2-dbf1e9acfb5e] DiffEqBayes 2.9.1
[eb300fae-53e8-50a0-950c-e21f52c2b7e0] DiffEqBiological 4.2.0
[f3b72e0c-5b89-59e1-b016-84e28bfd966d] DiffEqDevTools 2.18.0
[c894b116-72e5-5b58-be3c-e6d8d4ac2b12] DiffEqJump 6.5.0
[1130ab10-4a5a-5621-a13d-e4788d82bd4c] DiffEqParamEstim 1.13.0
[a077e3f3-b75c-5d7f-a0c6-6bc4c8ec64a9] DiffEqProblemLibrary 4.6.4
[ef61062a-5684-51dc-bb67-a0fcdec5c97d] DiffEqUncertainty 1.4.1
[0c46a032-eb83-5123-abaf-570d42b7fbaa] DifferentialEquations 6.12.0
[7073ff75-c697-5162-941a-fcdaad2a7d2a] IJulia 1.21.1
[7f56f5a3-f504-529b-bc02-0b1fe5e64312] LSODA 0.6.1
[76087f3c-5699-56af-9a33-bf431cd00edd] NLopt 0.5.1
[c030b06c-0b6d-57c2-b091-7029874bd033] ODE 2.6.0
[54ca160b-1b9f-5127-a996-1867f4bc2a2c] ODEInterface 0.4.6
[09606e27-ecf5-54fc-bb29-004bd9f985bf] ODEInterfaceDiffEq 3.6.0
[1dea7af3-3e70-54e6-95c3-0bf5283fa5ed] OrdinaryDiffEq 5.32.2
[2dcacdae-9679-587a-88bb-8b444fb7085b] ParallelDataTransfer 0.5.0
[65888b18-ceab-5e60-b2b9-181511a3b968] ParameterizedFunctions 5.0.3
[91a5bcdd-55d7-5caf-9e0b-520d859cae80] Plots 0.29.9
[b4db0fb7-de2a-5028-82bf-5021f5cfa881] ReactionNetworkImporters 0.1.5
[f2c3362d-daeb-58d1-803e-2bc74f2840b4] RecursiveFactorization 0.1.0
[9672c7b4-1e72-59bd-8a11-6ac3964bc41f] SteadyStateDiffEq 1.5.0
[c3572dad-4567-51f8-b174-8c6c989267f4] Sundials 3.9.0
[a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f] TimerOutputs 0.5.3
[44d3d7a6-8a23-5bf8-98c5-b353f8df5ec9] Weave 0.9.4
[b77e0a4c-d291-57a0-90e8-8db25a27a240] InteractiveUtils 
[d6f4376e-aef5-505a-96c1-9c027394607a] Markdown 
[44cfe95a-1eb2-52ea-b672-e2afdf69b78f] Pkg 
[9a3f8284-a2c9-5f02-9a11-845980a1fd5c] Random