-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closures with Fix1 for type stability #69
Comments
Hi @cscherrer, could you give an MWE for why the following Soss model is not type-stable? For(n -> Normal(α + β * x[n], σ), 1:N)
As you mentioned the implementation of GG's Closure here, I'd guess the type stability issue is caused by the use of
GeneralizedGenerated.jl/src/closure_conv.jl Lines 18 to 22 in 6ebfe69
|
I now have time to work on GG and Soss. |
Great! Yesterday's covid booster is hitting me pretty hard, but I'll send you an update when I'm feeling better. Lots going on 🙂 |
Sorry to hear that.. Hope to see you get better, and we will work together again! |
Using julia> using Soss
julia> using MeasureTheory
julia> m = @model n begin
y ~ For(n) do j
Bernoulli(1/j)
end
end;
julia> s = rand(m(10))
(y = Bool[1, 0, 0, 0, 0, 0, 0, 0, 1, 0],)
julia> using JET
julia> @test_opt rand(m(10))
JET-test failed at REPL[44]:1
Expression: #= REPL[44]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m(10))
═════ 7 possible errors found ═════
┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:24 Soss.rand(Soss.GLOBAL_RNG, m)
│┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:20 Soss._rand(_, m, Soss.argvals(c))(rng)
││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 GeneralizedGenerated.#_#4(Core.tuple(Base.pairs(Core.NamedTuple()), closure), args...)
│││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 _(Base.getproperty(closure, :frees), args...)
││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.#_#10(Core.tuple(Base.pairs(Core.NamedTuple()), #self#), pargs...)
│││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:116 Base.getproperty(Main, :For)(function = (j;) -> begin
(Main).Bernoulli((Main).:/(1, j))
end, n)
││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:199 MeasureBase.For(f, Base.OneTo(n))
│││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:198 MeasureBase.For(f, inds)
││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:18 For(::ggfunc-function, ::Tuple{Base.OneTo{Int64}})
│││││││││ failed to optimize: For(::ggfunc-function, ::Tuple{Base.OneTo{Int64}})
││││││││└───────────────────────────────────────────────────────────
│││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:198 For(::ggfunc-function, ::Base.OneTo{Int64})
││││││││ failed to optimize: For(::ggfunc-function, ::Base.OneTo{Int64})
│││││││└────────────────────────────────────────────────────────────
││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:199 For(::ggfunc-function, ::Int64)
│││││││ failed to optimize: For(::ggfunc-function, ::Int64)
││││││└────────────────────────────────────────────────────────────
│││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:116 y = Base.getproperty(Main, :rand)(_rng, Base.getproperty(Main, :For)(function = (j;) -> begin
(Main).Bernoulli((Main).:/(1, j))
end, n))
││││││┌ @ /home/chad/git/MeasureBase.jl/src/rand.jl:7 MeasureBase.rand(rng, MeasureBase.Float64, d)
│││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:212 MeasureBase._rand_product(rng, _, MeasureBase.marginals(d), _)
││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/product.jl:33 MeasureBase.map(#28, mar)
│││││││││┌ @ abstractarray.jl:2849 Base.collect_similar(A, Base.Generator(f, A))
││││││││││┌ @ array.jl:653 Base._collect(cont, itr, Base.IteratorEltype(itr), Base.IteratorSize(itr))
│││││││││││┌ @ array.jl:701 et = Base.promote_typejoin_union(T)
││││││││││││┌ @ promotion.jl:170 Base.promote_typejoin_union(Base.getproperty(_, :a))
│││││││││││││┌ @ promotion.jl:190 unwrapva(%35)
││││││││││││││ runtime dispatch detected: unwrapva(%35::Any)
│││││││││││││└────────────────────
│││││││││││││┌ @ promotion.jl:194 Base.typejoin(%57, %59)
││││││││││││││ runtime dispatch detected: Base.typejoin(%57::Any, %59::Any)
│││││││││││││└────────────────────
│││││││││││││┌ @ promotion.jl:198 Base.promote_typejoin_union(%50)
││││││││││││││ runtime dispatch detected: Base.promote_typejoin_union(%50::Any)
│││││││││││││└────────────────────
│││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Int64, ::Random._GLOBAL_RNG)
││││││ failed to optimize: GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Int64, ::Random._GLOBAL_RNG)
│││││└───────────────────────────────────────────────────────────────────────────────────
ERROR: There was an error during testing I think this comes down to @generated function For(f::F, inds::I) where {F,I<:Tuple}
eltypes = Tuple{eltype.(I.types)...}
quote
$(Expr(:meta, :inline))
T = Core.Compiler.return_type(f, $eltypes)
For{T,F,I}(f, inds)
end
end I've had some comments that using Interestingly, this fixes it: julia> f(ctx) = Base.Fix1(ctx) do ctx, j
Bernoulli(1/j)
end
f (generic function with 1 method)
julia> m = @model n begin
y ~ For(f(NamedTuple()), n)
end;
julia> s = rand(m(10))
(y = Bool[1, 0, 0, 1, 0, 0, 0, 0, 0, 0],)
julia> using JET
julia> @test_opt rand(m(10))
Test Passed
Expression: #= REPL[7]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m(10)) In general, |
Looking at this a little closer, I think my approach is kind of clumsy. Do you think it could be possible to instead have an option to just avoid boxing in the first place? Maybe something like an opaque closure, or a faked version of that? Even something throwing an error on type instability would be useful. |
Just a little more information on this, then I'm going to try to set it aside for a while... Say you start with using Soss, JET
m1 = @model N begin
p ~ Uniform()
x ~ For(N) do j
Bernoulli(p / j)
end
end Then **Here's the test for m1**julia> @test_opt rand(m1(10))
JET-test failed at REPL[38]:1
Expression: #= REPL[38]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m1(10))
═════ 16 possible errors found ═════
┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.#rand#49(Base.pairs(Core.NamedTuple()), #self#, m)
│┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.rand(Soss.GLOBAL_RNG, m)
││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:34 Soss.#rand#51(Soss.NamedTuple(), Soss.NamedTuple(), #self#, rng, mc)
│││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:36 f(cfg, ctx)
││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 GeneralizedGenerated.#_#4(Core.tuple(Base.pairs(Core.NamedTuple()), closure), args...)
│││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 _(Base.getproperty(closure, :frees), args...)
││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.#_#10(Core.tuple(Base.pairs(Core.NamedTuple()), #self#), pargs...)
│││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:24 tilde_rand(:x, Base.getproperty(Main, :For)(Core.apply_type(Closure, function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Base.typeof(freevars))(freevars), N), _cfg, _ctx, targs)
││││││││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:43 x = Soss.rand(Base.getproperty(cfg, :rng), d)
│││││││││┌ @ /home/chad/git/MeasureBase.jl/src/rand.jl:7 MeasureBase.rand(rng, MeasureBase.Float64, d)
││││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:220 MeasureBase._rand_product(rng, _, MeasureBase.marginals(d), _)
│││││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/product.jl:33 MeasureBase.map(#28, mar)
││││││││││││┌ @ abstractarray.jl:2849 Base.collect_similar(A, Base.Generator(f, A))
│││││││││││││┌ @ array.jl:653 Base._collect(cont, itr, Base.IteratorEltype(itr), Base.IteratorSize(itr))
││││││││││││││┌ @ array.jl:744 y = Base.iterate(itr)
│││││││││││││││┌ @ generator.jl:44 y = Base.iterate(Core.tuple(Base.getproperty(g, :iter)), s...)
││││││││││││││││┌ @ abstractarray.jl:1142 #self#(A, Core.tuple(Base.eachindex(A)))
│││││││││││││││││┌ @ abstractarray.jl:1144 Base.getindex(A, Base.getindex(y, 1))
││││││││││││││││││┌ @ /home/chad/.julia/packages/MappedArrays/bS6Yp/src/MappedArrays.jl:166 Base.getproperty(A, :f)(Base.getindex(Core.tuple(Base.getproperty(A, :data)), i...))
│││││││││││││││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 GeneralizedGenerated.#_#4(Core.tuple(Base.pairs(Core.NamedTuple()), closure), args...)
││││││││││││││││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 Base.merge(Base.NamedTuple(), kwargs)
│││││││││││││││││││││┌ @ namedtuple.jl:303 Core.apply_type(Base.NamedTuple, Core.tuple(names...))(Core.tuple(vals...))
││││││││││││││││││││││┌ @ boot.jl:601 Core.apply_type(Core.NamedTuple, _, Core.typeof(args))(args)
│││││││││││││││││││││││┌ @ namedtuple.jl:96 _(args)
││││││││││││││││││││││││┌ @ tuple.jl:312 Base.convert(_, x)
│││││││││││││││││││││││││┌ @ essentials.jl:344 Base.Val(_)
││││││││││││││││││││││││││┌ @ essentials.jl:701 %1()
│││││││││││││││││││││││││││ runtime dispatch detected: %1::Type{Val{_A}} where _A()
││││││││││││││││││││││││││└─────────────────────
│││││││││││││││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:5 (::GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}})(::Int64)
││││││││││││││││││││ failed to optimize: (::GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}})(::Int64)
│││││││││││││││││││└──────────────────────────────────────────────────────────────────────────
││││││││││││││││││┌ @ /home/chad/.julia/packages/MappedArrays/bS6Yp/src/MappedArrays.jl:164 getindex(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Int64)
│││││││││││││││││││ failed to optimize: getindex(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Int64)
││││││││││││││││││└─────────────────────────────────────────────────────────────────────────
│││││││││││││││││┌ @ abstractarray.jl:1141 iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Tuple{Base.OneTo{Int64}})
││││││││││││││││││ failed to optimize: iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Tuple{Base.OneTo{Int64}})
│││││││││││││││││└─────────────────────────
││││││││││││││││┌ @ abstractarray.jl:1141 iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}})
│││││││││││││││││ failed to optimize: iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}})
││││││││││││││││└─────────────────────────
│││││││││││││││┌ @ generator.jl:42 iterate(::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}})
││││││││││││││││ failed to optimize: iterate(::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}})
│││││││││││││││└───────────────────
││││││││││││││┌ @ array.jl:754 Base.collect_to_with_first!(dest, v1, itr, st)
│││││││││││││││┌ @ array.jl:760 Base.collect_to!(dest, itr, Base.+(i1, 1), st)
││││││││││││││││┌ @ array.jl:782 y = Base.iterate(itr, st)
│││││││││││││││││┌ @ generator.jl:44 y = Base.iterate(Core.tuple(Base.getproperty(g, :iter)), s...)
││││││││││││││││││┌ @ abstractarray.jl:1141 iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Tuple{Base.OneTo{Int64}, Int64})
│││││││││││││││││││ failed to optimize: iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Tuple{Base.OneTo{Int64}, Int64})
││││││││││││││││││└─────────────────────────
│││││││││││││││││┌ @ generator.jl:42 iterate(::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}}, ::Tuple{Base.OneTo{Int64}, Int64})
││││││││││││││││││ failed to optimize: iterate(::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}}, ::Tuple{Base.OneTo{Int64}, Int64})
│││││││││││││││││└───────────────────
││││││││││││││┌ @ array.jl:741 Base._collect(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}}, ::Base.EltypeUnknown, ::Base.HasShape{1})
│││││││││││││││ failed to optimize: Base._collect(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}}, ::Base.EltypeUnknown, ::Base.HasShape{1})
││││││││││││││└────────────────
│││││││││││││┌ @ array.jl:653 Base.collect_similar(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}})
││││││││││││││ failed to optimize: Base.collect_similar(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}})
│││││││││││││└────────────────
│││││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/product.jl:32 MeasureBase._rand_product(::Random._GLOBAL_RNG, ::Type{Float64}, ::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Type{Bernoulli{(:p,), Tuple{Float64}}})
││││││││││││ failed to optimize: MeasureBase._rand_product(::Random._GLOBAL_RNG, ::Type{Float64}, ::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}}, ::Type{Bernoulli{(:p,), Tuple{Float64}}})
│││││││││││└───────────────────────────────────────────────────────────────
││││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:219 rand(::Random._GLOBAL_RNG, ::Type{Float64}, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}})
│││││││││││ failed to optimize: rand(::Random._GLOBAL_RNG, ::Type{Float64}, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}})
││││││││││└────────────────────────────────────────────────────────────
│││││││││┌ @ /home/chad/git/MeasureBase.jl/src/rand.jl:7 rand(::Random._GLOBAL_RNG, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}})
││││││││││ failed to optimize: rand(::Random._GLOBAL_RNG, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}})
│││││││││└───────────────────────────────────────────────
││││││││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:42 Soss.tilde_rand(::Symbol, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}}, ::NamedTuple{(:rng, :_args, :_obs), Tuple{Random._GLOBAL_RNG, NamedTuple{(:N,), Tuple{Int64}}, NamedTuple{(), Tuple{}}}}, ::NamedTuple{(:p,), Tuple{Float64}}, ::Soss.TildeArgs{DataType, DataType, NamedTuple{(), Tuple{}}, Static.False, Static.False})
│││││││││ failed to optimize: Soss.tilde_rand(::Symbol, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}}, ::NamedTuple{(:rng, :_args, :_obs), Tuple{Random._GLOBAL_RNG, NamedTuple{(:N,), Tuple{Int64}}, NamedTuple{(), Tuple{}}}}, ::NamedTuple{(:p,), Tuple{Float64}}, ::Soss.TildeArgs{DataType, DataType, NamedTuple{(), Tuple{}}, Static.False, Static.False})
││││││││└────────────────────────────────────────────────────
│││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{91}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
││││││││ failed to optimize: GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{91}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
│││││││└───────────────────────────────────────────────────────────────────────────────────
│││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:5 GeneralizedGenerated.var"#_#4"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::GeneralizedGenerated.Closure{function = (_mc, _cfg, _ctx;) -> begin
begin
$(Expr(:meta, :((Main).inline)))
local _retn
_args = (Main).Soss.argvals(_mc)
_obs = (Main).Soss.observations(_mc)
_cfg = (Main).merge(_cfg, (_args = _args, _obs = _obs))
let
begin
N = _args.N
(p, _ctx, _retn) = let targs = (Main).Soss.TildeArgs(GeneralizedGenerated.NGG.TypeLevel{Symbol, "Buf{9}()"}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, (Main).NamedTuple{()}(()), static(false), static(false))
begin
(Soss.tilde_rand)(:p, (Main).Uniform(), _cfg, _ctx, targs)
end
end
(x, _ctx, _retn) = let targs = (Main).Soss.TildeArgs(GeneralizedGenerated.NGG.TypeLevel{Symbol, "Buf{9}()"}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, (Main).NamedTuple{()}(()), static(false), static(false))
begin
(Soss.tilde_rand)(:x, (Main).For(begin
let freevars = (p,)
(GeneralizedGenerated.Closure){function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Base.typeof(freevars)}(freevars)
end
end, N), _cfg, _ctx, targs)
end
end
end
end
_retn
end
end, Tuple{Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{91}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
││││││ failed to optimize: GeneralizedGenerated.var"#_#4"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::GeneralizedGenerated.Closure{function = (_mc, _cfg, _ctx;) -> begin
begin
$(Expr(:meta, :((Main).inline)))
local _retn
_args = (Main).Soss.argvals(_mc)
_obs = (Main).Soss.observations(_mc)
_cfg = (Main).merge(_cfg, (_args = _args, _obs = _obs))
let
begin
N = _args.N
(p, _ctx, _retn) = let targs = (Main).Soss.TildeArgs(GeneralizedGenerated.NGG.TypeLevel{Symbol, "Buf{9}()"}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, (Main).NamedTuple{()}(()), static(false), static(false))
begin
(Soss.tilde_rand)(:p, (Main).Uniform(), _cfg, _ctx, targs)
end
end
(x, _ctx, _retn) = let targs = (Main).Soss.TildeArgs(GeneralizedGenerated.NGG.TypeLevel{Symbol, "Buf{9}()"}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, (Main).NamedTuple{()}(()), static(false), static(false))
begin
(Soss.tilde_rand)(:x, (Main).For(begin
let freevars = (p,)
(GeneralizedGenerated.Closure){function = (p, j;) -> begin
(Main).Bernoulli((Main).:/(p, (Main).:√(j)))
end, Base.typeof(freevars)}(freevars)
end
end, N), _cfg, _ctx, targs)
end
end
end
end
_retn
end
end, Tuple{Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{91}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
│││││└──────────────────────────────────────────────────────────────────────────
ERROR: There was an error during testing If instead, you do it like this: f(ctx) = Base.Fix1(ctx) do ctx, j
Bernoulli(ctx.p / j)
end
m2 = @model N begin
p ~ Uniform()
x ~ For(f((p=p,)), N)
end Then That got me wondering whether it makes more sense to change the GG implementation to take a similar approach. But again this isn't at all clear to me. For example, this attempt also fails: m3 = @model N begin
p ~ Uniform()
f(ctx) = Base.Fix1(ctx) do ctx, j
Bernoulli(ctx.p / j)
end
x ~ For(f((p=p,)), N)
end **Here's the test for m3**julia> @test_opt rand(m3(10))
JET-test failed at REPL[37]:1
Expression: #= REPL[37]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m3(10))
═════ 2 possible errors found ═════
┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.#rand#49(Base.pairs(Core.NamedTuple()), #self#, m)
│┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.rand(Soss.GLOBAL_RNG, m)
││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:34 Soss.#rand#51(Soss.NamedTuple(), Soss.NamedTuple(), #self#, rng, mc)
│││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:36 f(cfg, ctx)
││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 GeneralizedGenerated.#_#4(Core.tuple(Base.pairs(Core.NamedTuple()), closure), args...)
│││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 _(Base.getproperty(closure, :frees), args...)
││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.#_#10(Core.tuple(Base.pairs(Core.NamedTuple()), #self#), pargs...)
│││││││┌ @ /home/chad/git/Soss.jl/src/primitives/interpret.jl:25 f(Core.apply_type(Core.NamedTuple, (:p,))(Core.tuple(p)))
││││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 (::ggfunc-f)(::NamedTuple{(:p,), Tuple{Float64}})
│││││││││ failed to optimize: (::ggfunc-f)(::NamedTuple{(:p,), Tuple{Float64}})
││││││││└───────────────────────────────────────────────────────────────────────────────────
│││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{118}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
││││││││ failed to optimize: GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{118}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
│││││││└───────────────────────────────────────────────────────────────────────────────────
ERROR: There was an error during testing A couple of things are surprising:
It seems likely to me that dynamic dispatch could become a problem for some models, even if it's not for this one. But I'm feeling kind of stuck anyway, so I'm going to set it aside for a while. Let me know if you have any ideas, and we can pick back up. |
Sorry for the delay, I'd get started working on this.
|
Some rough ideas: is the issue possibly related to a run-time invocation of @generated function For(f::F, inds::I) where {F,I<:Tuple}
eltypes = Tuple{eltype.(I.types)...}
T = Core.Compiler.return_type(f, eltypes)
quote
$(Expr(:meta, :inline))
For{$T,F,I}(f, inds)
end
end |
It seems |
Sorry, I've been changing more things. I'll get to something stable and send you the Manifest.
I originally had this outside the quote, I forget why I moved it. But anyway, I had to add this for it to work at all: @generated function MeasureTheory.For(f::GG.Closure{F,Free}, inds::I) where {F,Free,I<:Tuple}
freetypes = Free.types
eltypes = eltype.(I.types)
T = Core.Compiler.return_type(F, Tuple{freetypes..., eltypes...})
quote
$(Expr(:meta, :inline))
For{$T,GG.Closure{F,Free},I}(f, inds)
end
end |
Thanks, just send me them when done and I can work on this. I will need some time to address the concrete issue, and thanks a lot for providing above cases. |
Ok, I sent you a Zulip message with lots of details |
This seems maybe helpful: julia> g = mk_function(Main, :(ctx -> (j -> Bernoulli(ctx.p/j))))
function = (ctx;) -> begin
begin
begin
let freevars = (ctx,)
(GeneralizedGenerated.Closure){function = (ctx, j;) -> begin
begin
(Main).Bernoulli((Main).:/(ctx.p, j))
end
end, Base.typeof(freevars)}(freevars)
end
end
end
end
julia> m3 = @model n begin
p ~ Uniform()
y ~ For(g((p=p,)), n)
end;
julia> rand(m3(3))
(p = 0.9464608557346492, y = Bool[1, 1, 0])
julia> @test_opt rand(m3(3))
JET-test failed at REPL[162]:1
Expression: #= REPL[162]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m3(3))
═════ 3 possible errors found ═════
┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.#rand#43(Base.pairs(Core.NamedTuple()), #self#, m)
│┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.rand(Soss.GLOBAL_RNG, m)
││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:34 Soss.#rand#45(Soss.NamedTuple(), Soss.NamedTuple(), #self#, rng, mc)
│││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:36 f(cfg, ctx)
││││┌ @ /home/chad/git/GeneralizedGenerated.jl/src/closure.jl:6 GeneralizedGenerated.#_#4(Core.tuple(Base.pairs(Core.NamedTuple()), closure), args...)
│││││┌ @ /home/chad/git/GeneralizedGenerated.jl/src/closure.jl:6 _(Base.getproperty(closure, :frees), args...)
││││││┌ @ /home/chad/git/GeneralizedGenerated.jl/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.#_#10(Core.tuple(Base.pairs(Core.NamedTuple()), #self#), pargs...)
│││││││┌ @ /home/chad/git/Soss.jl/src/primitives/interpret.jl:22 %48(%49)
││││││││ runtime dispatch detected: %48::Any(%49::NamedTuple{(:p,), Tuple{Float64}})
│││││││└─────────────────────────────────────────────────────────
│││││││┌ @ /home/chad/git/MeasureTheory.jl/src/combinators/for.jl:263 MeasureTheory.For(%50, %53)
││││││││ runtime dispatch detected: MeasureTheory.For(%50::Any, %53::Tuple{ArrayInterface.OptionallyStaticUnitRange{Static.StaticInt{1}, Int64}})
│││││││└──────────────────────────────────────────────────────────────
│││││││┌ @ /home/chad/git/GeneralizedGenerated.jl/src/ngg/runtime_fns.jl:21 tilde_rand(:y, %54, %3, %45, %47)
││││││││ runtime dispatch detected: tilde_rand(:y::Symbol, %54::Any, %3::NamedTuple{(:rng, :args, :obs), Tuple{Random._GLOBAL_RNG, NamedTuple{(:n,), Tuple{Int64}}, NamedTuple{(), Tuple{}}}}, %45::NamedTuple{(:p,), Tuple{Float64}}, %47::Soss.TildeArgs{DataType, DataType, NamedTuple{(), Tuple{}}})
│││││││└────────────────────────────────────────────────────────────────────
ERROR: There was an error during testing It seems to get pretty close, but then there's this │││││││┌ @ /home/chad/git/MeasureTheory.jl/src/combinators/for.jl:263 MeasureTheory.For(%50, %53)
││││││││ runtime dispatch detected: MeasureTheory.For(%50::Any, %53::Tuple{ArrayInterface.OptionallyStaticUnitRange{Static.StaticInt{1}, Int64}}) The first argument passed to julia> Core.Compiler.return_type(g, Tuple{typeof((p = 0.2,))})
GeneralizedGenerated.Closure{function = (ctx, j;) -> begin
begin
(Main).Bernoulli((Main).:/(ctx.p, j))
end
end, Tuple{NamedTuple{(:p,), Tuple{Float64}}}} But when it's inside another GG function, it falls back to |
Hi @thautwarm ,
In a new version of Soss, I'm trying to really nail down type stability. Using GG I was having lots of trouble getting constructions like
to be type-stable.
After lots of exploring, I realized this could be solved by putting this function outside of GG:
Then within GG, it could be called as
That works great, but it's tricky to automate, since the
f
needs to be made available separately within the module, and I'd rather not useeval
to put it there.I eventually found this to work:
So now, I think I can just rewrite things this way before calling GG. But I wanted to check with you before doing this, since it seems plausible the current Closure approach could be updated to something similar, and maybe it would help other cases with type stability. What do you think?
The text was updated successfully, but these errors were encountered: