Skip to content

Commit

Permalink
Add a dispatch for LinearAlgebra.norm2
Browse files Browse the repository at this point in the history
`norm(@view x[..], 2)` was previously leading to a call of `LinearAlgebra.generic_norm2` which led to a scalar indexing. This catches such cuda subarray norm2 calls earlier.

Inf-norm and p-norm with cuda subarrays still leads to the following dispatches:
```julia
LinearAlgebra.generic_normInf(x) = float(mapreduce(norm, max, x))
LinearAlgebra.generic_norm1(x) = mapreduce(float ∘ norm, +, x)
```

I am not sure if there is a better way to dispatch them.
  • Loading branch information
sharanry committed Mar 22, 2024
1 parent f5100a1 commit 413c397
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
4 changes: 4 additions & 0 deletions lib/cublas/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ function LinearAlgebra.norm(x::DenseCuArray{<:Union{Float16, ComplexF16, CublasF
end
end

function LinearAlgebra.norm2(x::SubArray{T,N,P} where {T<:Union{Float16, ComplexF16, CublasFloat}, N, P<:DenseCuArray{<:T}})
return nrm2(x)
end

LinearAlgebra.BLAS.asum(x::StridedCuArray{<:CublasFloat}) = asum(length(x), x)

function LinearAlgebra.axpy!(alpha::Number, x::StridedCuArray{T}, y::StridedCuArray{T}) where T<:Union{Float16, ComplexF16, CublasFloat}
Expand Down
14 changes: 14 additions & 0 deletions test/libraries/cublas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1767,6 +1767,20 @@ end
@view(p[reshape(1:(out*inn),out,inn)]) * x
end
end

@testset "nrm2 with strided inputs" begin # JuliaGPU/CUDA.jl#2280
cudaTypes = (Float16, Complex{Float16}, BFloat16, Complex{BFloat16}, Float32, Complex{Float32},
Float64, Complex{Float64}, Int8, Complex{Int8}, UInt8, Complex{UInt8},
Int16, Complex{Int16}, UInt16, Complex{UInt16}, Int32, Complex{Int32},
UInt32, Complex{UInt32}, Int64, Complex{Int64}, UInt64, Complex{UInt64})
for CT in cudaTypes
x = rand(CT, 10, 10, 10)
dx = CuArray(x)
dx_ = @view dx[3:6, 1:5, :]
x_ = @view x[3:6, 1:5, :]
@test norm(dx_, 2) norm(x_, 2)
end
end
end

############################################################################################
Expand Down

0 comments on commit 413c397

Please sign in to comment.