diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 16181d7e760db..566013e73f4b8 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -341,6 +341,46 @@ def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure, let results = (outs VectorOfLengthAndType<[16], [F32]>:$res); } +//----------------------------------------------------------------------------// +// Convert packed F32 to packed BF16 +//----------------------------------------------------------------------------// + +def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure, + AllElementCountsMatch<["a", "dst"]>]> { + let summary = "Convert packed F32 to packed BF16 Data."; + let description = [{ + The `convert_f32_to_bf16` op is an AVX512-BF16 specific op that can lower + to the proper LLVMAVX512BF16 operation `llvm.cvtneps2bf16` depending on + the width of MLIR vectors it is applied to. + + #### From the Intel Intrinsics Guide: + + Convert packed single-precision (32-bit) floating-point elements in `a` to + packed BF16 (16-bit) floating-point elements, and store the results in `dst`. + + Example: + ```mlir + %dst = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16> + ``` + }]; + let arguments = (ins VectorOfLengthAndType<[8, 16], [F32]>:$a); + let results = (outs VectorOfLengthAndType<[8, 16], [BF16]>:$dst); + let assemblyFormat = + "$a attr-dict `:` type($a) `->` type($dst)"; +} + +def CvtNeF32ToBF16Ps256IntrOp : AVX512_IntrOp<"cvtneps2bf16.256", 1, [Pure], + /*extension=*/"bf16"> { + let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a); + let results = (outs VectorOfLengthAndType<[8], [BF16]>:$res); +} + +def CvtNeF32ToBF16Ps512IntrOp : AVX512_IntrOp<"cvtneps2bf16.512", 1, [Pure], + /*extension=*/"bf16"> { + let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$a); + let results = (outs VectorOfLengthAndType<[16], [BF16]>:$res); +} + //===----------------------------------------------------------------------===// // AVX op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp index 260ac9ce589a3..f1fbb39b97fc4 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -131,6 +131,39 @@ struct DotBF16OpConversion : public ConvertOpToLLVMPattern { } }; +struct CvtPackedF32ToBF16Conversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(CvtPackedF32ToBF16Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto typeA = dyn_cast(op.getA().getType()); + unsigned elemBitWidth = typeA.getElementTypeBitWidth(); + unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth; + + auto opType = op.getDst().getType(); + auto opA = op.getA(); + + switch (opBitWidth) { + case 256: { + rewriter.replaceOpWithNewOp(op, opType, opA); + break; + } + case 512: { + rewriter.replaceOpWithNewOp(op, opType, opA); + break; + } + default: { + return rewriter.notifyMatchFailure( + op, "unsupported AVX512-BF16 packed f32 to bf16 variant"); + } + } + + return success(); + } +}; + struct RsqrtOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -202,8 +235,10 @@ using Registry = RegistryImpl< void mlir::populateX86VectorLegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns) { Registry::registerPatterns(converter, patterns); - patterns.add(converter); + patterns + .add( + converter); } void mlir::configureX86VectorLegalizeForExportTarget( @@ -215,6 +250,9 @@ void mlir::configureX86VectorLegalizeForExportTarget( target.addLegalOp(); target.addLegalOp(); target.addIllegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addIllegalOp(); target.addLegalOp(); target.addIllegalOp(); target.addLegalOp(); diff --git a/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir b/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir new file mode 100644 index 0000000000000..c97c52f01c3b0 --- /dev/null +++ b/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir @@ -0,0 +1,24 @@ +// REQUIRES: target=x86{{.*}} + +// RUN: mlir-opt %s \ +// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-translate --mlir-to-llvmir | \ +// RUN: llc -mcpu=sapphirerapids | \ +// RUN: FileCheck %s + +func.func @avx512bf16_cvt_packed_f32_to_bf16_256( + %a: vector<8xf32>) -> vector<8xbf16> { + %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16> + return %0 : vector<8xbf16> +} +// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_256: +// CHECK: vcvtneps2bf16{{.*}}%xmm + +func.func @avx512bf16_cvt_packed_f32_to_bf16_512( + %a: vector<16xf32>) -> vector<16xbf16> { + %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16> + return %0 : vector<16xbf16> +} +// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_512: +// CHECK: vcvtneps2bf16{{.*}}%ymm diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir index ed9177eaec9ce..59be7dd75b3b0 100644 --- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir @@ -70,6 +70,24 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>, return %0 : vector<16xf32> } +// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256 +func.func @avx512bf16_cvt_packed_f32_to_bf16_256( + %a: vector<8xf32>) -> (vector<8xbf16>) +{ + // CHECK: x86vector.avx512.intr.cvtneps2bf16.256 + %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16> + return %0 : vector<8xbf16> +} + +// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512 +func.func @avx512bf16_cvt_packed_f32_to_bf16_512( + %a: vector<16xf32>) -> (vector<16xbf16>) +{ + // CHECK: x86vector.avx512.intr.cvtneps2bf16.512 + %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16> + return %0 : vector<16xbf16> +} + // CHECK-LABEL: func @avx_rsqrt func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>) { diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir index cf74a7ee60255..0d00448c63da8 100644 --- a/mlir/test/Dialect/X86Vector/roundtrip.mlir +++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir @@ -74,6 +74,26 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>, return %0 : vector<16xf32> } +// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256 +func.func @avx512bf16_cvt_packed_f32_to_bf16_256( + %a: vector<8xf32>) -> (vector<8xbf16>) +{ + // CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} : + // CHECK-SAME: vector<8xf32> -> vector<8xbf16> + %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16> + return %0 : vector<8xbf16> +} + +// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512 +func.func @avx512bf16_cvt_packed_f32_to_bf16_512( + %a: vector<16xf32>) -> (vector<16xbf16>) +{ + // CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} : + // CHECK-SAME: vector<16xf32> -> vector<16xbf16> + %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16> + return %0 : vector<16xbf16> +} + // CHECK-LABEL: func @avx_rsqrt func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>) { diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir index 1df03f10c9321..db1c10cd5cd37 100644 --- a/mlir/test/Target/LLVMIR/x86vector.mlir +++ b/mlir/test/Target/LLVMIR/x86vector.mlir @@ -62,37 +62,57 @@ llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>) // CHECK-LABEL: define <4 x float> @LLVM_x86_avx512bf16_dpbf16ps_128 llvm.func @LLVM_x86_avx512bf16_dpbf16ps_128( - %arg0: vector<4xf32>, %arg1: vector<8xbf16>, %arg2: vector<8xbf16> + %src: vector<4xf32>, %a: vector<8xbf16>, %b: vector<8xbf16> ) -> vector<4xf32> { // CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128( - %0 = "x86vector.avx512.intr.dpbf16ps.128"(%arg0, %arg1, %arg2) + %0 = "x86vector.avx512.intr.dpbf16ps.128"(%src, %a, %b) : (vector<4xf32>, vector<8xbf16>, vector<8xbf16>) -> vector<4xf32> llvm.return %0 : vector<4xf32> } // CHECK-LABEL: define <8 x float> @LLVM_x86_avx512bf16_dpbf16ps_256 llvm.func @LLVM_x86_avx512bf16_dpbf16ps_256( - %arg0: vector<8xf32>, %arg1: vector<16xbf16>, %arg2: vector<16xbf16> + %src: vector<8xf32>, %a: vector<16xbf16>, %b: vector<16xbf16> ) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256( - %0 = "x86vector.avx512.intr.dpbf16ps.256"(%arg0, %arg1, %arg2) + %0 = "x86vector.avx512.intr.dpbf16ps.256"(%src, %a, %b) : (vector<8xf32>, vector<16xbf16>, vector<16xbf16>) -> vector<8xf32> llvm.return %0 : vector<8xf32> } // CHECK-LABEL: define <16 x float> @LLVM_x86_avx512bf16_dpbf16ps_512 llvm.func @LLVM_x86_avx512bf16_dpbf16ps_512( - %arg0: vector<16xf32>, %arg1: vector<32xbf16>, %arg2: vector<32xbf16> + %src: vector<16xf32>, %a: vector<32xbf16>, %b: vector<32xbf16> ) -> vector<16xf32> { // CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512( - %0 = "x86vector.avx512.intr.dpbf16ps.512"(%arg0, %arg1, %arg2) + %0 = "x86vector.avx512.intr.dpbf16ps.512"(%src, %a, %b) : (vector<16xf32>, vector<32xbf16>, vector<32xbf16>) -> vector<16xf32> llvm.return %0 : vector<16xf32> } +// CHECK-LABEL: define <8 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_256 +llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_256( + %a: vector<8xf32>) -> vector<8xbf16> +{ + // CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256( + %0 = "x86vector.avx512.intr.cvtneps2bf16.256"(%a) + : (vector<8xf32>) -> vector<8xbf16> + llvm.return %0 : vector<8xbf16> +} + +// CHECK-LABEL: define <16 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_512 +llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_512( + %a: vector<16xf32>) -> vector<16xbf16> +{ + // CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512( + %0 = "x86vector.avx512.intr.cvtneps2bf16.512"(%a) + : (vector<16xf32>) -> vector<16xbf16> + llvm.return %0 : vector<16xbf16> +} + // CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256 llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32> { @@ -103,11 +123,11 @@ llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32> // CHECK-LABEL: define <8 x float> @LLVM_x86_avx_dp_ps_256 llvm.func @LLVM_x86_avx_dp_ps_256( - %arg0: vector<8xf32>, %arg1: vector<8xf32> + %a: vector<8xf32>, %b: vector<8xf32> ) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256( - %0 = llvm.mlir.constant(-1 : i8) : i8 - %1 = "x86vector.avx.intr.dp.ps.256"(%arg0, %arg1, %0) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32> + %c = llvm.mlir.constant(-1 : i8) : i8 + %1 = "x86vector.avx.intr.dp.ps.256"(%a, %b, %c) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32> llvm.return %1 : vector<8xf32> }