Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support LinearAlgebra.kron for all GPU backends #558

Open
ytdHuang opened this issue Sep 19, 2024 · 7 comments
Open

Support LinearAlgebra.kron for all GPU backends #558

ytdHuang opened this issue Sep 19, 2024 · 7 comments

Comments

@ytdHuang
Copy link

ytdHuang commented Sep 19, 2024

As @maleadt mentioned in JuliaGPU/Metal.jl#422.

I re-open a new issue here.

The current LinearAlgebra.kron only supports for CuArray, and the other GPUArray uses scalar indexing.

Also, the methods for Kronecker producting Transpose{T,<:AbstractGPUArray} or Adjoint{T,<:AbstractGPUArray} are missing.

@ytdHuang
Copy link
Author

ytdHuang commented Sep 19, 2024

@maleadt
You mentioned about just copying the implementation in CUDA.jl and paste it here.

Do you mean something like the following @@ ?

I just changed CuMatrix to AbstractGPUMatrix, but what about the kernel = @cuda stuff ?

Sorry that I'm not so familiar with the development of GPU codes...

function LinearAlgebra.kron!(C::AbstractGPUMatrix{TC}, A::AbstractGPUMatrix{TA}, B::AbstractGPUMatrix{TB}) where {TA,TB,TC}

    function _kron_mat_kernelA!(C, A, B, m, n, p, q)
        index_i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
        index_j = (blockIdx().y - 1) * blockDim().y + threadIdx().y

        stride_i = blockDim().x * gridDim().x
        stride_j = blockDim().y * gridDim().y

        index_i > m && return
        index_j > n && return

        for i in index_i:stride_i:m
            for j in index_j:stride_j:n
                for k in 1:p
                    for l in 1:q
                        @inbounds C[(i-1)*p+k, (j-1)*q+l] = A[i,j] * B[k,l]
                    end
                end
            end
        end
        return nothing
    end

    function _kron_mat_kernelB!(C, A, B, m, n, p, q)
        index_p = (blockIdx().x - 1) * blockDim().x + threadIdx().x
        index_q = (blockIdx().y - 1) * blockDim().y + threadIdx().y

        stride_p = blockDim().x * gridDim().x
        stride_q = blockDim().y * gridDim().y

        index_p > p && return
        index_q > q && return

        for i in 1:m
            for j in 1:n
                for k in index_p:stride_p:p
                    for l in index_q:stride_q:q
                        @inbounds C[(i-1)*p+k, (j-1)*q+l] = A[i,j] * B[k,l]
                    end
                end
            end
        end
        return nothing
    end

    m, n = size(A)
    p, q = size(B)

    # Use different kernels depending on the size of the matrices
    # choosing to parallelize the matrix with the largest number of elements
    m*n >= p*q ? (kernel = @cuda launch=false _kron_mat_kernelA!(C, A, B, m, n, p, q)) :
                 (kernel = @cuda launch=false _kron_mat_kernelB!(C, A, B, m, n, p, q))

    m*n >= p*q ? (sizes = (m, n)) : (sizes = (p, q))

    config = launch_configuration(kernel.fun)
    dim_ratio = sizes[1] / sizes[2]
    max_threads_i = max(1, floor(Int, sqrt(config.threads * dim_ratio)))
    max_threads_j = max(1, floor(Int, sqrt(config.threads / dim_ratio)))
    max_blocks_i = max(1, floor(Int, sqrt(config.blocks * dim_ratio)))
    max_blocks_j = max(1, floor(Int, sqrt(config.blocks / dim_ratio)))

    threads_i = min(sizes[1], max_threads_i)
    threads_j = min(sizes[2], max_threads_j)
    threads = (threads_i, threads_j)
    blocks_i = min(cld(sizes[1], threads_i), max_blocks_i)
    blocks_j = min(cld(sizes[2], threads_j), max_blocks_j)
    blocks = (blocks_i, blocks_j)

    kernel(C, A, B, m, n, p, q; threads=threads, blocks=blocks)

    return C
end

function LinearAlgebra.kron(A::AbstractGPUMatrix{TA}, B::AbstractGPUMatrix{TB}) where {TA,TB}
    m, n = size(A)
    p, q = size(B)

    T = promote_type(TA, TB)
    C = similar(A, T, m*p, n*q)

    kron!(C, A, B)
end

@ytdHuang ytdHuang changed the title Support for LinearAlgebra.kron Support LinearAlgebra.kron for all GPU backends Sep 19, 2024
@maleadt
Copy link
Member

maleadt commented Sep 19, 2024

Ah, I glanced over the kernels. Yeah those will have to be ported to GPUArrays' DSL (gpu_call), or wait for the impending switch over to KernelAbstractions.jl.

@ytdHuang
Copy link
Author

ytdHuang commented Sep 19, 2024

Hmm, not sure how to contribute

I checked the documentation of GPUArrays.jl, but I don't think I understand lol.

I think you guys define blockidx, blockdim, and threadidx here, and it will be called depend on different backends during execution ?

@ytdHuang
Copy link
Author

Ah, I glanced over the kernels. Yeah those will have to be ported to GPUArrays' DSL (gpu_call), or wait for the impending switch over to KernelAbstractions.jl.

Okay, maybe I will just wait

@maleadt
Copy link
Member

maleadt commented Sep 19, 2024

I think you guys define blockidx, blockdim, and threadidx here, and it will be called depend on different backends during execution ?

The docs are pretty bad, it's easier to look at some existing kernels. It's a very shallow abstraction over the CUDA stuff you're probably familiar with.

But waiting for KA.jl is probably fine as well, if you don't need this urgently.

@ytdHuang
Copy link
Author

I think you guys define blockidx, blockdim, and threadidx here, and it will be called depend on different backends during execution ?

The docs are pretty bad, it's easier to look at some existing kernels. It's a very shallow abstraction over the CUDA stuff you're probably familiar with.

But waiting for KA.jl is probably fine as well, if you don't need this urgently.

So you guys are making GPUArrays.jl depend on KA.jl ?
May I ask how long would it take (approximately) ? Just want to get the picture.

@maleadt
Copy link
Member

maleadt commented Sep 19, 2024

Yes, see the open PRs here and on the various back-ends.
Hard to tell how long it would take. Hopefully in the order of weeks, at most months.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants