Skip to content

Commit

Permalink
Merge pull request #705 from SciML/Vaibhavdixit02-patch-1
Browse files Browse the repository at this point in the history
Keep support for `cb` to avoid breakage
  • Loading branch information
Vaibhavdixit02 authored May 3, 2022
2 parents de6e2df + 7680bc8 commit 19cecc4
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqFlux"
uuid = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
authors = ["Chris Rackauckas <[email protected]>"]
version = "1.47.0"
version = "1.47.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
7 changes: 5 additions & 2 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ By default, if the loss function is deterministic than an optimizer chain of
ADAM -> BFGS is used, otherwise ADAM is used (and a choice of maxiters is required).
"""
function sciml_train(loss, θ, opt=nothing, adtype=nothing, args...;
lower_bounds=nothing, upper_bounds=nothing,
lower_bounds=nothing, upper_bounds=nothing, cb = nothing,
callback = (args...) -> (false),
maxiters=nothing, kwargs...)
if adtype === nothing
Expand Down Expand Up @@ -83,7 +83,10 @@ function sciml_train(loss, θ, opt=nothing, adtype=nothing, args...;
adtype = GalacticOptim.AutoZygote()
end
end

if !isnothing(cb)
callback = cb
end

optf = GalacticOptim.OptimizationFunction((x, p) -> loss(x), adtype)
optfunc = GalacticOptim.instantiate_function(optf, θ, adtype, nothing)
optprob = GalacticOptim.OptimizationProblem(optfunc, θ; lb=lower_bounds, ub=upper_bounds, kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion test/newton_neural_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ nODE = NeuralODE(NN, tspan, ROCK4(), reltol=1e-4, saveat=[tspan[end]])
loss_function(θ) = Flux.Losses.mse(y, nODE(x, θ)[end])
l1 = loss_function(nODE.p)

res = DiffEqFlux.sciml_train(loss_function, nODE.p, NewtonTrustRegion(), GalacticOptim.AutoZygote(), maxiters = 100, callback=cb)
res = DiffEqFlux.sciml_train(loss_function, nODE.p, NewtonTrustRegion(), GalacticOptim.AutoZygote(), maxiters=100, cb=cb) #ensure backwards compatibility of `cb`
@test loss_function(res.minimizer) < l1
res = DiffEqFlux.sciml_train(loss_function, nODE.p, Optim.KrylovTrustRegion(), GalacticOptim.AutoZygote(), maxiters = 100, callback=cb)
@test loss_function(res.minimizer) < l1
Expand Down

0 comments on commit 19cecc4

Please sign in to comment.