Skip to content

Commit

Permalink
Merge pull request #853 from AstitvaAggarwal/BPINNode
Browse files Browse the repository at this point in the history
Better BPINN ode Solver
  • Loading branch information
ChrisRackauckas authored Sep 8, 2024
2 parents 70ceedc + fab83e9 commit e3ee467
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 53 deletions.
14 changes: 11 additions & 3 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ struct BNNODE{C, K, IT <: NamedTuple,
init_params::I
Adaptorkwargs::A
Integratorkwargs::IT
numensemble::Int64
estim_collocate::Bool
autodiff::Bool
progress::Bool
verbose::Bool
Expand All @@ -112,6 +114,8 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
Metric = DiagEuclideanMetric,
targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
numensemble = floor(Int, draw_samples / 3),
estim_collocate = false,
autodiff = false, progress = false, verbose = false)
!(chain isa Lux.AbstractExplicitLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
Expand All @@ -120,6 +124,7 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
phystd, dataset, physdt, MCMCkwargs,
nchains, init_params,
Adaptorkwargs, Integratorkwargs,
numensemble, estim_collocate,
autodiff, progress, verbose)
end

Expand Down Expand Up @@ -186,7 +191,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
@unpack chain, l2std, phystd, param, priorsNNw, Kernel, strategy,
draw_samples, dataset, init_params,
nchains, physdt, Adaptorkwargs, Integratorkwargs,
MCMCkwargs, autodiff, progress, verbose = alg
MCMCkwargs, numensemble, estim_collocate, autodiff, progress,
verbose = alg

# ahmc_bayesian_pinn_ode needs param=[] for easier vcat operation for full vector of parameters
param = param === nothing ? [] : param
Expand All @@ -211,7 +217,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
Integratorkwargs = Integratorkwargs,
MCMCkwargs = MCMCkwargs,
progress = progress,
verbose = verbose)
verbose = verbose,
estim_collocate = estim_collocate)

fullsolution = BPINNstats(mcmcchain, samples, statistics)
ninv = length(param)
Expand All @@ -220,7 +227,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
if chain isa Lux.AbstractExplicitLayer
θinit, st = Lux.setup(Random.default_rng(), chain)
θ = [vector_to_parameters(samples[i][1:(end - ninv)], θinit)
for i in (draw_samples - numensemble):draw_samples]
for i in 1:max(draw_samples - draw_samples ÷ 10, draw_samples - 1000)]

luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble]
# only need for size
θinit = collect(ComponentArrays.ComponentArray(θinit))
Expand Down
5 changes: 0 additions & 5 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,6 @@ function LogDensityProblems.logdensity(Tar::PDELogTargetDensity, θ)
# + L2loss2(Tar, θ)
end

# function L2loss2(Tar::PDELogTargetDensity, θ)
# return Tar.full_loglikelihood(setparameters(Tar, θ),
# Tar.allstd)
# end

function setparameters(Tar::PDELogTargetDensity, θ)
names = Tar.names
ps_new = θ[1:(end - Tar.extraparams)]
Expand Down
89 changes: 76 additions & 13 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
physdt::Float64
extraparams::Int
init_params::I
estim_collocate::Bool

function LogTargetDensity(dim, prob, chain::Optimisers.Restructure, st, strategy,
dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
init_params::AbstractVector)
init_params::AbstractVector, estim_collocate)
new{
typeof(chain),
Nothing,
Expand All @@ -39,12 +40,13 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
autodiff,
physdt,
extraparams,
init_params)
init_params,
estim_collocate)
end
function LogTargetDensity(dim, prob, chain::Lux.AbstractExplicitLayer, st, strategy,
dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
init_params::NamedTuple)
init_params::NamedTuple, estim_collocate)
new{
typeof(chain),
typeof(st),
Expand All @@ -60,7 +62,8 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
autodiff,
physdt,
extraparams,
init_params)
init_params,
estim_collocate)
end
end

Expand All @@ -83,7 +86,12 @@ end
vector_to_parameters(ps_new::AbstractVector, ps::AbstractVector) = ps_new

function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ)
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ)
if Tar.estim_collocate
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ) +
L2loss2(Tar, θ)
else
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ)
end
end

LogDensityProblems.dimension(Tar::LogTargetDensity) = Tar.dim
Expand All @@ -92,6 +100,55 @@ function LogDensityProblems.capabilities(::LogTargetDensity)
LogDensityProblems.LogDensityOrder{1}()
end

"""
suggested extra loss function for ODE solver case
"""
function L2loss2(Tar::LogTargetDensity, θ)
f = Tar.prob.f

# parameter estimation chosen or not
if Tar.extraparams > 0
autodiff = Tar.autodiff
# Timepoints to enforce Physics
t = Tar.dataset[end]
u1 = Tar.dataset[2]
= Tar.dataset[1]

nnsol = NNodederi(Tar, t, θ[1:(length(θ) - Tar.extraparams)], autodiff)

ode_params = Tar.extraparams == 1 ?
θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] :
θ[((length(θ) - Tar.extraparams) + 1):length(θ)]

if length(Tar.prob.u0) == 1
physsol = [f(û[i],
ode_params,
t[i])
for i in 1:length(û[:, 1])]
else
physsol = [f([û[i], u1[i]],
ode_params,
t[i])
for i in 1:length(û)]
end
#form of NN output matrix output dim x n
deri_physsol = reduce(hcat, physsol)

physlogprob = 0
for i in 1:length(Tar.prob.u0)
# can add phystd[i] for u[i]
physlogprob += logpdf(MvNormal(deri_physsol[i, :],
LinearAlgebra.Diagonal(map(abs2,
(Tar.l2std[i] * 4.0) .*
ones(length(nnsol[i, :]))))),
nnsol[i, :])
end
return physlogprob
else
return 0
end
end

"""
L2 loss loglikelihood(needed for ODE parameter estimation).
"""
Expand Down Expand Up @@ -247,7 +304,7 @@ function innerdiff(Tar::LogTargetDensity, f, autodiff::Bool, t::AbstractVector,

vals = nnsol .- physsol

# N dimensional vector if N outputs for NN(each row has logpdf of i[i] where u is vector of dependant variables)
# N dimensional vector if N outputs for NN(each row has logpdf of u[i] where u is vector of dependant variables)
return [logpdf(
MvNormal(vals[i, :],
LinearAlgebra.Diagonal(abs2.(Tar.phystd[i] .*
Expand Down Expand Up @@ -442,7 +499,8 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false)
progress = false, verbose = false,
estim_collocate = false)
!(chain isa Lux.AbstractExplicitLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
# NN parameter prior mean and variance(PriorsNN must be a tuple)
Expand All @@ -467,7 +525,7 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
# Lux-Named Tuple
initial_nnθ, recon, st = generate_Tar(chain, init_params)
else
error("Only Lux.AbstractExplicitLayer neural networks are supported")
error("Only Lux.AbstractExplicitLayer Neural networks are supported")
end

if nchains > Threads.nthreads()
Expand Down Expand Up @@ -500,7 +558,7 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
t0 = prob.tspan[1]
# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = LogTargetDensity(nparameters, prob, recon, st, strategy, dataset, priors,
phystd, l2std, autodiff, physdt, ninv, initial_nnθ)
phystd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate)

try
ℓπ(t0, initial_θ[1:(nparameters - ninv)])
Expand All @@ -515,6 +573,9 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
@info("Current Physics Log-likelihood : ", physloglikelihood(ℓπ, initial_θ))
@info("Current Prior Log-likelihood : ", priorweights(ℓπ, initial_θ))
@info("Current MSE against dataset Log-likelihood : ", L2LossData(ℓπ, initial_θ))
if estim_collocate
@info("Current gradient loss against dataset Log-likelihood : ", L2loss2(ℓπ, initial_θ))
end

Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor],
Adaptorkwargs[:Metric], Adaptorkwargs[:targetacceptancerate]
Expand Down Expand Up @@ -565,12 +626,14 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
@info("Sampling Complete.")
@info("Current Physics Log-likelihood : ", physloglikelihood(ℓπ, samples[end]))
@info("Current Prior Log-likelihood : ", priorweights(ℓπ, samples[end]))
@info("Current MSE against dataset Log-likelihood : ",
L2LossData(ℓπ, samples[end]))
@info("Current MSE against dataset Log-likelihood : ", L2LossData(ℓπ, samples[end]))
if estim_collocate
@info("Current gradient loss against dataset Log-likelihood : ", L2loss2(ℓπ, samples[end]))
end

# return a chain(basic chain),samples and stats
matrix_samples = hcat(samples...)
mcmc_chain = MCMCChains.Chains(matrix_samples')
matrix_samples = reshape(hcat(samples...), (length(samples[1]), length(samples), 1))
mcmc_chain = MCMCChains.Chains(matrix_samples)
return mcmc_chain, samples, stats
end
end
14 changes: 2 additions & 12 deletions test/BPINN_PDEinvsol_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using ComponentArrays

Random.seed!(100)

@testset "Example 1: 2D Periodic System with parameter estimation" begin
@testset "Example 1: 1D Periodic System with parameter estimation" begin
# Cos(pi*t) periodic curve
@parameters t, p
@variables u(..)
Expand Down Expand Up @@ -59,17 +59,7 @@ Random.seed!(100)
saveats = [1 / 50.0],
param = [LogNormal(6.0, 0.5)])

discretization = BayesianPINN([chainl], QuadratureTraining(), param_estim = true,
dataset = [dataset, nothing])

ahmc_bayesian_pinn_pde(pde_system,
discretization;
draw_samples = 1500,
bcstd = [0.05],
phystd = [0.01], l2std = [0.01],
priorsNNw = (0.0, 1.0),
saveats = [1 / 50.0],
param = [LogNormal(6.0, 0.5)])
# alternative to QuadratureTraining [WIP]

discretization = BayesianPINN([chainl], GridTraining([0.02]), param_estim = true,
dataset = [dataset, nothing])
Expand Down
Loading

0 comments on commit e3ee467

Please sign in to comment.