diff --git a/src/functions_math.jl b/src/functions_math.jl index b263532..62a2387 100644 --- a/src/functions_math.jl +++ b/src/functions_math.jl @@ -22,15 +22,20 @@ function matrix_prod_names(a, b) compile_time_return_hack(res) end - -for (NA, NB) in ((1,2), (2,1), (2,2)) #Vector * Vector, is not allowed - @eval function Base.:*(a::NamedDimsArray{A,T,$NA}, b::NamedDimsArray{B,S,$NB}) where {A,B,T,S} - L = matrix_prod_names(A,B) - data = *(parent(a), parent(b)) - return NamedDimsArray{L}(data) +matrix_rdiv_names(a, b) = matrix_prod_names(a, reverse(b)) +matrix_ldiv_names(a, b) = matrix_prod_names(reverse(a), b) +for (NA, NB) in ((1,2), (2,1), (2,2)) + for (func, namemap) in ((:*, :matrix_prod_names), (:/, :matrix_rdiv_names), (:\, :matrix_ldiv_names),) + func==:* && NA==NB==1 && continue #Vector * Vector, is not allowed + @eval function Base.$func(a::NamedDimsArray{A,T,$NA}, b::NamedDimsArray{B,S,$NB}) where {A,B,T,S} + L = $namemap(A,B) + data = $func(parent(a), parent(b)) + return NamedDimsArray{L}(data) + end end end + """ @declare_matmul(MatrixT, VectorT=nothing)