Diffusion Equation Physics-Informed Neural Network (PINN) Optimizer Benchmarks

Adapted from NeuralPDE: Automating Physics-Informed Neural Networks (PINNs) with Error Approximations. Uses the NeuralPDE.jl library from the SciML Scientific Machine Learning Open Source Organization for the implementation of physics-informed neural networks (PINNs) and other science-guided AI techniques.


using NeuralPDE, OptimizationFlux, ModelingToolkit, Optimization, OptimizationOptimJL
using Lux, Plots
import ModelingToolkit: Interval, infimum, supremum
function solve(opt)
    strategy = QuadratureTraining()

    @parameters x t
    @variables u(..)
    Dt = Differential(t)
    Dxx = Differential(x)^2

    eq = Dt(u(x,t)) - Dxx(u(x,t)) ~ -exp(-t) * (sin(pi * x) - pi^2 * sin(pi * x))

    bcs = [u(x,0) ~ sin(pi*x),
           u(-1,t) ~ 0.,
           u(1,t) ~ 0.]

    domains = [x ∈ Interval(-1.0,1.0),
               t ∈ Interval(0.0,1.0)]

    chain = Lux.Chain(Lux.Dense(2,18,tanh),Lux.Dense(18,18,tanh),Lux.Dense(18,1))

    discretization = PhysicsInformedNN(chain,strategy)

    indvars = [x, t]   #phisically independent variables
    depvars = [u(x,t)]       #dependent (target) variable

    loss = []
    initial_time = nothing

    times = []

    cb_ = function (p,l)
        if initial_time == nothing
            initial_time = time()
        push!(times, time() - initial_time)
        #println("Current loss for $opt is: $l")
        push!(loss, l)
      #  println(l )
      #  println(time() - initial_time)
        return false

    @named pde_system = PDESystem(eq, bcs, domains, indvars, depvars)
    prob = discretize(pde_system, discretization)

    if opt == "both"
        res = Optimization.solve(prob, ADAM(); callback = cb_, maxiters=50)
        prob = remake(prob,u0=res.minimizer)
        res = Optimization.solve(prob, BFGS(); callback = cb_, maxiters=150)
        res = Optimization.solve(prob, opt; callback = cb_, maxiters=200)

    times[1] = 0.01

    return loss, times #add numeric solution
solve (generic function with 1 method)
opt1 = ADAM()
opt2 = ADAM(0.005)
opt3 = ADAM(0.05)
opt4 = RMSProp()
opt5 = RMSProp(0.005)
opt6 = RMSProp(0.05)
opt7 = OptimizationOptimJL.BFGS()
opt8 = OptimizationOptimJL.LBFGS()
Optim.LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.Hage
rZhang{Float64, Base.RefValue{Bool}}, Optim.var"#19#21"}(10, LineSearches.I
  alpha: Float64 1.0
  scaled: Bool false
, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}
  delta: Float64 0.1
  sigma: Float64 0.9
  alphamax: Float64 Inf
  rho: Float64 5.0
  epsilon: Float64 1.0e-6
  gamma: Float64 0.66
  linesearchmax: Int64 50
  psi3: Float64 0.1
  display: Int64 0
  mayterminate: Base.RefValue{Bool}
, nothing, Optim.var"#19#21"(), Optim.Flat(), true)


loss_1, times_1 = solve(opt1)
loss_2, times_2 = solve(opt2)
loss_3, times_3 = solve(opt3)
loss_4, times_4 = solve(opt4)
loss_5, times_5 = solve(opt5)
loss_6, times_6 = solve(opt6)
loss_7, times_7 = solve(opt7)
loss_8, times_8 = solve(opt8)
loss_9, times_9 = solve("both")
(Any[159.66046306053968, 158.60439041573756, 157.6122315779257, 156.6821639
506822, 155.8128720996569, 155.0016514999154, 154.24388145654493, 153.53436
910467752, 152.86669360910818, 152.2338221167056  …  0.000730318056255846, 
0.0007085908482378465, 0.0006727545628227806, 0.0006440050956860218, 0.0006
184283400141107, 0.0006121742008270983, 0.0005887785855185885, 0.0005754594
331250222, 0.000563704674752501, 0.0005568433361956758], Any[0.01, 0.011908
05435180664, 0.02372002601623535, 0.03544306755065918, 0.08127593994140625,
 0.09351301193237305, 0.10534405708312988, 0.11698603630065918, 0.128731966
01867676, 0.1404709815979004  …  14.976590871810913, 15.000741004943848, 15
.02496886253357, 15.049202919006348, 15.073423862457275, 15.097702980041504
, 15.12190294265747, 15.1460440158844, 15.170176982879639, 15.2182478904724


p = plot([times_1, times_2, times_3, times_4, times_5, times_6, times_7, times_8, times_9], [loss_1, loss_2, loss_3, loss_4, loss_5, loss_6, loss_7, loss_8, loss_9],xlabel="time (s)", ylabel="loss", xscale=:log10, yscale=:log10, labels=["ADAM(0.001)" "ADAM(0.005)" "ADAM(0.05)" "RMSProp(0.001)" "RMSProp(0.005)" "RMSProp(0.05)" "BFGS()" "LBFGS()" "ADAM + BFGS"], legend=:bottomleft, linecolor=["#2660A4" "#4CD0F4" "#FEC32F" "#F763CD" "#44BD79" "#831894" "#A6ED18" "#980000" "#FF912B"])

p = plot([loss_1, loss_2, loss_3, loss_4, loss_5, loss_6, loss_7, loss_8, loss_9], xlabel="iterations", ylabel="loss", yscale=:log10, labels=["ADAM(0.001)" "ADAM(0.005)" "ADAM(0.05)" "RMSProp(0.001)" "RMSProp(0.005)" "RMSProp(0.05)" "BFGS()" "LBFGS()" "ADAM + BFGS"], legend=:bottomleft, linecolor=["#2660A4" "#4CD0F4" "#FEC32F" "#F763CD" "#44BD79" "#831894" "#A6ED18" "#980000" "#FF912B"])

@show loss_1[end], loss_2[end], loss_3[end], loss_4[end], loss_5[end], loss_6[end], loss_7[end], loss_8[end], loss_9[end]
(loss_1[end], loss_2[end], loss_3[end], loss_4[end], loss_5[end], loss_6[en
d], loss_7[end], loss_8[end], loss_9[end]) = (29.953570064928265, 0.2681777
8324726876, 0.14697134452697494, 24.18526653575657, 4.0206669359294365, 10.
851500422402758, 0.000428893794136428, 0.014926232630149512, 0.000556843336
(29.953570064928265, 0.26817778324726876, 0.14697134452697494, 24.185266535
75657, 4.0206669359294365, 10.851500422402758, 0.000428893794136428, 0.0149
26232630149512, 0.0005568433361956758)


These benchmarks are a part of the SciMLBenchmarks.jl repository, found at: https://github.com/SciML/SciMLBenchmarks.jl. For more information on high-performance scientific machine learning, check out the SciML Open Source Software Organization https://sciml.ai.

To locally run this benchmark, do the following commands:

using SciMLBenchmarks

Computer Information:

Julia Version 1.7.3
Commit 742b9abb4d (2022-05-06 12:58 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: AMD EPYC 7502 32-Core Processor
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, znver2)
  BUILDKITE_PLUGIN_JULIA_CACHE_DIR = /cache/julia-buildkite-plugin
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/5b300254-1738-4989-ae0a-f4d2d937f953

Package Information:

