Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for an external synchronous compiler to discrete and hybrid systems #3399

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
c = BitSet(c′)
idxs = intersect(c, inferred)
isempty(idxs) && continue
if !allequal(var_domain[i] for i in idxs)
if !allequal(iscontinuous(var_domain[i]) for i in idxs)
display(fullvars[c′])
throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])"))
end
Expand Down Expand Up @@ -144,6 +144,9 @@
var_to_cid = Vector{Int}(undef, ndsts(graph))
cid_to_var = Vector{Int}[]
cid_counter = Ref(0)

# populates clock_to_id and id_to_clock
# checks if there is a continuous_id (for some reason? clock to id does this too)
for (i, d) in enumerate(eq_domain)
cid = let cid_counter = cid_counter, id_to_clock = id_to_clock,
continuous_id = continuous_id
Expand All @@ -161,9 +164,13 @@
resize_or_push!(cid_to_eq, i, cid)
end
continuous_id = continuous_id[]
# for each clock partition what are the input (indexes/vars)
input_idxs = map(_ -> Int[], 1:cid_counter[])
inputs = map(_ -> Any[], 1:cid_counter[])
# var_domain corresponds to fullvars/all variables in the system
nvv = length(var_domain)
# put variables into the right clock partition
# keep track of inputs to each partition
for i in 1:nvv
d = var_domain[i]
cid = get(clock_to_id, d, 0)
Expand All @@ -177,6 +184,7 @@
resize_or_push!(cid_to_var, i, cid)
end

# breaks the system up into a continous and 0 or more discrete systems

Check warning on line 187 in src/systems/clock_inference.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"continous" should be "continuous".
tss = similar(cid_to_eq, S)
for (id, ieqs) in enumerate(cid_to_eq)
ts_i = system_subset(ts, ieqs)
Expand All @@ -186,6 +194,7 @@
end
tss[id] = ts_i
end
# put the continous system at the back
if continuous_id != 0
tss[continuous_id], tss[end] = tss[end], tss[continuous_id]
inputs[continuous_id], inputs[end] = inputs[end], inputs[continuous_id]
Expand Down
8 changes: 7 additions & 1 deletion src/systems/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function structural_simplify(
kwargs...)
isscheduled(sys) && throw(RepeatedStructuralSimplificationError())
newsys′ = __structural_simplify(sys, io; simplify,
allow_symbolic, allow_parameter, conservative, fully_determined,
allow_symbolic, allow_parameter, conservative, fully_determined, additional_passes,
kwargs...)
if newsys′ isa Tuple
@assert length(newsys′) == 2
Expand Down Expand Up @@ -169,3 +169,9 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
guesses = guesses(sys), initialization_eqs = initialization_equations(sys))
end
end

"""
Mark whether an extra pass `p` can support compiling discrete systems.
"""
discrete_compile_pass(p) = false

49 changes: 27 additions & 22 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -626,40 +626,45 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
check_consistency = true, fully_determined = true, warn_initialize_determined = true,
kwargs...)
if state.sys isa ODESystem
# split_system returns one or two systems and the inputs for each
# mod clock inference to be binary
# if it's continous keep going, if not then error unless given trait impl in additional passes
ci = ModelingToolkit.ClockInference(state)
ci = ModelingToolkit.infer_clocks!(ci)
time_domains = merge(Dict(state.fullvars .=> ci.var_domain),
Dict(default_toterm.(state.fullvars) .=> ci.var_domain))
tss, inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci)
if continuous_id == 0
# do a trait check here - handle fully discrete system
additional_passes = get(kwargs, :additional_passes, nothing)
if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes)
# take the first discrete compilation pass given for now
discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes)
discrete_compile = additional_passes[discrete_pass_idx]
deleteat!(additional_passes, discrete_pass_idx)
return discrete_compile(tss, inputs)
else
# error goes here! this is a purely discrete system
throw(HybridSystemNotSupportedException("Discrete systems without JuliaSimCompiler are currently not supported in ODESystem."))
end
end
# puts the ios passed in to the call into the continous system
cont_io = merge_io(io, inputs[continuous_id])
# simplify as normal
sys, input_idxs = _structural_simplify!(tss[continuous_id], cont_io; simplify,
check_consistency, fully_determined,
kwargs...)
if length(tss) > 1
if continuous_id > 0
if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes)
discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes)
discrete_compile = additional_passes[discrete_pass_idx]
deleteat!(additional_passes, discrete_pass_idx)
# in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems
# and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed
sys = discrete_compile(sys, tss[2:end], inputs)
else
throw(HybridSystemNotSupportedException("Hybrid continuous-discrete systems are currently not supported with the standard MTK compiler. This system requires JuliaSimCompiler.jl, see https://help.juliahub.com/juliasimcompiler/stable/"))
end
# TODO: rename it to something else
discrete_subsystems = Vector{ODESystem}(undef, length(tss))
# Note that the appended_parameters must agree with
# `generate_discrete_affect`!
appended_parameters = parameters(sys)
for (i, state) in enumerate(tss)
if i == continuous_id
discrete_subsystems[i] = sys
continue
end
dist_io = merge_io(io, inputs[i])
ss, = _structural_simplify!(state, dist_io; simplify, check_consistency,
fully_determined, kwargs...)
append!(appended_parameters, inputs[i], unknowns(ss))
discrete_subsystems[i] = ss
end
@set! sys.discrete_subsystems = discrete_subsystems, inputs, continuous_id,
id_to_clock
@set! sys.ps = appended_parameters
@set! sys.defaults = merge(ModelingToolkit.defaults(sys),
Dict(v => 0.0 for v in Iterators.flatten(inputs)))
end
ps = [sym isa CallWithMetadata ? sym :
setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous()))
Expand Down
Loading