Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Directed rounding #2576

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions docs/src/tutorials/exposing_new_intrinsics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# # Introduction

# * Adding new GPU intrinsics *

# In this tutorial we will expose some GPU intrinsics to allow directed rounding in fused-multiply-add (fma)
# floating point operation
# We start by identifying the intrinsic we want to expose; to do so, we read the PTX (Parallel Thread Execution)
# documentation at [PTX - Floating Point Instructions](https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions).
# In table 32, it is presented a summary of floating point operations: we can construct the intrinsic string from that.
# The FMA instruction for Float32 is presented as `{mad,fma}.rnd.f32`, where `rnd` can assume the values `.rnd = { .rn, .rz, .rm, .rp }`,
# where `rn` is round to nearest, `rz` round to zero, `rm` round to minus infinity, `rp` round to plus infinity.
# When building the intrinsic for the call, we need to change the type `.f64` with `.d` and `.f32` with `.f`
# Therefore, to call the rounded towards infinity `fma` for `.f64` we need to call the intrinsic `llvm.nvvm.fma.rp.d`

fma_rp(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rp.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z)
fma(x::T, y::T, z::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = fma_rp(x, y, z)

# We inspect the PTX code
CUDA.code_ptx(fma_rp, Tuple{Float64,Float64,Float64})

# It is possible to see that the PTX code contains a call to the intrinsic `fma.rp.f64`; we add this function now
# to src/device/intrins/math.jl

function test_fma!(out, x, y)
I = threadIdx().x
z = (2.0) ^ (-(I+53))

out[I] = fma(x, y, z, RoundNearest)
out[I+4] = fma(x, y, z, RoundToZero)
out[I+8] = fma(x, y, z, RoundUp)
out[I+12] = fma(x, y, z, RoundDown)

return
end

# The first four entries of the output are Rounded to Nearest, the entries 5 to 8 are rounded towards zero,
# etc...

out_d = CuArray(zeros(16))
@cuda threads = 4 test_fma!(out_d, 1.0, 1.0)
out_h = Array(out_d)

out_d = CuArray(zeros(4))
@cuda threads = 4 test_fma!(out_d, -1.0, 1.0)
out_h = Array(out_d)

# The binary operations as add, sub, mul, div have been implemented through a macro

function test_add!(out, x, y)
I = threadIdx().x
if I == 1
out[I] = CUDA.add(x, y, RoundNearest)
elseif I == 2
out[I] = CUDA.add(x, y, RoundToZero)
elseif I == 3
out[I] = CUDA.add(x, y, RoundUp)
elseif I == 4
out[I] = CUDA.add(x, y, RoundDown)
end
return
end

out_d = CuArray(zeros(4))
@cuda threads = 4 test_add!(out_d, 1.0, 2^(-54))
out_h = Array(out_d)

function test_sub!(out, x, y)
I = threadIdx().x
if I == 1
out[I] = CUDA.sub(x, y, RoundNearest)
elseif I == 2
out[I] = CUDA.sub(x, y, RoundToZero)
elseif I == 3
out[I] = CUDA.sub(x, y, RoundUp)
elseif I == 4
out[I] = CUDA.sub(x, y, RoundDown)
end
return
end

out_d = CuArray(zeros(4))
@cuda threads = 4 test_sub!(out_d, 1.0, 2^(-53))
out_h = Array(out_d)

function test_mul!(out, x, y)
I = threadIdx().x
if I == 1
out[I] = CUDA.mul(x, y, RoundNearest)
elseif I == 2
out[I] = CUDA.mul(x, y, RoundToZero)
elseif I == 3
out[I] = CUDA.mul(x, y, RoundUp)
elseif I == 4
out[I] = CUDA.mul(x, y, RoundDown)
end
return
end

out_d = CuArray(zeros(4))
@cuda threads = 4 test_mul!(out_d, 1.0 - 2^(-52), 1.0 + 2^(-52))
out_h = Array(out_d)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how this part is still relevant to the 'defining an intrinsic' tutorial?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left only one example

63 changes: 61 additions & 2 deletions src/device/intrinsics/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,18 +390,77 @@ end
@device_function normcdfinv(x::Float64) = ccall("extern __nv_normcdfinv", llvmcall, Cdouble, (Cdouble,), x)
@device_function normcdfinv(x::Float32) = ccall("extern __nv_normcdfinvf", llvmcall, Cfloat, (Cfloat,), x)



Comment on lines -393 to -394
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated change.

#
# Unsorted
#

@device_override Base.hypot(x::Float64, y::Float64) = ccall("extern __nv_hypot", llvmcall, Cdouble, (Cdouble, Cdouble), x, y)
@device_override Base.hypot(x::Float32, y::Float32) = ccall("extern __nv_hypotf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)


for type in [:f, :d]
for round in [:rn, :rz, :rm, :rp]
for op in [:add, :mul, :div]

inp_type = Symbol("Float64")
c_type = Symbol("Cdouble")
if type == :f
inp_type = Symbol("Float32")
c_type = Symbol("Cfloat")
end

func_name = Symbol("$(op)_$(round)")
intrinsic_name = "llvm.nvvm.$(op).$(round).$(type)"
#@info func_name, intrinsic_name

@eval @device_function $func_name(x::$inp_type, y::$inp_type) = ccall($intrinsic_name, llvmcall, $c_type, ($c_type, $c_type), x, y)
end
end
end

@device_function sub_rn(x, y) = add_rn(x, -y)
@device_function sub_rz(x, y) = add_rz(x, -y)
@device_function sub_rm(x, y) = add_rm(x, -y)
@device_function sub_rp(x, y) = add_rp(x, -y)

@device_function add(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = add_rn(x, y)
@device_function add(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = add_rz(x, y)
@device_function add(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = add_rm(x, y)
@device_function add(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = add_rp(x, y)

@device_function sub(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = sub_rn(x, y)
@device_function sub(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = sub_rz(x, y)
@device_function sub(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = sub_rm(x, y)
@device_function sub(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = sub_rp(x, y)

@device_function mul(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = mul_rn(x, y)
@device_function mul(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = mul_rz(x, y)
@device_function mul(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = mul_rm(x, y)
@device_function mul(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = mul_rp(x, y)

@device_function div(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = div_rn(x, y)
@device_function div(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = div_rz(x, y)
@device_function div(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = div_rm(x, y)
@device_function div(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = div_rp(x, y)



@device_override Base.fma(x::Float64, y::Float64, z::Float64) = ccall("extern __nv_fma", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z)
@device_override Base.fma(x::Float32, y::Float32, z::Float32) = ccall("extern __nv_fmaf", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_override Base.fma(x::Float16, y::Float16, z::Float16) = ccall("llvm.fma.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z)
@device_function fma_rn(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rn.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z)
@device_function fma_rn(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rn.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_function fma_rz(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rz.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z)
@device_function fma_rz(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rz.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_function fma_rm(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rm.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z)
@device_function fma_rm(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rm.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_function fma_rp(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rp.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z)
@device_function fma_rp(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rp.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)

@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = fma_rn(x, y, z)
@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = fma_rz(x, y, z)
@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = fma_rm(x, y, z)
@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = fma_rp(x, y, z)

@device_function sad(x::Int32, y::Int32, z::Int32) = ccall("extern __nv_sad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)
@device_function sad(x::UInt32, y::UInt32, z::UInt32) = convert(UInt32, ccall("extern __nv_usad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z))
Expand Down
112 changes: 101 additions & 11 deletions src/device/intrinsics/wmma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ const map_ptx_to_jl_array = Dict(
"s8" => Int8,
"s32" => Int32,
"f16" => Float16,
"f32" => Float32
"f32" => Float32,
"f64" => Float64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated changes?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added intrinsics calls for WMMA with directed rounding modes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you keep that to a separate PR? We also currently don't support Float64 WMMA, see #1426.

)

# Maps PTX types to Julia fragment types
Expand All @@ -24,10 +25,13 @@ const map_ptx_to_jl_frag = Dict(
"s8" => UInt32,
"s32" => Int32,
"f16" => NTuple{2, VecElement{Float16}},
"f32" => Float32
"f32" => Float32,
"f64" => Float64
)

# Maps matrix & PTX types to fragment sizes
# Maps matrix & PTX types to fragment sizes, information retrieved from
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=wmma#matrix-fragments-for-wmma

const map_frag_sizes = Dict(
# A
"a.u8.m16n16k16" => 2,
Expand All @@ -41,6 +45,9 @@ const map_frag_sizes = Dict(
"a.f16.m16n16k16" => 8,
"a.f16.m8n32k16" => 8,
"a.f16.m32n8k16" => 8,

"a.f64.m8n8k4" => 1,

# B
"b.u8.m16n16k16" => 2,
"b.u8.m8n32k16" => 4,
Expand All @@ -53,6 +60,9 @@ const map_frag_sizes = Dict(
"b.f16.m16n16k16" => 8,
"b.f16.m8n32k16" => 8,
"b.f16.m32n8k16" => 8,

"b.f64.m8n8k4" => 1,

# C
"c.s32.m16n16k16" => 8,
"c.s32.m8n32k16" => 8,
Expand All @@ -65,6 +75,12 @@ const map_frag_sizes = Dict(
"c.f32.m16n16k16" => 8,
"c.f32.m8n32k16" => 8,
"c.f32.m32n8k16" => 8,

"c.f64.m8n8k4" => 2, # there is a clash of documentation here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-mma-m8n8k4-with-f64-floating-point-type
# says `A vector expression containing of two .f64 registers containing two .f64 elements from the matrix C.`
# while https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-wmma says 1

# D
"d.s32.m16n16k16" => 8,
"d.s32.m8n32k16" => 8,
Expand All @@ -77,6 +93,8 @@ const map_frag_sizes = Dict(
"d.f32.m16n16k16" => 8,
"d.f32.m8n32k16" => 8,
"d.f32.m32n8k16" => 8,

"d.f64.m8n8k4" => 2,
)

# Maps PTX AS to CUDA.AS
Expand All @@ -96,13 +114,19 @@ const wmma_half_ops = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f
const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"]
const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"]
const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"]

const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops,
ldst_int_ab_ops, ldst_int_cd_ops)
# Double
const ldst_double_ab_ops = [(8, 8, 4)], ["a", "b"], ["f64"]
const ldst_double_cd_ops = [(8, 8, 4)], ["c", "d"], ["f64"]
const wmma_double_ops = [(8, 8, 4)], ["f64"], ["f64"], ["f64"]

const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, ldst_double_ab_ops,
ldst_int_ab_ops, ldst_int_cd_ops, ldst_double_cd_ops)

# the wmma_double_ops will be treated separatedly due to rounding
const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops)

# Valid WMMA operation shapes
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)]
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (8, 8, 4)]

################################################################################
# HELPER FUNCTIONS
Expand Down Expand Up @@ -256,20 +280,21 @@ export llvm_wmma_store
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type]), "_"))

# Name of the LLVM intrinsic
#llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64
llvm_intr = "llvm.nvvm.wmma.$shape.store.$mat.$layout.stride.$elem_type.p$(addr_space_int)"
if LLVM.version() < v"17"
llvm_intr *= "i8"
end

# Determine types + size for this (matrix, elem_type) combination
arr_ty, frag_ty, sz = get_frag_info(mat, elem_type, shape)

ccall_name = "$llvm_intr"
frag_types = ntuple(i -> frag_ty, sz)
frag_vars = ntuple(i -> :(data[$i]), sz)

ptr_ty = :(LLVMPtr{$arr_ty, $addr_space_int})

@eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride)
@eval export $func_name
@eval @doc (@doc llvm_wmma_store) $func_name
Expand All @@ -283,6 +308,7 @@ end
WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{d_elem_type}_{c_elem_type}(a, b, c) or
WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{a_elem_type}(a, b, c)

For double operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{rnd}.{d_elem_type}.{c_elem_type}`
For floating point operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{d_elem_type}.{c_elem_type}`
For all other operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{a_elem_type}`

Expand Down Expand Up @@ -356,6 +382,68 @@ for ops in all_wmma_ops,
@eval @doc (@doc llvm_wmma_mma) $func_name
end

const wmma_double_rounding = ["", "rn", "rz", "rm", "rp"]

for ops in [wmma_double_ops],
a_layout in ["col", "row"],
b_layout in ["col", "row"],
mnk in ops[1],
rnd in wmma_double_rounding

a_elem_type = "f64"
b_elem_type = "f64"
c_elem_type = "f64"
d_elem_type = "f64"

shape = get_hl_shape(mnk[1], mnk[2], mnk[3])

llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$rnd.f64"
if rnd == ""
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.f64"
end
# Name of the Julia wrapper function
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_"))

# Determine types + size for the (matrix, elem_type) combinations for matrix A, B, C and D
a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type, shape)
b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", b_elem_type, shape)
c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", c_elem_type, shape)
d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", d_elem_type, shape)

ccall_name = "$llvm_intr"

a_types = ntuple(i -> a_frag_ty, a_sz)
b_types = ntuple(i -> b_frag_ty, b_sz)
c_types = ntuple(i -> c_frag_ty, c_sz)

a_vars = ntuple(i -> :(a[$i]), a_sz)
b_vars = ntuple(i -> :(b[$i]), b_sz)
c_vars = ntuple(i -> :(c[$i]), c_sz)

if d_sz == 1
@eval $func_name(a, b, c) = tuple(ccall($ccall_name, llvmcall, $d_frag_ty, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
else
struct_ty = Symbol("LLVMStruct$d_sz")
@eval $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
end
@eval export $func_name
@eval @doc (@doc llvm_wmma_mma) $func_name
end

llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_col_col_m8n8k4_f64_rn(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_col_col_m8n8k4_f64_rz(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_col_col_m8n8k4_f64_rp(a_frag, b_frag, c_frag)
llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_col_col_m8n8k4_f64_rm(a_frag, b_frag, c_frag)



# elseif d_elem_type == "f64"
# llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$rnd.f64.f64.f64.f64"
# # Name of the Julia wrapper function
# func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_"))



################################################################################
# FLATTENING/UNFLATTENING LOGIC
################################################################################
Expand Down Expand Up @@ -491,7 +579,9 @@ julia> config = WMMA.Config{16, 16, 16, Float32}
CUDA.WMMA.Config{16, 16, 16, Float32}
```
"""
struct Config{M, N, K, d_type} end
struct ConfigRounding{M, N, K, d_type, rounding} end

Config{M, N, K, d_type} = ConfigRounding{M, N, K, d_type, RoundNearest}

# ---------
# Constants
Expand Down