Skip to content

Commit

Permalink
Type inference free comprehensions
Browse files Browse the repository at this point in the history
- Move lowering of the comprehension to a julia macro
- Use typejoin to compute the array/dict type while
  being optimistic about the first element
- Fix comprehension colon logic
- Make tests pass
  • Loading branch information
carnaval committed May 25, 2015
1 parent c6d8ace commit c4c6817
Show file tree
Hide file tree
Showing 26 changed files with 365 additions and 338 deletions.
20 changes: 14 additions & 6 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ function reinterpret{T,S,N}(::Type{T}, a::Array{S}, dims::NTuple{N,Int})
ccall(:jl_reshape_array, Array{T,N}, (Any, Any, Any), Array{T,N}, a, dims)
end

function reinterpret{T,N}(::Type{T}, a::Array{Union()}, dims::NTuple{N,Int})
if length(a) == 0 && prod(dims) == 0
Array(T, dims)
else
error("cannot reinterpret non-empty Union() array")
end
end

# reshaping to same # of dimensions
function reshape{T,N}(a::Array{T,N}, dims::NTuple{N,Int})
if prod(dims) != length(a)
Expand Down Expand Up @@ -319,15 +327,15 @@ function getindex(A::Array, I::UnitRange{Int})
return X
end

function getindex{T<:Real}(A::Array, I::AbstractVector{T})
return [ A[i] for i in to_index(I) ]
function getindex{T<:Real,R}(A::Array{R}, I::AbstractVector{T})
return R[ A[i] for i in to_index(I) ]
end
function getindex{T<:Real}(A::Range, I::AbstractVector{T})
return [ A[i] for i in to_index(I) ]
function getindex{T<:Real,R}(A::Range{R}, I::AbstractVector{T})
return R[ A[i] for i in to_index(I) ]
end
function getindex(A::Range, I::AbstractVector{Bool})
function getindex{R}(A::Range{R}, I::AbstractVector{Bool})
checkbounds(A, I)
return [ A[i] for i in to_index(I) ]
return R[ A[i] for i in to_index(I) ]
end


Expand Down
280 changes: 277 additions & 3 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,13 @@ function length_checked_equal(args...)
n
end

map(f::Function, a::Array{Any,1}) = Any[ f(a[i]) for i=1:length(a) ]
function map(f::Function, a::Array{Any,1})
A = Array(Any,length(a))
for i=1:length(a)
A[i] = f(a[i])
end
A
end

function precompile(f::ANY, args::Tuple)
if isa(f,DataType)
Expand Down Expand Up @@ -212,9 +218,15 @@ function ==(v1::SimpleVector, v2::SimpleVector)
return true
end

map(f, v::SimpleVector) = Any[ f(v[i]) for i = 1:length(v) ]
function map(f, v::SimpleVector)
A = Array(Any,length(v))
for i=1:length(v)
A[i] = f(v[i])
end
A
end

getindex(v::SimpleVector, I::AbstractArray) = svec(Any[ v[i] for i in I ]...)
getindex(v::SimpleVector, I::AbstractArray) = svec(map(i->v[i],I)...)

function isassigned(v::SimpleVector, i::Int)
1 <= i <= length(v) || return false
Expand All @@ -226,3 +238,265 @@ end
type Colon
end
const (:) = Colon()


start(A::Array) = 1
next(a::Array,i) = (a[i],i+1)
done(a::Array,i) = (i > length(a))
length(a::Array) = arraylen(a)
gensym() = ccall(:jl_gensym, Any, ())::Symbol
getindex(A::Array, i0::Int) = arrayref(A,i0)
function vect{T}(args::T...)
A = Array(T, length(args))
for i=1:length(A)
A[i] = args[i]
end
A
end

# this should be only tuples/vars
function gen_var(f, ex, acc)
if isa(ex, Expr)
for a in ex.args
acc = gen_var(f, a,acc)
end
acc
elseif isa(ex,Symbol)
:($acc; $(f(ex)))
else
error("unknown iteration spec ?")
end
end

function gen_comp(T, body, iter0, isdict::Bool)
niter = length(iter0)
sz = Array(Any,niter) # dim tuple
sz0 = Array(Any,niter) # dim tuple if an input is empty
iterblock = Array(Any,niter) # canonicalized iteration spec
preheader = :() # assignment for iterables to avoid repeated evaluation
preheader_colon = :()
isempty = :(false)
itst = Array(Any,niter) # names for iterator states
ncolon = 0

for i=1:niter # precompute this, we don't have push!
ncolon += (iter0[i] === :(:))
end

colons = Array(Any,ncolon)
iscolon = Array(Any,niter)
needs_firsteval = T === nothing # do we need to unroll the 1st iteration
needs_fallback = T === nothing # do we need to check types
icol = 1
for i = 1:niter
it = iter0[i]
iscol = false
if it === :(:)
colonname = gensym()
colons[icol] = colonname
it = :($colonname = 1:size(v0,$icol))
icol += 1
iscol = true
needs_firsteval = true
end
iscolon[i] = iscol
# generate anonymous name for unused iterators
if !isa(it,Expr) || it.head !== :(=)
it = :($(gensym()) = $it)
end

name = gensym()
itst[i] = gensym()
sz[i] = :(length($name))
itname = it.args[2]
preheader = gen_var(ex -> :(local $(esc(ex))), it.args[1], preheader)
if iscol
preheader_colon = quote
$preheader_colon
$name = $itname
$(itst[i]) = start($name)
$(it.args[1]), $(itst[i]) = next($name, $(itst[i]))
end
sz0[i] = :(0)
iterblock[i] = :($(it.args[1]) = $name)
else
sz0[i] = :(length($name))
preheader = :($preheader; $name = $(esc(itname)))
isempty = :($isempty | (length($name) == 0))
iterblock[i] = :($(esc(it.args[1])) = $name)
end
end

if isdict
key = body.args[1]
body = body.args[2]
if T !== nothing
KT = T.args[1]
T = T.args[2]
end
end

#isempty = foldl((x,i) -> :($x | ($i == 0)), :(false), sz)

loopexpr = if isdict
quote
index = $(esc(key))
v = $(esc(body))
end
else
if ncolon > 0
:(v = $(Expr(:call, TopNode(:getindex), esc(body), colons...)))
else
:(v = $(esc(body)))
end
end
if needs_firsteval
first_it = :()
for i = niter:-1:1
iscolon[i] && continue
it = iterblock[i]
itname = it.args[2]
first_it = quote
$first_it
$(itst[i]) = start($itname)
$(it.args[1]), $(itst[i]) = next($itname, $(itst[i]))
end
end
init = if isdict
quote
index = $(esc(key))
v = $(esc(body))
KT = typeof(index)
T = typeof(v)
result = Dict{KT,T}()
end
else
colon_eval = if ncolon > 0
quote
$preheader_colon
v = $(Expr(:call, TopNode(:getindex), esc(body), colons...))
end
else
:(v = v0)
end
quote
index = 1
v0 = $(esc(body))
$colon_eval
$(if needs_fallback
quote
T = typeof(v)
result = Array(T, $(sz...))
end
else
:(result = Array($(esc(T)), $(sz...)))
end)
end
end
fallback =
if isdict
quote
S = typeof(v)
KS = typeof(index)
if !(S <: T && KS <: KT)
T = typejoin(S,T)
KT = typejoin(KS,KT)
result = convert(Dict{KT,T}, result)
end
end
else
quote
S = typeof(v)
if !(S <: T)
T = typejoin(S,T)
result_next = similar(result,T)
copy!(result_next, 1, result, 1, index-1)
result = result_next
end
end
end

header = quote
$first_it
$init
@goto inner_loop
end
loopexpr = quote
$loopexpr
$(needs_fallback ? fallback : :())
@label inner_loop
end
else
header = if isdict
quote
result = Dict{$(esc(KT)),$(esc(T))}()
end
else
quote
index = 1
result = Array($(esc(T)), $(sz...))
end
end
end
loopexpr = quote
$loopexpr
result[index] = v
end
if !isdict
loopexpr = :($loopexpr; index += 1)
end
for i = 1:niter
it = iterblock[i]
itname = it.args[2]
loopexpr = quote
$(itst[i]) = start($itname)
while !done($itname, $(itst[i]))
tup = next($itname, $(itst[i]))
$(itst[i]) = tup[2]
$(gen_var(ex -> esc(NewvarNode(ex)), it.args[1], :())) # ensure new bindings if we capture the it variable
$(it.args[1]) = tup[1]
$loopexpr
end
end
end
q = if needs_firsteval
quote
let
$preheader
if $isempty
$(isdict ? :(Dict{Union(),Union()}()) : :(Array(Union(), $(sz0...))))
else
$header
$loopexpr
result
end
end
end
else
quote
let
$preheader
$header
$loopexpr
result
end
end
end
q
end

macro comprehension(body, iter...)
gen_comp(nothing,body,iter,false)
end

macro typed_comprehension(T, body, iter...)
gen_comp(T,body,iter,false)
end

macro dict_comprehension(body,iter...)
gen_comp(nothing,body,iter,true)
end

macro typed_dict_comprehension(T,body,iter...)
gen_comp(T,body,iter,true)
end
4 changes: 1 addition & 3 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ symbol(a::Array{UInt8,1}) =
ccall(:jl_symbol_n, Any, (Ptr{UInt8}, Int32), a, length(a))::Symbol
symbol(x...) = symbol(string(x...))

gensym() = ccall(:jl_gensym, Any, ())::Symbol

gensym(s::ASCIIString) = gensym(s.data)
gensym(s::UTF8String) = gensym(s.data)
gensym(a::Array{UInt8,1}) =
Expand Down Expand Up @@ -39,7 +37,7 @@ copy(s::SymbolNode) = SymbolNode(s.name, s.typ)

# copy parts of an AST that the compiler mutates
astcopy(x::Union(SymbolNode,Expr)) = copy(x)
astcopy(x::Array{Any,1}) = Any[astcopy(a) for a in x]
astcopy(x::Array{Any,1}) = map(a->astcopy(a),x)
astcopy(x) = x

==(x::Expr, y::Expr) = x.head === y.head && x.args == y.args
Expand Down
Loading

0 comments on commit c4c6817

Please sign in to comment.