Skip to content

Commit

Permalink
Use the new API
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Jul 7, 2024
1 parent aa97e3a commit 5139cb3
Show file tree
Hide file tree
Showing 9 changed files with 304 additions and 152 deletions.
7 changes: 4 additions & 3 deletions src/MKLSparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ end
# Wrappers generated by Clang.jl
include("libmklsparse.jl")
include("types.jl")
include("mklsparsematrix.jl")

# TODO: BLAS1

# BLAS2 and BLAS3
include("matdescra.jl")
include("generator.jl")
include("matmul.jl")
include("deprecated.jl")
include("generic.jl")
include("interface.jl")

end # module
12 changes: 11 additions & 1 deletion src/generator.jl → src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
matdescra(A::LowerTriangular) = "TLNF"
matdescra(A::UpperTriangular) = "TUNF"
matdescra(A::Diagonal) = "DUNF"
matdescra(A::UnitLowerTriangular) = "TLUF"
matdescra(A::UnitUpperTriangular) = "TUUF"
matdescra(A::Symmetric) = string('S', A.uplo, 'N', 'F')
matdescra(A::Hermitian) = string('H', A.uplo, 'N', 'F')
matdescra(A::SparseMatrixCSC) = "GUUF"
matdescra(A::Transpose) = matdescra(A.parent)
matdescra(A::Adjoint) = matdescra(A.parent)

# The increments to the `__counter` variable is for testing purposes

function _check_transa(t::Char)
Expand Down Expand Up @@ -33,7 +44,6 @@ function cscmv!(transa::Char, α::T, matdescra::String,
_check_transa(transa)
_check_mat_mult_matvec(y, A, x, transa)
__counter[] += 1

T == Float32 && (mkl_scscmv(transa, A.m, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, β, y))
T == Float64 && (mkl_dcscmv(transa, A.m, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, β, y))
T == ComplexF32 && (mkl_ccscmv(transa, A.m, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, β, y))
Expand Down
51 changes: 51 additions & 0 deletions src/generic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
for T in (:Float32, :Float64, :ComplexF32, :ComplexF64)
for SparseMatrix in (:(SparseMatrixCSC{$T,BlasInt}), :(MKLSparse.SparseMatrixCSR{$T,BlasInt}), :(MKLSparse.SparseMatrixCOO{$T,BlasInt}))

fname_mv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mv")
fname_mm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mm")
fname_trsv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsv")
fname_trsm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsm")

@eval begin
function mv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, beta::$T, y::StridedVector{$T})
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
$fname_mv(operation, alpha, MKLSparseMatrix(A), descr, x, beta, y)
return y
end

function mm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, beta::$T, y::StridedMatrix{$T})
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
columns = size(y, 2)
ldx = stride(x, 2)
ldy = stride(y, 2)
$fname_mm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, beta, y, ldy)
return y
end

function trsv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, y::StridedVector{$T})
checksquare(A)
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
$fname_trsv(operation, alpha, MKLSparseMatrix(A), descr, x, y)
return y
end

function trsm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, y::StridedMatrix{$T})
checksquare(A)
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
columns = size(y, 2)
ldx = stride(x, 2)
ldy = stride(y, 2)
$fname_trsm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, y, ldy)
return y
end
end
end
end
87 changes: 87 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import Base: \, *
import LinearAlgebra: mul!, ldiv!

for T in (Float32, Float64, ComplexF32, ComplexF64)

tag_wrappers = ((identity , identity ),
(M -> :(Symmetric{$T, $M}), A -> :(parent($A))),
(M -> :(Hermitian{$T, $M}), A -> :(parent($A))))

triangle_wrappers = ((M -> :(LowerTriangular{$T, $M}) , A -> :(parent($A))),
(M -> :(UnitLowerTriangular{$T, $M}), A -> :(parent($A))),
(M -> :(UpperTriangular{$T, $M}) , A -> :(parent($A))),
(M -> :(UnitUpperTriangular{$T, $M}), A -> :(parent($A))))

op_wrappers = ((identity , 'N', identity ),
(M -> :(Transpose{$T, $M}), 'T', A -> :(parent($A))),
(M -> :(Adjoint{$T, $M}) , 'C', A -> :(parent($A))))

for SparseMatrixType in (:(SparseMatrixCSC{$T, $BlasInt}), :(MKLSparse.SparseMatrixCOO{$T, $BlasInt}), :(MKLSparse.SparseMatrixCSR{$T, $BlasInt}))
for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
TypeA = wrapa(taga(SparseMatrixType))

@eval begin
function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number)
# return cscmv!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), x, $T(beta), y)
return mv!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y)
end

function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number)
# return cscmm!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), B, $T(beta), C)
return mm!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C)
end
end
end

for (trianglea, untrianglea) in triangle_wrappers, (wrapa, transa, unwrapa) in op_wrappers
TypeA = wrapa(trianglea(SparseMatrixType))

@eval begin
function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number)
# return cscmv!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), x, $T(beta), y)
return mv!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y)
end

function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number)
# return cscmm!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), B, $T(beta), C)
return mm!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C)
end

function LinearAlgebra.ldiv!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T})
# return cscsv!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), x, y)
return trsv!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, y)
end

function LinearAlgebra.ldiv!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T})
# return cscsm!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), B, C)
return trsm!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, C)
end

function (*)(A::$TypeA, x::StridedVector{$T})
m, n = size(A)
y = Vector{$T}(undef, m)
return mul!(y, A, x, one($T), zero($T))
end

function (*)(A::$TypeA, B::StridedMatrix{$T})
m, k = size(A)
p, n = size(B)
C = Matrix{$T}(undef, m, n)
return mul!(C, A, B, one($T), zero($T))
end

function (\)(A::$TypeA, x::StridedVector{$T})
n = length(x)
y = Vector{$T}(undef, n)
return ldiv!(y, A, x)
end

function (\)(A::$TypeA, B::StridedMatrix{$T})
m, n = size(B)
C = Matrix{$T}(undef, m, n)
return ldiv!(C, A, B)
end
end
end
end
end
8 changes: 0 additions & 8 deletions src/matdescra.jl

This file was deleted.

101 changes: 0 additions & 101 deletions src/matmul.jl

This file was deleted.

67 changes: 67 additions & 0 deletions src/mklsparsematrix.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
## MKL sparse matrix

# https://github.com/JuliaSmoothOptimizers/SparseMatricesCOO.jl
mutable struct SparseMatrixCOO{Tv,Ti} <: AbstractSparseMatrix{Tv,Ti}
m::Int
n::Int
rows::Vector{Ti}
cols::Vector{Ti}
vals::Vector{Tv}
end

# https://github.com/gridap/SparseMatricesCSR.jl
mutable struct SparseMatrixCSR{Tv,Ti} <: AbstractSparseMatrix{Tv,Ti}
m::Int
n::Int
rowptr::Vector{Ti}
colval::Vector{Ti}
nzval::Vector{Tv}
end

SparseArrays.nnz(A::MKLSparse.SparseMatrixCOO) = length(A.vals)
SparseArrays.nnz(A::MKLSparse.SparseMatrixCSR) = length(A.nzval)

matrixdescra(A::MKLSparse.SparseMatrixCSR) = matrix_descr('G', 'F', 'N')
matrixdescra(A::MKLSparse.SparseMatrixCOO) = matrix_descr('G', 'F', 'N')

mutable struct MKLSparseMatrix
handle::sparse_matrix_t
end

Base.unsafe_convert(::Type{sparse_matrix_t}, desc::MKLSparseMatrix) = desc.handle

for T in (:Float32, :Float64, :ComplexF32, :ComplexF64)

create_coo = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_coo")
create_csc = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_csc")
create_csr = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_csr")

@eval begin
# SparseMatrixCOO
function MKLSparseMatrix(A::MKLSparse.SparseMatrixCOO{$T, BlasInt}, IndexBase::Char='O')
descr_ref = Ref{sparse_matrix_t}()
$create_coo(descr_ref, IndexBase, A.m, A.n, nnz(A), A.rows, A.cols, A.vals)
obj = MKLSparseMatrix(descr_ref[])
finalizer(mkl_sparse_destroy, obj)
return obj
end

# SparseMatrixCSR
function MKLSparseMatrix(A::MKLSparse.SparseMatrixCSR{$T, BlasInt}, IndexBase::Char='O')
descr_ref = Ref{sparse_matrix_t}()
$create_csr(descr_ref, IndexBase, A.m, A.n, A.rowptr, pointer(A.rowptr, 2), A.colval, A.nzval)
obj = MKLSparseMatrix(descr_ref[])
finalizer(mkl_sparse_destroy, obj)
return obj
end

# SparseMatrixCSC
function MKLSparseMatrix(A::SparseMatrixCSC{$T, BlasInt}, IndexBase::Char='O')
descr_ref = Ref{sparse_matrix_t}()
$create_csc(descr_ref, IndexBase, A.m, A.n, A.colptr, pointer(A.colptr, 2), A.rowval, A.nzval)
obj = MKLSparseMatrix(descr_ref[])
finalizer(mkl_sparse_destroy, obj)
return obj
end
end
end
25 changes: 25 additions & 0 deletions src/types.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,30 @@
# MKL sparse types

function mkl_type_specifier(T::Symbol)
if T == :Float32
's'
elseif T == :Float64
'd'
elseif T == :ComplexF32
'c'
elseif T == :ComplexF64
'z'
else
throw(ArgumentError("Unsupported numeric type $T"))
end
end

matrixdescra(A::LowerTriangular) = matrix_descr('T','L','N')
matrixdescra(A::UpperTriangular) = matrix_descr('T','U','N')
matrixdescra(A::Diagonal) = matrix_descr('D','F','N')
matrixdescra(A::UnitLowerTriangular) = matrix_descr('T','L','U')
matrixdescra(A::UnitUpperTriangular) = matrix_descr('T','U','U')
matrixdescra(A::Symmetric) = matrix_descr('S', A.uplo, 'N')
matrixdescra(A::Hermitian) = matrix_descr('H', A.uplo, 'N')
matrixdescra(A::SparseMatrixCSC) = matrix_descr('G', 'F', 'N')
matrixdescra(A::Transpose) = matrixdescra(A.parent)
matrixdescra(A::Adjoint) = matrixdescra(A.parent)

function Base.convert(::Type{sparse_operation_t}, trans::Char)
if trans == 'N'
SPARSE_OPERATION_NON_TRANSPOSE
Expand Down
Loading

0 comments on commit 5139cb3

Please sign in to comment.