Skip to content

Commit

Permalink
[mlir][x86vector] AVX512-BF16 Convert packed F32 to BF16 (llvm#125685)
Browse files Browse the repository at this point in the history
Adds AVX512 bf16 conversion from packed f32 to bf16 elements.

Tests are slightly refactored to better follow file's convention.
  • Loading branch information
adam-smnk authored and wldfngrs committed Feb 19, 2025
1 parent 09c0dd5 commit 47a888a
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 11 deletions.
40 changes: 40 additions & 0 deletions mlir/include/mlir/Dialect/X86Vector/X86Vector.td
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,46 @@ def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
}

//----------------------------------------------------------------------------//
// Convert packed F32 to packed BF16
//----------------------------------------------------------------------------//

def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
AllElementCountsMatch<["a", "dst"]>]> {
let summary = "Convert packed F32 to packed BF16 Data.";
let description = [{
The `convert_f32_to_bf16` op is an AVX512-BF16 specific op that can lower
to the proper LLVMAVX512BF16 operation `llvm.cvtneps2bf16` depending on
the width of MLIR vectors it is applied to.

#### From the Intel Intrinsics Guide:

Convert packed single-precision (32-bit) floating-point elements in `a` to
packed BF16 (16-bit) floating-point elements, and store the results in `dst`.

Example:
```mlir
%dst = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
```
}];
let arguments = (ins VectorOfLengthAndType<[8, 16], [F32]>:$a);
let results = (outs VectorOfLengthAndType<[8, 16], [BF16]>:$dst);
let assemblyFormat =
"$a attr-dict `:` type($a) `->` type($dst)";
}

def CvtNeF32ToBF16Ps256IntrOp : AVX512_IntrOp<"cvtneps2bf16.256", 1, [Pure],
/*extension=*/"bf16"> {
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
let results = (outs VectorOfLengthAndType<[8], [BF16]>:$res);
}

def CvtNeF32ToBF16Ps512IntrOp : AVX512_IntrOp<"cvtneps2bf16.512", 1, [Pure],
/*extension=*/"bf16"> {
let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$a);
let results = (outs VectorOfLengthAndType<[16], [BF16]>:$res);
}

//===----------------------------------------------------------------------===//
// AVX op definitions
//===----------------------------------------------------------------------===//
Expand Down
42 changes: 40 additions & 2 deletions mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,39 @@ struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> {
}
};

struct CvtPackedF32ToBF16Conversion
: public ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op> {
using ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(CvtPackedF32ToBF16Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto typeA = dyn_cast<VectorType>(op.getA().getType());
unsigned elemBitWidth = typeA.getElementTypeBitWidth();
unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth;

auto opType = op.getDst().getType();
auto opA = op.getA();

switch (opBitWidth) {
case 256: {
rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps256IntrOp>(op, opType, opA);
break;
}
case 512: {
rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps512IntrOp>(op, opType, opA);
break;
}
default: {
return rewriter.notifyMatchFailure(
op, "unsupported AVX512-BF16 packed f32 to bf16 variant");
}
}

return success();
}
};

struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;

Expand Down Expand Up @@ -202,8 +235,10 @@ using Registry = RegistryImpl<
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
Registry::registerPatterns(converter, patterns);
patterns.add<MaskCompressOpConversion, DotBF16OpConversion, RsqrtOpConversion,
DotOpConversion>(converter);
patterns
.add<MaskCompressOpConversion, DotBF16OpConversion,
CvtPackedF32ToBF16Conversion, RsqrtOpConversion, DotOpConversion>(
converter);
}

void mlir::configureX86VectorLegalizeForExportTarget(
Expand All @@ -215,6 +250,9 @@ void mlir::configureX86VectorLegalizeForExportTarget(
target.addLegalOp<DotBF16Ps256IntrOp>();
target.addLegalOp<DotBF16Ps512IntrOp>();
target.addIllegalOp<DotBF16Op>();
target.addLegalOp<CvtNeF32ToBF16Ps256IntrOp>();
target.addLegalOp<CvtNeF32ToBF16Ps512IntrOp>();
target.addIllegalOp<CvtPackedF32ToBF16Op>();
target.addLegalOp<RsqrtIntrOp>();
target.addIllegalOp<RsqrtOp>();
target.addLegalOp<DotIntrOp>();
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// REQUIRES: target=x86{{.*}}

// RUN: mlir-opt %s \
// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-translate --mlir-to-llvmir | \
// RUN: llc -mcpu=sapphirerapids | \
// RUN: FileCheck %s

func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
%a: vector<8xf32>) -> vector<8xbf16> {
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
return %0 : vector<8xbf16>
}
// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_256:
// CHECK: vcvtneps2bf16{{.*}}%xmm

func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
%a: vector<16xf32>) -> vector<16xbf16> {
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
return %0 : vector<16xbf16>
}
// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_512:
// CHECK: vcvtneps2bf16{{.*}}%ymm
18 changes: 18 additions & 0 deletions mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,24 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
return %0 : vector<16xf32>
}

// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256
func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
%a: vector<8xf32>) -> (vector<8xbf16>)
{
// CHECK: x86vector.avx512.intr.cvtneps2bf16.256
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
return %0 : vector<8xbf16>
}

// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512
func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
%a: vector<16xf32>) -> (vector<16xbf16>)
{
// CHECK: x86vector.avx512.intr.cvtneps2bf16.512
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
return %0 : vector<16xbf16>
}

// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/Dialect/X86Vector/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,26 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
return %0 : vector<16xf32>
}

// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256
func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
%a: vector<8xf32>) -> (vector<8xbf16>)
{
// CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
// CHECK-SAME: vector<8xf32> -> vector<8xbf16>
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
return %0 : vector<8xbf16>
}

// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512
func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
%a: vector<16xf32>) -> (vector<16xbf16>)
{
// CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
// CHECK-SAME: vector<16xf32> -> vector<16xbf16>
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
return %0 : vector<16xbf16>
}

// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
Expand Down
38 changes: 29 additions & 9 deletions mlir/test/Target/LLVMIR/x86vector.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -62,37 +62,57 @@ llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)

// CHECK-LABEL: define <4 x float> @LLVM_x86_avx512bf16_dpbf16ps_128
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_128(
%arg0: vector<4xf32>, %arg1: vector<8xbf16>, %arg2: vector<8xbf16>
%src: vector<4xf32>, %a: vector<8xbf16>, %b: vector<8xbf16>
) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(
%0 = "x86vector.avx512.intr.dpbf16ps.128"(%arg0, %arg1, %arg2)
%0 = "x86vector.avx512.intr.dpbf16ps.128"(%src, %a, %b)
: (vector<4xf32>, vector<8xbf16>, vector<8xbf16>) -> vector<4xf32>
llvm.return %0 : vector<4xf32>
}

// CHECK-LABEL: define <8 x float> @LLVM_x86_avx512bf16_dpbf16ps_256
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_256(
%arg0: vector<8xf32>, %arg1: vector<16xbf16>, %arg2: vector<16xbf16>
%src: vector<8xf32>, %a: vector<16xbf16>, %b: vector<16xbf16>
) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(
%0 = "x86vector.avx512.intr.dpbf16ps.256"(%arg0, %arg1, %arg2)
%0 = "x86vector.avx512.intr.dpbf16ps.256"(%src, %a, %b)
: (vector<8xf32>, vector<16xbf16>, vector<16xbf16>) -> vector<8xf32>
llvm.return %0 : vector<8xf32>
}

// CHECK-LABEL: define <16 x float> @LLVM_x86_avx512bf16_dpbf16ps_512
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_512(
%arg0: vector<16xf32>, %arg1: vector<32xbf16>, %arg2: vector<32xbf16>
%src: vector<16xf32>, %a: vector<32xbf16>, %b: vector<32xbf16>
) -> vector<16xf32>
{
// CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(
%0 = "x86vector.avx512.intr.dpbf16ps.512"(%arg0, %arg1, %arg2)
%0 = "x86vector.avx512.intr.dpbf16ps.512"(%src, %a, %b)
: (vector<16xf32>, vector<32xbf16>, vector<32xbf16>) -> vector<16xf32>
llvm.return %0 : vector<16xf32>
}

// CHECK-LABEL: define <8 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_256
llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_256(
%a: vector<8xf32>) -> vector<8xbf16>
{
// CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(
%0 = "x86vector.avx512.intr.cvtneps2bf16.256"(%a)
: (vector<8xf32>) -> vector<8xbf16>
llvm.return %0 : vector<8xbf16>
}

// CHECK-LABEL: define <16 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_512
llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
%a: vector<16xf32>) -> vector<16xbf16>
{
// CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(
%0 = "x86vector.avx512.intr.cvtneps2bf16.512"(%a)
: (vector<16xf32>) -> vector<16xbf16>
llvm.return %0 : vector<16xbf16>
}

// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
{
Expand All @@ -103,11 +123,11 @@ llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>

// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_dp_ps_256
llvm.func @LLVM_x86_avx_dp_ps_256(
%arg0: vector<8xf32>, %arg1: vector<8xf32>
%a: vector<8xf32>, %b: vector<8xf32>
) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256(
%0 = llvm.mlir.constant(-1 : i8) : i8
%1 = "x86vector.avx.intr.dp.ps.256"(%arg0, %arg1, %0) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
%c = llvm.mlir.constant(-1 : i8) : i8
%1 = "x86vector.avx.intr.dp.ps.256"(%a, %b, %c) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
llvm.return %1 : vector<8xf32>
}

0 comments on commit 47a888a

Please sign in to comment.