diff --git a/src/symbolic.jl b/src/symbolic.jl index c847a32..ecdb6f9 100644 --- a/src/symbolic.jl +++ b/src/symbolic.jl @@ -102,6 +102,31 @@ function simplify(ex::Expr) return new_ex end +function sum_numeric_args(args) + sum = 0 + sym_args = {} + for arg in args + if isa(arg, Number) + sum += arg + else + sym_args = [sym_args, arg] + end + end + (sum, sym_args) +end + +function mul_numeric_args(args) + prod = 1 + sym_args = {} + for arg in args + if isa(arg, Number) + prod *= arg + else + sym_args = [sym_args, arg] + end + end + (prod, sym_args) +end # Handles all lengths for ex.args # Removes any 0's in a sum @@ -113,6 +138,8 @@ function simplify(::SymbolParameter{:+}, args) elseif length(new_args) == 1 return new_args[1] else + (sum, sym_args) = sum_numeric_args(new_args) + new_args = sum==0 ? sym_args : [sum, sym_args] return Expr(:call, :+, new_args...) end end @@ -144,6 +171,8 @@ function simplify(::SymbolParameter{:*}, args) elseif any(new_args .== 0) return 0 else + (prod, sym_args) = mul_numeric_args(new_args) + new_args = prod==1 ? sym_args : [prod, sym_args] return Expr(:call, :*, new_args...) end end diff --git a/test/symbolic.jl b/test/symbolic.jl index 1332a08..5da0e80 100644 --- a/test/symbolic.jl +++ b/test/symbolic.jl @@ -15,7 +15,7 @@ @assert isequal(differentiate(:(x * a), :x), :a) @assert isequal(differentiate(:(x ^ 2), :x), :(2 * x)) @assert isequal(differentiate(:(a * x ^ 2), :x), :(a * (2 * x))) -@assert isequal(differentiate(:(2 ^ x), :x), :(*(^(2, x), 0.6931471805599453))) +@assert isequal(differentiate(:(2 ^ x), :x), :(*(0.6931471805599453, ^(2, x)))) @assert isequal(differentiate(:(sin(x)), :x), :(cos(x))) @assert isequal(differentiate(:(cos(x)), :x), :(*(-1,sin(x)))) # needs a better simplify @assert isequal(differentiate(:(tan(x)), :x), :(1 + tan(x)^2)) @@ -74,3 +74,17 @@ end @assert isequal(testfun(x), :(^($(x),2))) @assert isequal(testfun(3), 9) @assert isequal(testfun(@sexpr(x+y)), :(^(+($x,$y),2))) + +# +# Simplify tests +# + +@assert isequal(simplify(:(x+y)), :(+(x,y))) +@assert isequal(simplify(:(x+3)), :(+(3,x))) +@assert isequal(simplify(:(x+3+4)), :(+(7,x))) +@assert isequal(simplify(:(2+y+x+3)), :(+(5,y,x))) + +@assert isequal(simplify(:(x*y)), :(*(x,y))) +@assert isequal(simplify(:(x*3)), :(*(3,x))) +@assert isequal(simplify(:(x*3*4)), :(*(12,x))) +@assert isequal(simplify(:(2*y*x*3)), :(*(6,y,x)))