-
Notifications
You must be signed in to change notification settings - Fork 237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrap and test some more Float16 intrinsics #2644
base: master
Are you sure you want to change the base?
Conversation
Too much to hope for that they had secretly included these in |
The |
OK! We now have some working |
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/test/core/device/intrinsics/math.jl b/test/core/device/intrinsics/math.jl
index 474570a4f..6f4e60c1e 100644
--- a/test/core/device/intrinsics/math.jl
+++ b/test/core/device/intrinsics/math.jl
@@ -3,7 +3,7 @@ using SpecialFunctions
@testset "math" begin
@testset "log10" begin
for T in (Float32, Float64)
- @test testf(a->log10.(a), T[100])
+ @test testf(a -> log10.(a), T[100])
end
end
@@ -14,22 +14,22 @@ using SpecialFunctions
@test testf((x,y)->x.^y, rand(Float32, 1), -rand(range, 1))
end
end
-
+
@testset "min/max" begin
for T in (Float32, Float64)
- @test testf((x,y)->max.(x, y), rand(Float32, 1), rand(T, 1))
- @test testf((x,y)->min.(x, y), rand(Float32, 1), rand(T, 1))
+ @test testf((x, y) -> max.(x, y), rand(Float32, 1), rand(T, 1))
+ @test testf((x, y) -> min.(x, y), rand(Float32, 1), rand(T, 1))
end
end
@testset "isinf" begin
- for x in (Inf32, Inf, NaN16, NaN32, NaN)
+ for x in (Inf32, Inf, NaN16, NaN32, NaN)
@test testf(x->isinf.(x), [x])
end
end
@testset "isnan" begin
- for x in (Inf32, Inf, NaN16, NaN32, NaN)
+ for x in (Inf32, Inf, NaN16, NaN32, NaN)
@test testf(x->isnan.(x), [x])
end
end
@@ -104,16 +104,16 @@ using SpecialFunctions
# JuliaGPU/CUDA.jl#1085: exp uses Base.sincos performing a global CPU load
@test testf(x->exp.(x), [1e7im])
end
-
+
@testset "Real - $op" for op in (exp, abs, abs2, exp10, log10)
@testset "$T" for T in (Float16, Float32, Float64)
- @test testf(x->op.(x), rand(T, 1))
+ @test testf(x -> op.(x), rand(T, 1))
end
end
-
- @testset "Float16 - $op" for op in (log,exp,exp2,exp10,log2,log10)
- @testset "$T" for T in (Float16, )
- @test testf(x->op.(x), rand(T, 1))
+
+ @testset "Float16 - $op" for op in (log, exp, exp2, exp10, log2, log10)
+ @testset "$T" for T in (Float16,)
+ @test testf(x -> op.(x), rand(T, 1))
end
end
|
What about more native implementations? using CUDA, LLVM, LLVM.Interop
function asmlog(x::Float16)
log_x = @asmcall("""{
.reg.b32 f, C;
.reg.b16 r,h;
mov.b16 h,\$1;
cvt.f32.f16 f,h;
lg2.approx.ftz.f32 f,f;
mov.b32 C, 0x3f317218U;
mul.f32 f,f,C;
cvt.rn.f16.f32 r,f;
.reg.b16 spc, ulp, p;
mov.b16 spc, 0X160DU;
mov.b16 ulp, 0x9C00U;
set.eq.f16.f16 p, h, spc;
fma.rn.f16 r,p,ulp,r;
mov.b16 spc, 0X3BFEU;
mov.b16 ulp, 0x8010U;
set.eq.f16.f16 p, h, spc;
fma.rn.f16 r,p,ulp,r;
mov.b16 spc, 0X3C0BU;
mov.b16 ulp, 0x8080U;
set.eq.f16.f16 p, h, spc;
fma.rn.f16 r,p,ulp,r;
mov.b16 spc, 0X6051U;
mov.b16 ulp, 0x1C00U;
set.eq.f16.f16 p, h, spc;
fma.rn.f16 r,p,ulp,r;
mov.b16 \$0,r;
}""", "=h,h", Float16, Tuple{Float16}, x)
return log_x
end
function nativelog(h::Float16)
# perform computation in Float32 domain
f = Float32(h)
f = @fastmath log(f)
r = Float16(f)
# handle degenrate cases
r = fma(Float16(h == reinterpret(Float16, 0x160D)), reinterpret(Float16, 0x9C00), r)
r = fma(Float16(h == reinterpret(Float16, 0x3BFE)), reinterpret(Float16, 0x8010), r)
r = fma(Float16(h == reinterpret(Float16, 0x3C0B)), reinterpret(Float16, 0x8080), r)
r = fma(Float16(h == reinterpret(Float16, 0x6051)), reinterpret(Float16, 0x1C00), r)
return r
end
function main()
CUDA.code_ptx(asmlog, Tuple{Float16})
CUDA.code_ptx(nativelog, Tuple{Float16})
return
end
function nativelog10(h::Float16)
# perform computation in Float32 domain
f = Float32(h)
f = @fastmath log10(f)
r = Float16(f)
# handle degenerate cases
r = fma(Float16(h == reinterpret(Float16, 0x338F)), reinterpret(Float16, 0x1000), r)
r = fma(Float16(h == reinterpret(Float16, 0x33F8)), reinterpret(Float16, 0x9000), r)
r = fma(Float16(h == reinterpret(Float16, 0x57E1)), reinterpret(Float16, 0x9800), r)
r = fma(Float16(h == reinterpret(Float16, 0x719D)), reinterpret(Float16, 0x9C00), r)
return r
end
function nativelog2(h::Float16)
# perform computation in Float32 domain
f = Float32(h)
f = @fastmath log2(f)
r = Float16(f)
# handle degenerate cases
r = fma(Float16(r == reinterpret(Float16, 0xA2E2)), reinterpret(Float16, 0x8080), r)
r = fma(Float16(r == reinterpret(Float16, 0xBF46)), reinterpret(Float16, 0x9400), r)
return r
end
It's weird that the special cases are checked against function nativeexp(h::Float16)
# perform computation in Float32 domain
f = Float32(h)
f = fma(f, reinterpret(Float32, 0x3fb8aa3b), reinterpret(Float32, Base.sign_mask(Float32)))
f = @fastmath exp2(f)
r = Float16(f)
# handle degenerate cases
r = fma(Float16(h == reinterpret(Float16, 0x1F79)), reinterpret(Float16, 0x9400), r)
r = fma(Float16(h == reinterpret(Float16, 0x25CF)), reinterpret(Float16, 0x9400), r)
r = fma(Float16(h == reinterpret(Float16, 0xC13B)), reinterpret(Float16, 0x0400), r)
r = fma(Float16(h == reinterpret(Float16, 0xC1EF)), reinterpret(Float16, 0x0200), r)
return r
end
function nativeexp2(h::Float16)
# perform computation in Float32 domain
f = Float32(h)
f = @fastmath exp2(f)
# one ULP adjustement
f = muladd(f, reinterpret(Float32, 0x33800000), f)
r = Float16(f)
return r
end
function nativeexp10(h::Float16)
# perform computation in Float32 domain
f = Float32(h)
f = fma(f, reinterpret(Float32, 0x40549A78), reinterpret(Float32, Base.sign_mask(Float32)))
f = @fastmath exp2(f)
r = Float16(f)
# handle degenerate cases
r = fma(Float16(h == reinterpret(Float16, 0x34DE)), reinterpret(Float16, 0x9800), r)
r = fma(Float16(h == reinterpret(Float16, 0x9766)), reinterpret(Float16, 0x9000), r)
r = fma(Float16(h == reinterpret(Float16, 0x9972)), reinterpret(Float16, 0x1000), r)
r = fma(Float16(h == reinterpret(Float16, 0xA5C4)), reinterpret(Float16, 0x1000), r)
r = fma(Float16(h == reinterpret(Float16, 0xBF0A)), reinterpret(Float16, 0x8100), r)
return r
end
I only did these ports by looking at the assembly, and they would still need to be tested properly. |
I wonder if some of the degenerate cases in the ASM-to-Julia ports above could be written differently (I couldn't find any existing predicates returning those bit values); maybe some floating-point wizards know (cc @oscardssmith)? |
src/device/intrinsics/math.jl
Outdated
.reg.b16 r,h; | ||
mov.b16 h,\$1; | ||
cvt.f32.f16 f,h; | ||
lg2.approx.ftz.f32 f,f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seeing .approx
here; are these the fastmath versions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's from their own log
implementation which is not fast math (https://github.com/cupy/cupy/blob/620183256d25eb463081a8bac2a7a965d35db66b/cupy/_core/include/cupy/_cuda/cuda-12/cuda_fp16.hpp#L2478), I guess they use the approximate method for fp32 and assume it doesn't hurt fp16 accuracy too much?
I'm in favour! My assembly skills are weak enough I didn't want to venture too far out on my own, but others should feel free to push more to this branch or point me at some references :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CUDA.jl Benchmarks
Benchmark suite | Current: 3081234 | Previous: f62af73 | Ratio |
---|---|---|---|
latency/precompile |
46406094960.5 ns |
46667053506.5 ns |
0.99 |
latency/ttfp |
7033081502 ns |
6954689688 ns |
1.01 |
latency/import |
3654105472 ns |
3631856123 ns |
1.01 |
integration/volumerhs |
9625724 ns |
9624743.5 ns |
1.00 |
integration/byval/slices=1 |
146529 ns |
146953 ns |
1.00 |
integration/byval/slices=3 |
425109 ns |
425334 ns |
1.00 |
integration/byval/reference |
144733 ns |
145208 ns |
1.00 |
integration/byval/slices=2 |
285804 ns |
286016 ns |
1.00 |
integration/cudadevrt |
103080.5 ns |
103424 ns |
1.00 |
kernel/indexing |
14017 ns |
14214 ns |
0.99 |
kernel/indexing_checked |
14594 ns |
14910 ns |
0.98 |
kernel/occupancy |
670.1 ns |
637.5449101796407 ns |
1.05 |
kernel/launch |
2036.6 ns |
2102 ns |
0.97 |
kernel/rand |
17997 ns |
18239 ns |
0.99 |
array/reverse/1d |
19676 ns |
19474 ns |
1.01 |
array/reverse/2d |
23915 ns |
23910 ns |
1.00 |
array/reverse/1d_inplace |
10273 ns |
10670 ns |
0.96 |
array/reverse/2d_inplace |
11856 ns |
12291 ns |
0.96 |
array/copy |
21205 ns |
20955 ns |
1.01 |
array/iteration/findall/int |
154962 ns |
155336 ns |
1.00 |
array/iteration/findall/bool |
133464 ns |
133979 ns |
1.00 |
array/iteration/findfirst/int |
153204 ns |
154049 ns |
0.99 |
array/iteration/findfirst/bool |
153040.5 ns |
153056 ns |
1.00 |
array/iteration/scalar |
59613 ns |
61530 ns |
0.97 |
array/iteration/logical |
203196.5 ns |
202309 ns |
1.00 |
array/iteration/findmin/1d |
37943 ns |
37878 ns |
1.00 |
array/iteration/findmin/2d |
93588 ns |
93537 ns |
1.00 |
array/reductions/reduce/1d |
39200.5 ns |
37060.5 ns |
1.06 |
array/reductions/reduce/2d |
50930 ns |
50765 ns |
1.00 |
array/reductions/mapreduce/1d |
36233.5 ns |
36727 ns |
0.99 |
array/reductions/mapreduce/2d |
48224.5 ns |
42618.5 ns |
1.13 |
array/broadcast |
20666 ns |
20743 ns |
1.00 |
array/copyto!/gpu_to_gpu |
11744 ns |
13730.5 ns |
0.86 |
array/copyto!/cpu_to_gpu |
208210 ns |
207788 ns |
1.00 |
array/copyto!/gpu_to_cpu |
241342 ns |
243117 ns |
0.99 |
array/accumulate/1d |
108670 ns |
108517 ns |
1.00 |
array/accumulate/2d |
80054 ns |
79641 ns |
1.01 |
array/construct |
1240.05 ns |
1306.5 ns |
0.95 |
array/random/randn/Float32 |
43773 ns |
43234.5 ns |
1.01 |
array/random/randn!/Float32 |
26796 ns |
26328 ns |
1.02 |
array/random/rand!/Int64 |
26960 ns |
27074 ns |
1.00 |
array/random/rand!/Float32 |
8559.333333333334 ns |
8647.666666666666 ns |
0.99 |
array/random/rand/Int64 |
29942 ns |
29948 ns |
1.00 |
array/random/rand/Float32 |
13087 ns |
13039 ns |
1.00 |
array/permutedims/4d |
61006 ns |
60777 ns |
1.00 |
array/permutedims/2d |
55371 ns |
55571 ns |
1.00 |
array/permutedims/3d |
56266 ns |
55866 ns |
1.01 |
array/sorting/1d |
2776175.5 ns |
2764795 ns |
1.00 |
array/sorting/by |
3367127.5 ns |
3367795 ns |
1.00 |
array/sorting/2d |
1084227 ns |
1084334 ns |
1.00 |
cuda/synchronization/stream/auto |
1035.6 ns |
1052.3 ns |
0.98 |
cuda/synchronization/stream/nonblocking |
6419.4 ns |
6404.4 ns |
1.00 |
cuda/synchronization/stream/blocking |
801.6881720430108 ns |
810.0736842105263 ns |
0.99 |
cuda/synchronization/context/auto |
1173.2 ns |
1185.6 ns |
0.99 |
cuda/synchronization/context/nonblocking |
6674.6 ns |
6726.6 ns |
0.99 |
cuda/synchronization/context/blocking |
909.0869565217391 ns |
925.975 ns |
0.98 |
This comment was automatically generated by workflow using github-action-benchmark.
@device_override function Base.exp(h::Float16) | ||
# perform computation in Float32 domain | ||
f = Float32(h) | ||
f = fma(f, reinterpret(Float32, 0x3fb8aa3b), reinterpret(Float32, Base.sign_mask(Float32))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
julia> reinterpret(UInt32, log2(Float32(ℯ)))
0x3fb8aa3b
and:
julia> reinterpret(Float32, Base.sign_mask(Float32))
-0.0f0
We probably can't rely on the constant evaluation of this, but this code is essentially: f *= log2(Float32(ℯ))
@device_override function Base.exp10(h::Float16) | ||
# perform computation in Float32 domain | ||
f = Float32(h) | ||
f = fma(f, reinterpret(Float32, 0x40549A78), reinterpret(Float32, Base.sign_mask(Float32))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above but
julia> reinterpret(UInt32, log2(10.f0))
0x40549a78
@testset "Float16 - $op" for op in (log,exp,exp2,exp10,log2,log10) | ||
@testset "$T" for T in (Float16, ) | ||
@test testf(x->op.(x), rand(T, 1)) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could test all values here:
julia> all_float_16 = collect(reinterpret(Float16, pattern) for pattern in UInt16(0):UInt16(1):typemax(UInt16))
65536-element Vector{Float16}:
(there might be a better way that avoids some of the duplicated patterns, but it is only 65k in the end)
Otherwise for some of the degenerate cases we might randomly fail if we disagree with Julia.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all_float_16 = collect(reinterpret(Float16, pattern) for pattern in UInt16(0):UInt16(1):typemax(UInt16))
all_float_16 = filter(!isnan, all_float_16)
julia> findall(==(0), exp.(all_float_16) .== Array(exp.(CuArray(all_float_16))))
2-element Vector{Int64}:
8058
9680
julia> all_float_16[8058]
Float16(0.007298)
julia> all_float_16[9680]
Float16(0.02269)
julia> reinterpret(UInt16, all_float_16[8058])
0x1f79
julia> reinterpret(UInt16, all_float_16[9680])
0x25cf
r = fma(Float16(h == reinterpret(Float16, 0x1F79)), reinterpret(Float16, 0x9400), r) | ||
r = fma(Float16(h == reinterpret(Float16, 0x25CF)), reinterpret(Float16, 0x9400), r) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two cause us to disagree with Julia.
julia> findall(==(0), exp.(all_float_16) .== Array(exp.(CuArray(all_float_16))))
2-element Vector{Int64}:
8058
9680
julia> all_float_16[8058]
Float16(0.007298)
julia> all_float_16[9680]
Float16(0.02269)
julia> reinterpret(UInt16, all_float_16[8058]
)
0x1f79
julia> reinterpret(UInt16, all_float_16[9680])
0x25cf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
julia> Float16(exp(Float32(all_float_16[8058])))
Float16(1.008)
julia> exp_cu[8058]
Float16(1.007)
julia> exp_cu[8058] - Float16(exp(Float32(all_float_16[8058])))
Float16(-0.000977)
OK I got a bit turned around here - what's the ask? Can we come up with a list of what needs to be done to get this merged? |
I think we should replace the magic constants with the expressions Valentin figured out. Beyond that, we should think about whether we want to mimic CUDA or Julia here, i.e., whether to keep the adjustments or not. And in any case, make sure there's tests covering those added snippets. |
Since I don't have access to an A100 this is going to be a bit "debugging via CI".