Skip to content

Commit

Permalink
Updated factor bounds setting method
Browse files Browse the repository at this point in the history
Previous implementation looped over each factor and set these individually, which is very slow.

Updated to set the factors in a single pass.
  • Loading branch information
ConnectedSystems committed Jan 27, 2025
1 parent 52aaded commit 705c65e
Showing 1 changed file with 43 additions and 10 deletions.
53 changes: 43 additions & 10 deletions src/io/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,18 @@ function fix_factor!(d::Domain, factor::Symbol, val::Real)::Nothing
return nothing
end
function fix_factor!(d::Domain, factors::Vector{Symbol})::Nothing
for f in factors
fix_factor!(d, f)
end
params = DataFrame(d.model)
factor_rows = findall(in(factors), params.fieldname)

# Get current values and dist_params lengths
vals = params[factor_rows, :val]
dist_lens = length.(params[factor_rows, :dist_params])

# Create new dist_params tuples
new_params = [Tuple(fill(v, len)) for (v, len) in zip(vals, dist_lens)]
params[factor_rows, :dist_params] .= new_params

update!(d, params)
return nothing
end
function fix_factor!(d::Domain; factors...)::Nothing
Expand Down Expand Up @@ -480,8 +488,8 @@ function get_attr(dom::Domain, factor::Symbol, attr::Symbol)
end

"""
set_factor_bounds(dom::Domain, factor::Symbol, new_bounds::Tuple)::Nothing
set_factor_bounds(dom::Domain; factors...)::Nothing
set_factor_bounds!(dom::Domain, factor::Symbol, new_bounds::Tuple)::Nothing
set_factor_bounds!(dom::Domain; factors...)::Nothing
Set new bound values for a given parameter. Sampled values for a parameter will lie
within the range `lower_bound ≤ s ≤ upper_bound`, for every sample value `s`.
Expand All @@ -504,6 +512,12 @@ set_factor_bounds(dom, :wave_stress, (0.1, 0.2))
```
"""
function set_factor_bounds(dom::Domain, factor::Symbol, new_dist_params::Tuple)::Domain
Base.@warn "set_factor_bounds is deprecated, use set_factor_bounds! instead" maxlog=1 _category=:deprecation

set_factor_bounds!(dom, factor, new_dist_params)
return dom
end
function set_factor_bounds!(dom::Domain, factor::Symbol, new_dist_params::Tuple)::Domain
old_val = get_attr(dom, factor, :val)
new_val = mean(new_dist_params[1:2])

Expand All @@ -514,16 +528,35 @@ function set_factor_bounds(dom::Domain, factor::Symbol, new_dist_params::Tuple):
ms[!, :is_constant] .= (ms[!, :lower_bound] .== ms[!, :upper_bound])

update!(dom, ms)

return dom
return nothing
end

function set_factor_bounds(dom::Domain; factors...)::Domain
for (factor, bounds) in factors
dom = set_factor_bounds(dom, factor, bounds)
end
Base.@warn "set_factor_bounds is deprecated, use set_factor_bounds! instead" maxlog=1 _category=:deprecation

set_factor_bounds!(dom; factors...)
return dom
end
function set_factor_bounds!(dom::Domain; factors...)::Nothing
ms = model_spec(dom)
factor_symbols = collect(keys(factors))
factor_rows = findall(in(factor_symbols), ms.fieldname)

# Update dist_params and values in bulk
new_params = collect(values(factors))
ms[factor_rows, :dist_params] .= new_params

# Calculate new values preserving types
old_vals = ms[factor_rows, :val]
new_vals = mean.(zip(first.(new_params), last.(new_params)))
ms[factor_rows, :val] .= [v isa Int ? round(n) : n for (v, n) in zip(old_vals, new_vals)]

# Update `is_constant` column
ms[!, :is_constant] .= (ms[!, :lower_bound] .== ms[!, :upper_bound])

update!(dom, ms)
return nothing
end

"""
_validate_new_bounds(dom::Domain, factor::Symbol, new_dist_params::Tuple)::Nothing
Expand Down

0 comments on commit 705c65e

Please sign in to comment.