Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closures with Fix1 for type stability #69

Open
cscherrer opened this issue Jan 26, 2022 · 14 comments
Open

Closures with Fix1 for type stability #69

cscherrer opened this issue Jan 26, 2022 · 14 comments

Comments

@cscherrer
Copy link
Collaborator

Hi @thautwarm ,

In a new version of Soss, I'm trying to really nail down type stability. Using GG I was having lots of trouble getting constructions like

For(n -> Normal+ β * x[n], σ), 1:N)

to be type-stable.

After lots of exploring, I realized this could be solved by putting this function outside of GG:

f(ctx, n) = Normal(ctx.α + ctx.β * ctx.x[n], ctx.σ)
f(ctx::NamedTuple) = Base.Fix1(f, ctx)

Then within GG, it could be called as

For(f((;α,β,σ,x)), 1:N)

That works great, but it's tricky to automate, since the f needs to be made available separately within the module, and I'd rather not use eval to put it there.

I eventually found this to work:

For(let f = ctx -> Base.Fix1(ctx) do ctx, n
            Normal(ctx.α + ctx.β * ctx.x[n], ctx.σ)
        end
        f((; α, β, σ, x))
    end, 1:N)

So now, I think I can just rewrite things this way before calling GG. But I wanted to check with you before doing this, since it seems plausible the current Closure approach could be updated to something similar, and maybe it would help other cases with type stability. What do you think?

@thautwarm
Copy link
Member

Hi @cscherrer, could you give an MWE for why the following Soss model is not type-stable?

For(n -> Normal+ β * x[n], σ), 1:N)

since it seems plausible the current Closure approach could be updated to something similar, and maybe it would help other cases with type stability. What do you think?

As you mentioned the implementation of GG's Closure here, I'd guess the type stability issue is caused by the use of Core.Box.

Core.Box allows capturing type-unstable free variables, so that you can reassign heterogenous data to a free variable. However, it seems that it breaks type stability.

if var.name in scope.bound_inits
push!(block, :($name = Core.Box($name)))
else
push!(block, :($name = Core.Box()))
end

@thautwarm
Copy link
Member

I now have time to work on GG and Soss.

@cscherrer
Copy link
Collaborator Author

Great! Yesterday's covid booster is hitting me pretty hard, but I'll send you an update when I'm feeling better. Lots going on 🙂

@thautwarm
Copy link
Member

thautwarm commented Jan 29, 2022

Sorry to hear that.. Hope to see you get better, and we will work together again!

@cscherrer
Copy link
Collaborator Author

Hi @cscherrer, could you give an MWE for why the following Soss model is not type-stable?

For(n -> Normal+ β * x[n], σ), 1:N)

Using MeasureBase#dev, MeasureTheory#dev2, and Soss#dev, I get

julia> using Soss

julia> using MeasureTheory

julia> m = @model n begin
           y ~ For(n) do j
               Bernoulli(1/j)
           end
       end;

julia> s = rand(m(10))
(y = Bool[1, 0, 0, 0, 0, 0, 0, 0, 1, 0],)

julia> using JET

julia> @test_opt rand(m(10))
JET-test failed at REPL[44]:1
  Expression: #= REPL[44]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m(10))
  ═════ 7 possible errors found ═════
  ┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:24 Soss.rand(Soss.GLOBAL_RNG, m)
  │┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:20 Soss._rand(_, m, Soss.argvals(c))(rng)
  ││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 GeneralizedGenerated.#_#4(Core.tuple(Base.pairs(Core.NamedTuple()), closure), args...)
  │││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 _(Base.getproperty(closure, :frees), args...)
  ││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.#_#10(Core.tuple(Base.pairs(Core.NamedTuple()), #self#), pargs...)
  │││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:116 Base.getproperty(Main, :For)(function = (j;) -> begin
      (Main).Bernoulli((Main).:/(1, j))
  end, n)
  ││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:199 MeasureBase.For(f, Base.OneTo(n))
  │││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:198 MeasureBase.For(f, inds)
  ││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:18 For(::ggfunc-function, ::Tuple{Base.OneTo{Int64}})
  │││││││││ failed to optimize: For(::ggfunc-function, ::Tuple{Base.OneTo{Int64}})
  ││││││││└───────────────────────────────────────────────────────────
  │││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:198 For(::ggfunc-function, ::Base.OneTo{Int64})
  ││││││││ failed to optimize: For(::ggfunc-function, ::Base.OneTo{Int64})
  │││││││└────────────────────────────────────────────────────────────
  ││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:199 For(::ggfunc-function, ::Int64)
  │││││││ failed to optimize: For(::ggfunc-function, ::Int64)
  ││││││└────────────────────────────────────────────────────────────
  │││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:116 y = Base.getproperty(Main, :rand)(_rng, Base.getproperty(Main, :For)(function = (j;) -> begin
      (Main).Bernoulli((Main).:/(1, j))
  end, n))
  ││││││┌ @ /home/chad/git/MeasureBase.jl/src/rand.jl:7 MeasureBase.rand(rng, MeasureBase.Float64, d)
  │││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:212 MeasureBase._rand_product(rng, _, MeasureBase.marginals(d), _)
  ││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/product.jl:33 MeasureBase.map(#28, mar)
  │││││││││┌ @ abstractarray.jl:2849 Base.collect_similar(A, Base.Generator(f, A))
  ││││││││││┌ @ array.jl:653 Base._collect(cont, itr, Base.IteratorEltype(itr), Base.IteratorSize(itr))
  │││││││││││┌ @ array.jl:701 et = Base.promote_typejoin_union(T)
  ││││││││││││┌ @ promotion.jl:170 Base.promote_typejoin_union(Base.getproperty(_, :a))
  │││││││││││││┌ @ promotion.jl:190 unwrapva(%35)
  ││││││││││││││ runtime dispatch detected: unwrapva(%35::Any)
  │││││││││││││└────────────────────
  │││││││││││││┌ @ promotion.jl:194 Base.typejoin(%57, %59)
  ││││││││││││││ runtime dispatch detected: Base.typejoin(%57::Any, %59::Any)
  │││││││││││││└────────────────────
  │││││││││││││┌ @ promotion.jl:198 Base.promote_typejoin_union(%50)
  ││││││││││││││ runtime dispatch detected: Base.promote_typejoin_union(%50::Any)
  │││││││││││││└────────────────────
  │││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Int64, ::Random._GLOBAL_RNG)
  ││││││ failed to optimize: GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Int64, ::Random._GLOBAL_RNG)
  │││││└───────────────────────────────────────────────────────────────────────────────────
  
ERROR: There was an error during testing

I think this comes down to

@generated function For(f::F, inds::I) where {F,I<:Tuple}
    eltypes = Tuple{eltype.(I.types)...}
    quote
        $(Expr(:meta, :inline))
        T = Core.Compiler.return_type(f, $eltypes)
        For{T,F,I}(f, inds)
    end
end

I've had some comments that using return_type is a bad idea, but I really need as much of this as possible to happen at compile time, and so far I don't see a good alternative. Also worth noting that MappedArrays.jl takes a very similar approach, see e.g. https://github.com/JuliaArrays/MappedArrays.jl/blob/46bf47f3388d011419fe43404214c1b7a44a49cc/src/MappedArrays.jl#L61

Interestingly, this fixes it:

julia> f(ctx) = Base.Fix1(ctx) do ctx, j
           Bernoulli(1/j)
       end
f (generic function with 1 method)

julia> m = @model n begin
           y ~ For(f(NamedTuple()), n)
       end;

julia> s = rand(m(10))
(y = Bool[1, 0, 0, 1, 0, 0, 0, 0, 0, 0],)

julia> using JET

julia> @test_opt rand(m(10))
Test Passed
  Expression: #= REPL[7]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m(10))

In general, ctx can be set to a named tuple of local variables, and f could be a runtime generated function. In addition to type stability, this gives a way of breaking up a GG function into smaller pieces, which may also solve the problem of source code size limitations in GG :)

@cscherrer
Copy link
Collaborator Author

Looking at this a little closer, I think my approach is kind of clumsy. Do you think it could be possible to instead have an option to just avoid boxing in the first place? Maybe something like an opaque closure, or a faked version of that? Even something throwing an error on type instability would be useful.

@cscherrer
Copy link
Collaborator Author

Just a little more information on this, then I'm going to try to set it aside for a while...

Say you start with

using Soss, JET

m1 = @model N begin
    p ~ Uniform()
    x ~ For(N) do j
            Bernoulli(p / j)
        end
    end

Then JET.@test_opt rand(m1(10)) fails. This is possibly because of boxing the state for the closure, though I don't understand the GG implementation well enough to confirm this.

**Here's the test for m1**
julia> @test_opt rand(m1(10))
JET-test failed at REPL[38]:1
  Expression: #= REPL[38]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m1(10))
  ═════ 16 possible errors found ═════
  ┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.#rand#49(Base.pairs(Core.NamedTuple()), #self#, m)
  │┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.rand(Soss.GLOBAL_RNG, m)
  ││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:34 Soss.#rand#51(Soss.NamedTuple(), Soss.NamedTuple(), #self#, rng, mc)
  │││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:36 f(cfg, ctx)
  ││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 GeneralizedGenerated.#_#4(Core.tuple(Base.pairs(Core.NamedTuple()), closure), args...)
  │││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 _(Base.getproperty(closure, :frees), args...)
  ││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.#_#10(Core.tuple(Base.pairs(Core.NamedTuple()), #self#), pargs...)
  │││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:24 tilde_rand(:x, Base.getproperty(Main, :For)(Core.apply_type(Closure, function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Base.typeof(freevars))(freevars), N), _cfg, _ctx, targs)
  ││││││││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:43 x = Soss.rand(Base.getproperty(cfg, :rng), d)
  │││││││││┌ @ /home/chad/git/MeasureBase.jl/src/rand.jl:7 MeasureBase.rand(rng, MeasureBase.Float64, d)
  ││││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:220 MeasureBase._rand_product(rng, _, MeasureBase.marginals(d), _)
  │││││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/product.jl:33 MeasureBase.map(#28, mar)
  ││││││││││││┌ @ abstractarray.jl:2849 Base.collect_similar(A, Base.Generator(f, A))
  │││││││││││││┌ @ array.jl:653 Base._collect(cont, itr, Base.IteratorEltype(itr), Base.IteratorSize(itr))
  ││││││││││││││┌ @ array.jl:744 y = Base.iterate(itr)
  │││││││││││││││┌ @ generator.jl:44 y = Base.iterate(Core.tuple(Base.getproperty(g, :iter)), s...)
  ││││││││││││││││┌ @ abstractarray.jl:1142 #self#(A, Core.tuple(Base.eachindex(A)))
  │││││││││││││││││┌ @ abstractarray.jl:1144 Base.getindex(A, Base.getindex(y, 1))
  ││││││││││││││││││┌ @ /home/chad/.julia/packages/MappedArrays/bS6Yp/src/MappedArrays.jl:166 Base.getproperty(A, :f)(Base.getindex(Core.tuple(Base.getproperty(A, :data)), i...))
  │││││││││││││││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 GeneralizedGenerated.#_#4(Core.tuple(Base.pairs(Core.NamedTuple()), closure), args...)
  ││││││││││││││││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 Base.merge(Base.NamedTuple(), kwargs)
  │││││││││││││││││││││┌ @ namedtuple.jl:303 Core.apply_type(Base.NamedTuple, Core.tuple(names...))(Core.tuple(vals...))
  ││││││││││││││││││││││┌ @ boot.jl:601 Core.apply_type(Core.NamedTuple, _, Core.typeof(args))(args)
  │││││││││││││││││││││││┌ @ namedtuple.jl:96 _(args)
  ││││││││││││││││││││││││┌ @ tuple.jl:312 Base.convert(_, x)
  │││││││││││││││││││││││││┌ @ essentials.jl:344 Base.Val(_)
  ││││││││││││││││││││││││││┌ @ essentials.jl:701 %1()
  │││││││││││││││││││││││││││ runtime dispatch detected: %1::Type{Val{_A}} where _A()
  ││││││││││││││││││││││││││└─────────────────────
  │││││││││││││││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:5 (::GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}})(::Int64)
  ││││││││││││││││││││ failed to optimize: (::GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}})(::Int64)
  │││││││││││││││││││└──────────────────────────────────────────────────────────────────────────
  ││││││││││││││││││┌ @ /home/chad/.julia/packages/MappedArrays/bS6Yp/src/MappedArrays.jl:164 getindex(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, ::Int64)
  │││││││││││││││││││ failed to optimize: getindex(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, ::Int64)
  ││││││││││││││││││└─────────────────────────────────────────────────────────────────────────
  │││││││││││││││││┌ @ abstractarray.jl:1141 iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, ::Tuple{Base.OneTo{Int64}})
  ││││││││││││││││││ failed to optimize: iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, ::Tuple{Base.OneTo{Int64}})
  │││││││││││││││││└─────────────────────────
  ││││││││││││││││┌ @ abstractarray.jl:1141 iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}})
  │││││││││││││││││ failed to optimize: iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}})
  ││││││││││││││││└─────────────────────────
  │││││││││││││││┌ @ generator.jl:42 iterate(::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}})
  ││││││││││││││││ failed to optimize: iterate(::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}})
  │││││││││││││││└───────────────────
  ││││││││││││││┌ @ array.jl:754 Base.collect_to_with_first!(dest, v1, itr, st)
  │││││││││││││││┌ @ array.jl:760 Base.collect_to!(dest, itr, Base.+(i1, 1), st)
  ││││││││││││││││┌ @ array.jl:782 y = Base.iterate(itr, st)
  │││││││││││││││││┌ @ generator.jl:44 y = Base.iterate(Core.tuple(Base.getproperty(g, :iter)), s...)
  ││││││││││││││││││┌ @ abstractarray.jl:1141 iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, ::Tuple{Base.OneTo{Int64}, Int64})
  │││││││││││││││││││ failed to optimize: iterate(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, ::Tuple{Base.OneTo{Int64}, Int64})
  ││││││││││││││││││└─────────────────────────
  │││││││││││││││││┌ @ generator.jl:42 iterate(::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}}, ::Tuple{Base.OneTo{Int64}, Int64})
  ││││││││││││││││││ failed to optimize: iterate(::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}}, ::Tuple{Base.OneTo{Int64}, Int64})
  │││││││││││││││││└───────────────────
  ││││││││││││││┌ @ array.jl:741 Base._collect(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, ::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}}, ::Base.EltypeUnknown, ::Base.HasShape{1})
  │││││││││││││││ failed to optimize: Base._collect(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, ::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}}, ::Base.EltypeUnknown, ::Base.HasShape{1})
  ││││││││││││││└────────────────
  │││││││││││││┌ @ array.jl:653 Base.collect_similar(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, ::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}})
  ││││││││││││││ failed to optimize: Base.collect_similar(::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, ::Base.Generator{MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, MeasureBase.var"#28#29"{Float64, Random._GLOBAL_RNG}})
  │││││││││││││└────────────────
  │││││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/product.jl:32 MeasureBase._rand_product(::Random._GLOBAL_RNG, ::Type{Float64}, ::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, ::Type{Bernoulli{(:p,), Tuple{Float64}}})
  ││││││││││││ failed to optimize: MeasureBase._rand_product(::Random._GLOBAL_RNG, ::Type{Float64}, ::MappedArrays.ReadonlyMappedArray{Bernoulli{(:p,), Tuple{Float64}}, 1, Base.OneTo{Int64}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}}, ::Type{Bernoulli{(:p,), Tuple{Float64}}})
  │││││││││││└───────────────────────────────────────────────────────────────
  ││││││││││┌ @ /home/chad/git/MeasureBase.jl/src/combinators/for.jl:219 rand(::Random._GLOBAL_RNG, ::Type{Float64}, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}})
  │││││││││││ failed to optimize: rand(::Random._GLOBAL_RNG, ::Type{Float64}, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}})
  ││││││││││└────────────────────────────────────────────────────────────
  │││││││││┌ @ /home/chad/git/MeasureBase.jl/src/rand.jl:7 rand(::Random._GLOBAL_RNG, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}})
  ││││││││││ failed to optimize: rand(::Random._GLOBAL_RNG, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}})
  │││││││││└───────────────────────────────────────────────
  ││││││││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:42 Soss.tilde_rand(::Symbol, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}}, ::NamedTuple{(:rng, :_args, :_obs), Tuple{Random._GLOBAL_RNG, NamedTuple{(:N,), Tuple{Int64}}, NamedTuple{(), Tuple{}}}}, ::NamedTuple{(:p,), Tuple{Float64}}, ::Soss.TildeArgs{DataType, DataType, NamedTuple{(), Tuple{}}, Static.False, Static.False})
  │││││││││ failed to optimize: Soss.tilde_rand(::Symbol, ::For{Bernoulli{(:p,), Tuple{Float64}}, GeneralizedGenerated.Closure{function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Tuple{Float64}}, Tuple{Base.OneTo{Int64}}}, ::NamedTuple{(:rng, :_args, :_obs), Tuple{Random._GLOBAL_RNG, NamedTuple{(:N,), Tuple{Int64}}, NamedTuple{(), Tuple{}}}}, ::NamedTuple{(:p,), Tuple{Float64}}, ::Soss.TildeArgs{DataType, DataType, NamedTuple{(), Tuple{}}, Static.False, Static.False})
  ││││││││└────────────────────────────────────────────────────
  │││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{91}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
  ││││││││ failed to optimize: GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{91}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
  │││││││└───────────────────────────────────────────────────────────────────────────────────
  │││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:5 GeneralizedGenerated.var"#_#4"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::GeneralizedGenerated.Closure{function = (_mc, _cfg, _ctx;) -> begin
      begin
          $(Expr(:meta, :((Main).inline)))
          local _retn
          _args = (Main).Soss.argvals(_mc)
          _obs = (Main).Soss.observations(_mc)
          _cfg = (Main).merge(_cfg, (_args = _args, _obs = _obs))
          let
              begin
                  N = _args.N
                  (p, _ctx, _retn) = let targs = (Main).Soss.TildeArgs(GeneralizedGenerated.NGG.TypeLevel{Symbol, "Buf{9}()"}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, (Main).NamedTuple{()}(()), static(false), static(false))
                          begin
                              (Soss.tilde_rand)(:p, (Main).Uniform(), _cfg, _ctx, targs)
                          end
                      end
                  (x, _ctx, _retn) = let targs = (Main).Soss.TildeArgs(GeneralizedGenerated.NGG.TypeLevel{Symbol, "Buf{9}()"}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, (Main).NamedTuple{()}(()), static(false), static(false))
                          begin
                              (Soss.tilde_rand)(:x, (Main).For(begin
                                          let freevars = (p,)
                                              (GeneralizedGenerated.Closure){function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Base.typeof(freevars)}(freevars)
                                          end
                                      end, N), _cfg, _ctx, targs)
                          end
                      end
              end
          end
          _retn
      end
  end, Tuple{Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{91}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
  ││││││ failed to optimize: GeneralizedGenerated.var"#_#4"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::GeneralizedGenerated.Closure{function = (_mc, _cfg, _ctx;) -> begin
      begin
          $(Expr(:meta, :((Main).inline)))
          local _retn
          _args = (Main).Soss.argvals(_mc)
          _obs = (Main).Soss.observations(_mc)
          _cfg = (Main).merge(_cfg, (_args = _args, _obs = _obs))
          let
              begin
                  N = _args.N
                  (p, _ctx, _retn) = let targs = (Main).Soss.TildeArgs(GeneralizedGenerated.NGG.TypeLevel{Symbol, "Buf{9}()"}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, (Main).NamedTuple{()}(()), static(false), static(false))
                          begin
                              (Soss.tilde_rand)(:p, (Main).Uniform(), _cfg, _ctx, targs)
                          end
                      end
                  (x, _ctx, _retn) = let targs = (Main).Soss.TildeArgs(GeneralizedGenerated.NGG.TypeLevel{Symbol, "Buf{9}()"}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, (Main).NamedTuple{()}(()), static(false), static(false))
                          begin
                              (Soss.tilde_rand)(:x, (Main).For(begin
                                          let freevars = (p,)
                                              (GeneralizedGenerated.Closure){function = (p, j;) -> begin
      (Main).Bernoulli((Main).:/(p, (Main).:(j)))
  end, Base.typeof(freevars)}(freevars)
                                          end
                                      end, N), _cfg, _ctx, targs)
                          end
                      end
              end
          end
          _retn
      end
  end, Tuple{Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{91}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
  │││││└──────────────────────────────────────────────────────────────────────────
  
ERROR: There was an error during testing

If instead, you do it like this:

f(ctx) = Base.Fix1(ctx) do ctx, j
    Bernoulli(ctx.p / j)
end

m2 = @model N begin
    p ~ Uniform()
    x ~ For(f((p=p,)), N)
end

Then @test_opt rand(m2(10)) passes just fine. So my next thought was to have f(ctx) generated at runtime, using GG. But I'm not sure it's possible, since it seems it would rely on the same mechanism that failed for the first case.

That got me wondering whether it makes more sense to change the GG implementation to take a similar approach. But again this isn't at all clear to me. For example, this attempt also fails:

m3 = @model N begin
    p ~ Uniform()
    f(ctx) = Base.Fix1(ctx) do ctx, j
        Bernoulli(ctx.p / j)
    end
    x ~ For(f((p=p,)), N)
end
**Here's the test for m3**
julia> @test_opt rand(m3(10))
JET-test failed at REPL[37]:1
  Expression: #= REPL[37]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m3(10))
  ═════ 2 possible errors found ═════
  ┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.#rand#49(Base.pairs(Core.NamedTuple()), #self#, m)
  │┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.rand(Soss.GLOBAL_RNG, m)
  ││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:34 Soss.#rand#51(Soss.NamedTuple(), Soss.NamedTuple(), #self#, rng, mc)
  │││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:36 f(cfg, ctx)
  ││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 GeneralizedGenerated.#_#4(Core.tuple(Base.pairs(Core.NamedTuple()), closure), args...)
  │││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6 _(Base.getproperty(closure, :frees), args...)
  ││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.#_#10(Core.tuple(Base.pairs(Core.NamedTuple()), #self#), pargs...)
  │││││││┌ @ /home/chad/git/Soss.jl/src/primitives/interpret.jl:25 f(Core.apply_type(Core.NamedTuple, (:p,))(Core.tuple(p)))
  ││││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 (::ggfunc-f)(::NamedTuple{(:p,), Tuple{Float64}})
  │││││││││ failed to optimize: (::ggfunc-f)(::NamedTuple{(:p,), Tuple{Float64}})
  ││││││││└───────────────────────────────────────────────────────────────────────────────────
  │││││││┌ @ /home/chad/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{118}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
  ││││││││ failed to optimize: GeneralizedGenerated.NGG.var"#_#10"(::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ggfunc-function, ::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{118}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}, ::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, ::NamedTuple{(), Tuple{}})
  │││││││└───────────────────────────────────────────────────────────────────────────────────
  
ERROR: There was an error during testing

A couple of things are surprising:

  1. Though m1 and m3 both fail, there seem to be far fewer issues with m3
  2. Dynamic dispatch usually leads to big slowdowns in benchmark times, but here I don't see a difference at all.

It seems likely to me that dynamic dispatch could become a problem for some models, even if it's not for this one. But I'm feeling kind of stuck anyway, so I'm going to set it aside for a while. Let me know if you have any ideas, and we can pick back up.

@thautwarm
Copy link
Member

Sorry for the delay, I'd get started working on this.
Right now I cloned Soss#dev but have issues in executing instantiate.

ERROR: Unsatisfiable requirements detected for package TransformVariables [84d833dd]:
 TransformVariables [84d833dd] log:
 ├─possible versions are: 0.1.0-0.5.0 or uninstalled
 ├─restricted to versions 0.5 by Soss [8ce77f84], leaving only versions 0.5.0
 │ └─Soss [8ce77f84] log:
 │   ├─possible versions are: 0.20.9 or uninstalled
 │   └─Soss [8ce77f84] is fixed to version 0.20.9
 └─restricted by compatibility requirements with MeasureTheory [eadaa1a4] to versions: 0.4.0-0.4.1 — no versions left
   └─MeasureTheory [eadaa1a4] log:
     ├─possible versions are: 0.2.1-0.13.2 or uninstalled
     └─restricted to versions 0.13 by Soss [8ce77f84], leaving only versions 0.13.0-0.13.2
       └─Soss [8ce77f84] log: see above

@thautwarm
Copy link
Member

thautwarm commented Feb 9, 2022

I think this comes down to

@generated function For(f::F, inds::I) where {F,I<:Tuple}
    eltypes = Tuple{eltype.(I.types)...}
    quote
        $(Expr(:meta, :inline))
        T = Core.Compiler.return_type(f, $eltypes)
        For{T,F,I}(f, inds)
    end
end

Some rough ideas: is the issue possibly related to a run-time invocation of return_type? Is there any issue of lifting the inference to codegen time as follow?

@generated function For(f::F, inds::I) where {F,I<:Tuple}
    eltypes = Tuple{eltype.(I.types)...}
    T = Core.Compiler.return_type(f, eltypes)
    quote
        $(Expr(:meta, :inline))
        For{$T,F,I}(f, inds)
    end
end

@thautwarm
Copy link
Member

It seems f is expected to be later extended (add new methods), so Core.Compiler.return_type might return different results?

@cscherrer
Copy link
Collaborator Author

Right now I cloned Soss#dev but have issues in executing instantiate.

Sorry, I've been changing more things. I'll get to something stable and send you the Manifest.

is the issue possibly related to a run-time invocation of return_type?

I originally had this outside the quote, I forget why I moved it. But anyway, I had to add this for it to work at all:

@generated function MeasureTheory.For(f::GG.Closure{F,Free}, inds::I) where {F,Free,I<:Tuple}
    freetypes = Free.types
    eltypes = eltype.(I.types)
    T = Core.Compiler.return_type(F, Tuple{freetypes..., eltypes...})
    quote
        $(Expr(:meta, :inline))
        For{$T,GG.Closure{F,Free},I}(f, inds)
    end
end

@thautwarm
Copy link
Member

Sorry, I've been changing more things. I'll get to something stable and send you the Manifest.

Thanks, just send me them when done and I can work on this.

I will need some time to address the concrete issue, and thanks a lot for providing above cases.

@cscherrer
Copy link
Collaborator Author

Ok, I sent you a Zulip message with lots of details

@cscherrer
Copy link
Collaborator Author

This seems maybe helpful:

julia> g = mk_function(Main, :(ctx -> (j -> Bernoulli(ctx.p/j))))
function = (ctx;) -> begin
    begin
        begin
            let freevars = (ctx,)
                (GeneralizedGenerated.Closure){function = (ctx, j;) -> begin
    begin
        (Main).Bernoulli((Main).:/(ctx.p, j))
    end
end, Base.typeof(freevars)}(freevars)
            end
        end
    end
end

julia> m3 = @model n begin
           p ~ Uniform()
           y ~ For(g((p=p,)), n)
       end;

julia> rand(m3(3))
(p = 0.9464608557346492, y = Bool[1, 1, 0])

julia> @test_opt rand(m3(3))
JET-test failed at REPL[162]:1
  Expression: #= REPL[162]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m3(3))
  ═════ 3 possible errors found ═════
  ┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.#rand#43(Base.pairs(Core.NamedTuple()), #self#, m)
  │┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:19 Soss.rand(Soss.GLOBAL_RNG, m)
  ││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:34 Soss.#rand#45(Soss.NamedTuple(), Soss.NamedTuple(), #self#, rng, mc)
  │││┌ @ /home/chad/git/Soss.jl/src/primitives/rand.jl:36 f(cfg, ctx)
  ││││┌ @ /home/chad/git/GeneralizedGenerated.jl/src/closure.jl:6 GeneralizedGenerated.#_#4(Core.tuple(Base.pairs(Core.NamedTuple()), closure), args...)
  │││││┌ @ /home/chad/git/GeneralizedGenerated.jl/src/closure.jl:6 _(Base.getproperty(closure, :frees), args...)
  ││││││┌ @ /home/chad/git/GeneralizedGenerated.jl/src/ngg/runtime_fns.jl:83 GeneralizedGenerated.NGG.#_#10(Core.tuple(Base.pairs(Core.NamedTuple()), #self#), pargs...)
  │││││││┌ @ /home/chad/git/Soss.jl/src/primitives/interpret.jl:22 %48(%49)
  ││││││││ runtime dispatch detected: %48::Any(%49::NamedTuple{(:p,), Tuple{Float64}})
  │││││││└─────────────────────────────────────────────────────────
  │││││││┌ @ /home/chad/git/MeasureTheory.jl/src/combinators/for.jl:263 MeasureTheory.For(%50, %53)
  ││││││││ runtime dispatch detected: MeasureTheory.For(%50::Any, %53::Tuple{ArrayInterface.OptionallyStaticUnitRange{Static.StaticInt{1}, Int64}})
  │││││││└──────────────────────────────────────────────────────────────
  │││││││┌ @ /home/chad/git/GeneralizedGenerated.jl/src/ngg/runtime_fns.jl:21 tilde_rand(:y, %54, %3, %45, %47)
  ││││││││ runtime dispatch detected: tilde_rand(:y::Symbol, %54::Any, %3::NamedTuple{(:rng, :args, :obs), Tuple{Random._GLOBAL_RNG, NamedTuple{(:n,), Tuple{Int64}}, NamedTuple{(), Tuple{}}}}, %45::NamedTuple{(:p,), Tuple{Float64}}, %47::Soss.TildeArgs{DataType, DataType, NamedTuple{(), Tuple{}}})
  │││││││└────────────────────────────────────────────────────────────────────
  
ERROR: There was an error during testing

It seems to get pretty close, but then there's this

  │││││││┌ @ /home/chad/git/MeasureTheory.jl/src/combinators/for.jl:263 MeasureTheory.For(%50, %53)
  ││││││││ runtime dispatch detected: MeasureTheory.For(%50::Any, %53::Tuple{ArrayInterface.OptionallyStaticUnitRange{Static.StaticInt{1}, Int64}})

The first argument passed to For is g((p=p,)). The compiler can figure this out just fine for a given p:

julia> Core.Compiler.return_type(g, Tuple{typeof((p = 0.2,))})
GeneralizedGenerated.Closure{function = (ctx, j;) -> begin
    begin
        (Main).Bernoulli((Main).:/(ctx.p, j))
    end
end, Tuple{NamedTuple{(:p,), Tuple{Float64}}}}

But when it's inside another GG function, it falls back to Any

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants