Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap and test some more Float16 intrinsics #2644

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open

Wrap and test some more Float16 intrinsics #2644

wants to merge 5 commits into from

Conversation

kshyatt
Copy link
Contributor

@kshyatt kshyatt commented Feb 6, 2025

Since I don't have access to an A100 this is going to be a bit "debugging via CI".

@kshyatt
Copy link
Contributor Author

kshyatt commented Feb 7, 2025

Too much to hope for that they had secretly included these in libdevice. I'll work on doing the @asmcalls tomorrow.

@maleadt
Copy link
Member

maleadt commented Feb 10, 2025

The libdevice file is simply LLVM bitcode, so you can disassemble it to look at the functions in there. Or check out the documentation: https://docs.nvidia.com/cuda/libdevice-users-guide/index.html (there are some h suffixed functions, at least).

@kshyatt
Copy link
Contributor Author

kshyatt commented Feb 10, 2025

OK! We now have some working Float16 methods for log, exp, log2, log10, exp2 and exp10!

@kshyatt kshyatt marked this pull request as ready for review February 10, 2025 17:29
Copy link
Contributor

github-actions bot commented Feb 10, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/test/core/device/intrinsics/math.jl b/test/core/device/intrinsics/math.jl
index 474570a4f..6f4e60c1e 100644
--- a/test/core/device/intrinsics/math.jl
+++ b/test/core/device/intrinsics/math.jl
@@ -3,7 +3,7 @@ using SpecialFunctions
 @testset "math" begin
     @testset "log10" begin
         for T in (Float32, Float64)
-            @test testf(a->log10.(a), T[100])
+            @test testf(a -> log10.(a), T[100])
         end
     end
 
@@ -14,22 +14,22 @@ using SpecialFunctions
             @test testf((x,y)->x.^y, rand(Float32, 1), -rand(range, 1))
         end
     end
-    
+
     @testset "min/max" begin
         for T in (Float32, Float64)
-            @test testf((x,y)->max.(x, y), rand(Float32, 1), rand(T, 1))
-            @test testf((x,y)->min.(x, y), rand(Float32, 1), rand(T, 1))
+            @test testf((x, y) -> max.(x, y), rand(Float32, 1), rand(T, 1))
+            @test testf((x, y) -> min.(x, y), rand(Float32, 1), rand(T, 1))
         end
     end
 
     @testset "isinf" begin
-      for x in (Inf32, Inf, NaN16, NaN32, NaN)
+        for x in (Inf32, Inf, NaN16, NaN32, NaN)
         @test testf(x->isinf.(x), [x])
       end
     end
 
     @testset "isnan" begin
-      for x in (Inf32, Inf, NaN16, NaN32, NaN)
+        for x in (Inf32, Inf, NaN16, NaN32, NaN)
         @test testf(x->isnan.(x), [x])
       end
     end
@@ -104,16 +104,16 @@ using SpecialFunctions
         # JuliaGPU/CUDA.jl#1085: exp uses Base.sincos performing a global CPU load
         @test testf(x->exp.(x), [1e7im])
     end
-    
+
     @testset "Real - $op" for op in (exp, abs, abs2, exp10, log10)
         @testset "$T" for T in (Float16, Float32, Float64)
-            @test testf(x->op.(x), rand(T, 1))
+            @test testf(x -> op.(x), rand(T, 1))
         end
     end
-    
-    @testset "Float16 - $op" for op in (log,exp,exp2,exp10,log2,log10)
-        @testset "$T" for T in (Float16, )
-            @test testf(x->op.(x), rand(T, 1))
+
+    @testset "Float16 - $op" for op in (log, exp, exp2, exp10, log2, log10)
+        @testset "$T" for T in (Float16,)
+            @test testf(x -> op.(x), rand(T, 1))
         end
     end
 

@maleadt
Copy link
Member

maleadt commented Feb 11, 2025

What about more native implementations?

using CUDA, LLVM, LLVM.Interop

function asmlog(x::Float16)
    log_x = @asmcall("""{
                  .reg.b32        f, C;
                  .reg.b16        r,h;
                  mov.b16         h,\$1;
                  cvt.f32.f16     f,h;
                  lg2.approx.ftz.f32  f,f;
                  mov.b32         C, 0x3f317218U;
                  mul.f32         f,f,C;
                  cvt.rn.f16.f32  r,f;
                  .reg.b16 spc, ulp, p;
                  mov.b16 spc, 0X160DU;
                  mov.b16 ulp, 0x9C00U;
                  set.eq.f16.f16 p, h, spc;
                  fma.rn.f16 r,p,ulp,r;
                  mov.b16 spc, 0X3BFEU;
                  mov.b16 ulp, 0x8010U;
                  set.eq.f16.f16 p, h, spc;
                  fma.rn.f16 r,p,ulp,r;
                  mov.b16 spc, 0X3C0BU;
                  mov.b16 ulp, 0x8080U;
                  set.eq.f16.f16 p, h, spc;
                  fma.rn.f16 r,p,ulp,r;
                  mov.b16 spc, 0X6051U;
                  mov.b16 ulp, 0x1C00U;
                  set.eq.f16.f16 p, h, spc;
                  fma.rn.f16 r,p,ulp,r;
                  mov.b16         \$0,r;
          }""", "=h,h", Float16, Tuple{Float16}, x)
    return log_x
end

function nativelog(h::Float16)
    # perform computation in Float32 domain
    f = Float32(h)
    f = @fastmath log(f)
    r = Float16(f)

    # handle degenrate cases
    r = fma(Float16(h == reinterpret(Float16, 0x160D)), reinterpret(Float16, 0x9C00), r)
    r = fma(Float16(h == reinterpret(Float16, 0x3BFE)), reinterpret(Float16, 0x8010), r)
    r = fma(Float16(h == reinterpret(Float16, 0x3C0B)), reinterpret(Float16, 0x8080), r)
    r = fma(Float16(h == reinterpret(Float16, 0x6051)), reinterpret(Float16, 0x1C00), r)

    return r
end

function main()
    CUDA.code_ptx(asmlog, Tuple{Float16})
    CUDA.code_ptx(nativelog, Tuple{Float16})
    return
end
.visible .func  (.param .b32 func_retval0) julia_asmlog_18627(
	.param .b32 julia_asmlog_18627_param_0
)
{
	.reg .b16 	%h<2>;
	.reg .b16 	%rs<3>;

// %bb.0:                               // %top
	ld.param.u16 	%rs2, [julia_asmlog_18627_param_0];
	// begin inline asm
	{
        .reg.b32        f, C;
        .reg.b16        r,h;
        mov.b16         h,%rs2;
        cvt.f32.f16     f,h;
        lg2.approx.ftz.f32  f,f;
        mov.b32         C, 0x3f317218U;
        mul.f32         f,f,C;
        cvt.rn.f16.f32  r,f;
        .reg.b16 spc, ulp, p;
        mov.b16 spc, 0X160DU;
        mov.b16 ulp, 0x9C00U;
        set.eq.f16.f16 p, h, spc;
        fma.rn.f16 r,p,ulp,r;
        mov.b16 spc, 0X3BFEU;
        mov.b16 ulp, 0x8010U;
        set.eq.f16.f16 p, h, spc;
        fma.rn.f16 r,p,ulp,r;
        mov.b16 spc, 0X3C0BU;
        mov.b16 ulp, 0x8080U;
        set.eq.f16.f16 p, h, spc;
        fma.rn.f16 r,p,ulp,r;
        mov.b16 spc, 0X6051U;
        mov.b16 ulp, 0x1C00U;
        set.eq.f16.f16 p, h, spc;
        fma.rn.f16 r,p,ulp,r;
        mov.b16         %rs1,r;
}
	// end inline asm
	mov.b16 	%h1, %rs1;
	st.param.b16 	[func_retval0+0], %h1;
	ret;
                                        // -- End function
}


.visible .func  (.param .b32 func_retval0) julia_nativelog_18635(
	.param .b32 julia_nativelog_18635_param_0
)
{
	.reg .pred 	%p<5>;
	.reg .b16 	%h<19>;
	.reg .f32 	%f<4>;

// %bb.0:                               // %top
	ld.param.b16 	%h1, [julia_nativelog_18635_param_0];
	cvt.f32.f16 	%f1, %h1;
	lg2.approx.f32 	%f2, %f1;
	mul.f32 	%f3, %f2, 0f3F317218;
	cvt.rn.f16.f32 	%h2, %f3;
	mov.b16 	%h3, 0x160D;
	setp.eq.f16 	%p1, %h1, %h3;
	selp.b16 	%h4, 0x3C00, 0x0000, %p1;
	mov.b16 	%h5, 0x9C00;
	fma.rn.f16 	%h6, %h4, %h5, %h2;
	mov.b16 	%h7, 0x3BFE;
	setp.eq.f16 	%p2, %h1, %h7;
	selp.b16 	%h8, 0x3C00, 0x0000, %p2;
	mov.b16 	%h9, 0x8010;
	fma.rn.f16 	%h10, %h8, %h9, %h6;
	mov.b16 	%h11, 0x3C0B;
	setp.eq.f16 	%p3, %h1, %h11;
	selp.b16 	%h12, 0x3C00, 0x0000, %p3;
	mov.b16 	%h13, 0x8080;
	fma.rn.f16 	%h14, %h12, %h13, %h10;
	mov.b16 	%h15, 0x6051;
	setp.eq.f16 	%p4, %h1, %h15;
	selp.b16 	%h16, 0x3C00, 0x0000, %p4;
	mov.b16 	%h17, 0x1C00;
	fma.rn.f16 	%h18, %h16, %h17, %h14;
	st.param.b16 	[func_retval0+0], %h18;
	ret;
                                        // -- End function
}

function nativelog10(h::Float16)
    # perform computation in Float32 domain
    f = Float32(h)
    f = @fastmath log10(f)
    r = Float16(f)

    # handle degenerate cases
    r = fma(Float16(h == reinterpret(Float16, 0x338F)), reinterpret(Float16, 0x1000), r)
    r = fma(Float16(h == reinterpret(Float16, 0x33F8)), reinterpret(Float16, 0x9000), r)
    r = fma(Float16(h == reinterpret(Float16, 0x57E1)), reinterpret(Float16, 0x9800), r)
    r = fma(Float16(h == reinterpret(Float16, 0x719D)), reinterpret(Float16, 0x9C00), r)

    return r
end
.visible .func  (.param .b32 func_retval0) julia_nativelog10_18750(
	.param .b32 julia_nativelog10_18750_param_0
)
{
	.reg .pred 	%p<5>;
	.reg .b16 	%h<19>;
	.reg .f32 	%f<4>;

// %bb.0:                               // %top
	ld.param.b16 	%h1, [julia_nativelog10_18750_param_0];
	cvt.f32.f16 	%f1, %h1;
	lg2.approx.f32 	%f2, %f1;
	mul.f32 	%f3, %f2, 0f3E9A209B;
	cvt.rn.f16.f32 	%h2, %f3;
	mov.b16 	%h3, 0x338F;
	setp.eq.f16 	%p1, %h1, %h3;
	selp.b16 	%h4, 0x3C00, 0x0000, %p1;
	mov.b16 	%h5, 0x1000;
	fma.rn.f16 	%h6, %h4, %h5, %h2;
	mov.b16 	%h7, 0x33F8;
	setp.eq.f16 	%p2, %h1, %h7;
	selp.b16 	%h8, 0x3C00, 0x0000, %p2;
	mov.b16 	%h9, 0x9000;
	fma.rn.f16 	%h10, %h8, %h9, %h6;
	mov.b16 	%h11, 0x57E1;
	setp.eq.f16 	%p3, %h1, %h11;
	selp.b16 	%h12, 0x3C00, 0x0000, %p3;
	mov.b16 	%h13, 0x9800;
	fma.rn.f16 	%h14, %h12, %h13, %h10;
	mov.b16 	%h15, 0x719D;
	setp.eq.f16 	%p4, %h1, %h15;
	selp.b16 	%h16, 0x3C00, 0x0000, %p4;
	mov.b16 	%h17, 0x9C00;
	fma.rn.f16 	%h18, %h16, %h17, %h14;
	st.param.b16 	[func_retval0+0], %h18;
	ret;
                                        // -- End function
}

function nativelog2(h::Float16)
    # perform computation in Float32 domain
    f = Float32(h)
    f = @fastmath log2(f)
    r = Float16(f)

    # handle degenerate cases
    r = fma(Float16(r == reinterpret(Float16, 0xA2E2)), reinterpret(Float16, 0x8080), r)
    r = fma(Float16(r == reinterpret(Float16, 0xBF46)), reinterpret(Float16, 0x9400), r)

    return r
end
.visible .func  (.param .b32 func_retval0) julia_nativelog2_18847(
	.param .b32 julia_nativelog2_18847_param_0
)
{
	.reg .pred 	%p<3>;
	.reg .b16 	%h<11>;
	.reg .f32 	%f<3>;

// %bb.0:                               // %top
	ld.param.b16 	%h1, [julia_nativelog2_18847_param_0];
	cvt.f32.f16 	%f1, %h1;
	lg2.approx.f32 	%f2, %f1;
	cvt.rn.f16.f32 	%h2, %f2;
	mov.b16 	%h3, 0xA2E2;
	setp.eq.f16 	%p1, %h2, %h3;
	selp.b16 	%h4, 0x3C00, 0x0000, %p1;
	mov.b16 	%h5, 0x8080;
	fma.rn.f16 	%h6, %h4, %h5, %h2;
	mov.b16 	%h7, 0xBF46;
	setp.eq.f16 	%p2, %h6, %h7;
	selp.b16 	%h8, 0x3C00, 0x0000, %p2;
	mov.b16 	%h9, 0x9400;
	fma.rn.f16 	%h10, %h8, %h9, %h6;
	st.param.b16 	[func_retval0+0], %h10;
	ret;
                                        // -- End function
}

It's weird that the special cases are checked against r here, and not the input h.


function nativeexp(h::Float16)
    # perform computation in Float32 domain
    f = Float32(h)
    f = fma(f, reinterpret(Float32, 0x3fb8aa3b), reinterpret(Float32, Base.sign_mask(Float32)))
    f = @fastmath exp2(f)
    r = Float16(f)

    # handle degenerate cases
    r = fma(Float16(h == reinterpret(Float16, 0x1F79)), reinterpret(Float16, 0x9400), r)
    r = fma(Float16(h == reinterpret(Float16, 0x25CF)), reinterpret(Float16, 0x9400), r)
    r = fma(Float16(h == reinterpret(Float16, 0xC13B)), reinterpret(Float16, 0x0400), r)
    r = fma(Float16(h == reinterpret(Float16, 0xC1EF)), reinterpret(Float16, 0x0200), r)

    return r
end
.visible .func  (.param .b32 func_retval0) julia_nativeexp_19019(
	.param .b32 julia_nativeexp_19019_param_0
)
{
	.reg .pred 	%p<5>;
	.reg .b16 	%h<18>;
	.reg .f32 	%f<4>;

// %bb.0:                               // %top
	ld.param.b16 	%h1, [julia_nativeexp_19019_param_0];
	cvt.f32.f16 	%f1, %h1;
	mul.f32 	%f2, %f1, 0f3FB8AA3B;
	ex2.approx.f32 	%f3, %f2;
	cvt.rn.f16.f32 	%h2, %f3;
	mov.b16 	%h3, 0x1F79;
	setp.eq.f16 	%p1, %h1, %h3;
	selp.b16 	%h4, 0x3C00, 0x0000, %p1;
	mov.b16 	%h5, 0x9400;
	fma.rn.f16 	%h6, %h4, %h5, %h2;
	mov.b16 	%h7, 0x25CF;
	setp.eq.f16 	%p2, %h1, %h7;
	selp.b16 	%h8, 0x3C00, 0x0000, %p2;
	fma.rn.f16 	%h9, %h8, %h5, %h6;
	mov.b16 	%h10, 0xC13B;
	setp.eq.f16 	%p3, %h1, %h10;
	selp.b16 	%h11, 0x3C00, 0x0000, %p3;
	mov.b16 	%h12, 0x0400;
	fma.rn.f16 	%h13, %h11, %h12, %h9;
	mov.b16 	%h14, 0xC1EF;
	setp.eq.f16 	%p4, %h1, %h14;
	selp.b16 	%h15, 0x3C00, 0x0000, %p4;
	mov.b16 	%h16, 0x0200;
	fma.rn.f16 	%h17, %h15, %h16, %h13;
	st.param.b16 	[func_retval0+0], %h17;
	ret;
                                        // -- End function
}

function nativeexp2(h::Float16)
    # perform computation in Float32 domain
    f = Float32(h)
    f = @fastmath exp2(f)

    # one ULP adjustement
    f = muladd(f, reinterpret(Float32, 0x33800000), f)
    r = Float16(f)

    return r
end
.visible .func  (.param .b32 func_retval0) julia_nativeexp2_19066(
	.param .b32 julia_nativeexp2_19066_param_0
)
{
	.reg .b16 	%h<3>;
	.reg .f32 	%f<4>;

// %bb.0:                               // %top
	ld.param.b16 	%h1, [julia_nativeexp2_19066_param_0];
	cvt.f32.f16 	%f1, %h1;
	ex2.approx.f32 	%f2, %f1;
	fma.rn.f32 	%f3, %f2, 0f33800000, %f2;
	cvt.rn.f16.f32 	%h2, %f3;
	st.param.b16 	[func_retval0+0], %h2;
	ret;
                                        // -- End function
}

function nativeexp10(h::Float16)
    # perform computation in Float32 domain
    f = Float32(h)
    f = fma(f, reinterpret(Float32, 0x40549A78), reinterpret(Float32, Base.sign_mask(Float32)))
    f = @fastmath exp2(f)
    r = Float16(f)

    # handle degenerate cases
    r = fma(Float16(h == reinterpret(Float16, 0x34DE)), reinterpret(Float16, 0x9800), r)
    r = fma(Float16(h == reinterpret(Float16, 0x9766)), reinterpret(Float16, 0x9000), r)
    r = fma(Float16(h == reinterpret(Float16, 0x9972)), reinterpret(Float16, 0x1000), r)
    r = fma(Float16(h == reinterpret(Float16, 0xA5C4)), reinterpret(Float16, 0x1000), r)
    r = fma(Float16(h == reinterpret(Float16, 0xBF0A)), reinterpret(Float16, 0x8100), r)

    return r
end
.visible .func  (.param .b32 func_retval0) julia_nativeexp10_19202(
	.param .b32 julia_nativeexp10_19202_param_0
)
{
	.reg .pred 	%p<6>;
	.reg .b16 	%h<22>;
	.reg .f32 	%f<4>;

// %bb.0:                               // %top
	ld.param.b16 	%h1, [julia_nativeexp10_19202_param_0];
	cvt.f32.f16 	%f1, %h1;
	mul.f32 	%f2, %f1, 0f40549A78;
	ex2.approx.f32 	%f3, %f2;
	cvt.rn.f16.f32 	%h2, %f3;
	mov.b16 	%h3, 0x34DE;
	setp.eq.f16 	%p1, %h1, %h3;
	selp.b16 	%h4, 0x3C00, 0x0000, %p1;
	mov.b16 	%h5, 0x9800;
	fma.rn.f16 	%h6, %h4, %h5, %h2;
	mov.b16 	%h7, 0x9766;
	setp.eq.f16 	%p2, %h1, %h7;
	selp.b16 	%h8, 0x3C00, 0x0000, %p2;
	mov.b16 	%h9, 0x9000;
	fma.rn.f16 	%h10, %h8, %h9, %h6;
	mov.b16 	%h11, 0x9972;
	setp.eq.f16 	%p3, %h1, %h11;
	selp.b16 	%h12, 0x3C00, 0x0000, %p3;
	mov.b16 	%h13, 0x1000;
	fma.rn.f16 	%h14, %h12, %h13, %h10;
	mov.b16 	%h15, 0xA5C4;
	setp.eq.f16 	%p4, %h1, %h15;
	selp.b16 	%h16, 0x3C00, 0x0000, %p4;
	fma.rn.f16 	%h17, %h16, %h13, %h14;
	mov.b16 	%h18, 0xBF0A;
	setp.eq.f16 	%p5, %h1, %h18;
	selp.b16 	%h19, 0x3C00, 0x0000, %p5;
	mov.b16 	%h20, 0x8100;
	fma.rn.f16 	%h21, %h19, %h20, %h17;
	st.param.b16 	[func_retval0+0], %h21;
	ret;
                                        // -- End function
}

I only did these ports by looking at the assembly, and they would still need to be tested properly.

@maleadt
Copy link
Member

maleadt commented Feb 11, 2025

I wonder if some of the degenerate cases in the ASM-to-Julia ports above could be written differently (I couldn't find any existing predicates returning those bit values); maybe some floating-point wizards know (cc @oscardssmith)?

.reg.b16 r,h;
mov.b16 h,\$1;
cvt.f32.f16 f,h;
lg2.approx.ftz.f32 f,f;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing .approx here; are these the fastmath versions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's from their own log implementation which is not fast math (https://github.com/cupy/cupy/blob/620183256d25eb463081a8bac2a7a965d35db66b/cupy/_core/include/cupy/_cuda/cuda-12/cuda_fp16.hpp#L2478), I guess they use the approximate method for fp32 and assume it doesn't hurt fp16 accuracy too much?

@kshyatt
Copy link
Contributor Author

kshyatt commented Feb 11, 2025

What about more native implementations?

I'm in favour! My assembly skills are weak enough I didn't want to venture too far out on my own, but others should feel free to push more to this branch or point me at some references :)

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA.jl Benchmarks

Benchmark suite Current: 3081234 Previous: f62af73 Ratio
latency/precompile 46406094960.5 ns 46667053506.5 ns 0.99
latency/ttfp 7033081502 ns 6954689688 ns 1.01
latency/import 3654105472 ns 3631856123 ns 1.01
integration/volumerhs 9625724 ns 9624743.5 ns 1.00
integration/byval/slices=1 146529 ns 146953 ns 1.00
integration/byval/slices=3 425109 ns 425334 ns 1.00
integration/byval/reference 144733 ns 145208 ns 1.00
integration/byval/slices=2 285804 ns 286016 ns 1.00
integration/cudadevrt 103080.5 ns 103424 ns 1.00
kernel/indexing 14017 ns 14214 ns 0.99
kernel/indexing_checked 14594 ns 14910 ns 0.98
kernel/occupancy 670.1 ns 637.5449101796407 ns 1.05
kernel/launch 2036.6 ns 2102 ns 0.97
kernel/rand 17997 ns 18239 ns 0.99
array/reverse/1d 19676 ns 19474 ns 1.01
array/reverse/2d 23915 ns 23910 ns 1.00
array/reverse/1d_inplace 10273 ns 10670 ns 0.96
array/reverse/2d_inplace 11856 ns 12291 ns 0.96
array/copy 21205 ns 20955 ns 1.01
array/iteration/findall/int 154962 ns 155336 ns 1.00
array/iteration/findall/bool 133464 ns 133979 ns 1.00
array/iteration/findfirst/int 153204 ns 154049 ns 0.99
array/iteration/findfirst/bool 153040.5 ns 153056 ns 1.00
array/iteration/scalar 59613 ns 61530 ns 0.97
array/iteration/logical 203196.5 ns 202309 ns 1.00
array/iteration/findmin/1d 37943 ns 37878 ns 1.00
array/iteration/findmin/2d 93588 ns 93537 ns 1.00
array/reductions/reduce/1d 39200.5 ns 37060.5 ns 1.06
array/reductions/reduce/2d 50930 ns 50765 ns 1.00
array/reductions/mapreduce/1d 36233.5 ns 36727 ns 0.99
array/reductions/mapreduce/2d 48224.5 ns 42618.5 ns 1.13
array/broadcast 20666 ns 20743 ns 1.00
array/copyto!/gpu_to_gpu 11744 ns 13730.5 ns 0.86
array/copyto!/cpu_to_gpu 208210 ns 207788 ns 1.00
array/copyto!/gpu_to_cpu 241342 ns 243117 ns 0.99
array/accumulate/1d 108670 ns 108517 ns 1.00
array/accumulate/2d 80054 ns 79641 ns 1.01
array/construct 1240.05 ns 1306.5 ns 0.95
array/random/randn/Float32 43773 ns 43234.5 ns 1.01
array/random/randn!/Float32 26796 ns 26328 ns 1.02
array/random/rand!/Int64 26960 ns 27074 ns 1.00
array/random/rand!/Float32 8559.333333333334 ns 8647.666666666666 ns 0.99
array/random/rand/Int64 29942 ns 29948 ns 1.00
array/random/rand/Float32 13087 ns 13039 ns 1.00
array/permutedims/4d 61006 ns 60777 ns 1.00
array/permutedims/2d 55371 ns 55571 ns 1.00
array/permutedims/3d 56266 ns 55866 ns 1.01
array/sorting/1d 2776175.5 ns 2764795 ns 1.00
array/sorting/by 3367127.5 ns 3367795 ns 1.00
array/sorting/2d 1084227 ns 1084334 ns 1.00
cuda/synchronization/stream/auto 1035.6 ns 1052.3 ns 0.98
cuda/synchronization/stream/nonblocking 6419.4 ns 6404.4 ns 1.00
cuda/synchronization/stream/blocking 801.6881720430108 ns 810.0736842105263 ns 0.99
cuda/synchronization/context/auto 1173.2 ns 1185.6 ns 0.99
cuda/synchronization/context/nonblocking 6674.6 ns 6726.6 ns 0.99
cuda/synchronization/context/blocking 909.0869565217391 ns 925.975 ns 0.98

This comment was automatically generated by workflow using github-action-benchmark.

@device_override function Base.exp(h::Float16)
# perform computation in Float32 domain
f = Float32(h)
f = fma(f, reinterpret(Float32, 0x3fb8aa3b), reinterpret(Float32, Base.sign_mask(Float32)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> reinterpret(UInt32, log2(Float32(ℯ)))
0x3fb8aa3b

and:

julia> reinterpret(Float32, Base.sign_mask(Float32))
-0.0f0

We probably can't rely on the constant evaluation of this, but this code is essentially: f *= log2(Float32(ℯ))

@device_override function Base.exp10(h::Float16)
# perform computation in Float32 domain
f = Float32(h)
f = fma(f, reinterpret(Float32, 0x40549A78), reinterpret(Float32, Base.sign_mask(Float32)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above but

julia> reinterpret(UInt32, log2(10.f0))
0x40549a78

Comment on lines +114 to +118
@testset "Float16 - $op" for op in (log,exp,exp2,exp10,log2,log10)
@testset "$T" for T in (Float16, )
@test testf(x->op.(x), rand(T, 1))
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could test all values here:

julia> all_float_16 = collect(reinterpret(Float16, pattern) for pattern in  UInt16(0):UInt16(1):typemax(UInt16))
65536-element Vector{Float16}:

(there might be a better way that avoids some of the duplicated patterns, but it is only 65k in the end)

Otherwise for some of the degenerate cases we might randomly fail if we disagree with Julia.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all_float_16 = collect(reinterpret(Float16, pattern) for pattern in  UInt16(0):UInt16(1):typemax(UInt16))
all_float_16 = filter(!isnan, all_float_16)

julia> findall(==(0), exp.(all_float_16) .== Array(exp.(CuArray(all_float_16))))
2-element Vector{Int64}:
 8058
 9680

julia> all_float_16[8058]
Float16(0.007298)

julia> all_float_16[9680]
Float16(0.02269)

julia> reinterpret(UInt16, all_float_16[8058])
0x1f79

julia> reinterpret(UInt16, all_float_16[9680])
0x25cf

Comment on lines +182 to +183
r = fma(Float16(h == reinterpret(Float16, 0x1F79)), reinterpret(Float16, 0x9400), r)
r = fma(Float16(h == reinterpret(Float16, 0x25CF)), reinterpret(Float16, 0x9400), r)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two cause us to disagree with Julia.

julia> findall(==(0), exp.(all_float_16) .== Array(exp.(CuArray(all_float_16))))
2-element Vector{Int64}:
 8058
 9680

julia> all_float_16[8058]
Float16(0.007298)

julia> all_float_16[9680]
Float16(0.02269)

julia> reinterpret(UInt16, all_float_16[8058]
       )
0x1f79

julia> reinterpret(UInt16, all_float_16[9680])
0x25cf

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> Float16(exp(Float32(all_float_16[8058])))
Float16(1.008)

julia> exp_cu[8058]
Float16(1.007)

julia> exp_cu[8058] - Float16(exp(Float32(all_float_16[8058])))
Float16(-0.000977)

@kshyatt
Copy link
Contributor Author

kshyatt commented Feb 14, 2025

OK I got a bit turned around here - what's the ask? Can we come up with a list of what needs to be done to get this merged?

@maleadt
Copy link
Member

maleadt commented Feb 18, 2025

I think we should replace the magic constants with the expressions Valentin figured out. Beyond that, we should think about whether we want to mimic CUDA or Julia here, i.e., whether to keep the adjustments or not. And in any case, make sure there's tests covering those added snippets.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants