From da29045df1f956023ac0f483f0affacab3c68099 Mon Sep 17 00:00:00 2001 From: nefrathenrici Date: Mon, 17 Jun 2024 11:26:09 -0700 Subject: [PATCH] Add kwargs for EKP to `initialize` --- src/backends.jl | 31 ++++++++++++++++++++----------- src/ekp_interface.jl | 2 ++ test/ekp_interface.jl | 18 ++++++++++++++++-- 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/src/backends.jl b/src/backends.jl index 8c201485..29becec6 100644 --- a/src/backends.jl +++ b/src/backends.jl @@ -43,17 +43,24 @@ include(joinpath(experiment_dir, "model_interface.jl")) eki = ClimaCalibrate.calibrate(experiment_dir) ``` """ -calibrate(config::ExperimentConfig; kwargs...) = - calibrate(get_backend(), config; kwargs...) +calibrate(config::ExperimentConfig; ekp_kwargs...) = + calibrate(get_backend(), config; ekp_kwargs...) -calibrate(experiment_dir::AbstractString) = - calibrate(get_backend(), ExperimentConfig(experiment_dir)) +calibrate(experiment_dir::AbstractString; ekp_kwargs...) = + calibrate(get_backend(), ExperimentConfig(experiment_dir); ekp_kwargs...) -calibrate(b::Type{JuliaBackend}, experiment_dir::AbstractString) = - calibrate(b, ExperimentConfig(experiment_dir)) +calibrate( + b::Type{JuliaBackend}, + experiment_dir::AbstractString; + ekp_kwargs..., +) = calibrate(b, ExperimentConfig(experiment_dir); ekp_kwargs...) -function calibrate(::Type{JuliaBackend}, config::ExperimentConfig) - initialize(config) +function calibrate( + ::Type{JuliaBackend}, + config::ExperimentConfig; + ekp_kwargs..., +) + initialize(config; ekp_kwargs...) (; n_iterations, ensemble_size) = config eki = nothing for i in 0:(n_iterations - 1) @@ -103,9 +110,10 @@ eki = calibrate(CaltechHPC, experiment_dir; model_interface, slurm_kwargs); function calibrate( b::Type{CaltechHPC}, experiment_dir::AbstractString; - kwargs..., + slurm_kwargs, + ekp_kwargs..., ) - calibrate(b, ExperimentConfig(experiment_dir); kwargs...) + calibrate(b, ExperimentConfig(experiment_dir); slurm_kwargs, ekp_kwargs...) end function calibrate( @@ -117,11 +125,12 @@ function calibrate( ), verbose = false, slurm_kwargs = Dict(:time_limit => 45, :ntasks => 1), + ekp_kwargs..., ) # ExperimentConfig is created from a YAML file within the experiment_dir (; n_iterations, output_dir, ensemble_size) = config @info "Initializing calibration" n_iterations ensemble_size output_dir - initialize(config) + initialize(config; ekp_kwargs...) eki = nothing for iter in 0:(n_iterations - 1) diff --git a/src/ekp_interface.jl b/src/ekp_interface.jl index ba72bfd4..59ebd822 100644 --- a/src/ekp_interface.jl +++ b/src/ekp_interface.jl @@ -221,6 +221,7 @@ function initialize( prior, output_dir; rng_seed = 1234, + ekp_kwargs..., ) Random.seed!(rng_seed) rng_ekp = Random.MersenneTwister(rng_seed) @@ -234,6 +235,7 @@ function initialize( EKP.Inversion(); rng = rng_ekp, failure_handler_method = EKP.SampleSuccGauss(), + ekp_kwargs..., ) param_dict = get_param_dict(prior) diff --git a/test/ekp_interface.jl b/test/ekp_interface.jl index 95973eef..39fb0755 100644 --- a/test/ekp_interface.jl +++ b/test/ekp_interface.jl @@ -1,4 +1,5 @@ using Distributions +import EnsembleKalmanProcesses as EKP using EnsembleKalmanProcesses.ParameterDistributions import ClimaCalibrate as CAL import ClimaParams as CP @@ -7,7 +8,7 @@ using Test FT = Float64 output_dir = "test_init" -prior_path = joinpath("test_case_inputs", "prior.toml") +prior_path = joinpath(pkgdir(CAL), "test", "test_case_inputs", "prior.toml") param_names = ["one", "two"] prior = CAL.get_prior(prior_path) @@ -25,7 +26,20 @@ config = CAL.ExperimentConfig( output_dir, ) -CAL.initialize(config) +eki = CAL.initialize(config) +eki_with_kwargs = CAL.initialize( + config; + scheduler = EKP.MutableScheduler(2), + accelerator = EKP.NesterovAccelerator(), +) + +@testset "Test passing kwargs to EKP struct" begin + @test eki_with_kwargs.scheduler != eki.scheduler + @test eki_with_kwargs.scheduler isa EKP.MutableScheduler + + @test eki_with_kwargs.accelerator != eki.accelerator + @test eki_with_kwargs.accelerator isa EKP.NesterovAccelerator +end override_file = joinpath( config.output_dir,