From 174206388b03026a5e12473be29e6ab27b8af6f9 Mon Sep 17 00:00:00 2001 From: arun-thmn <93976833+arun-thmn@users.noreply.github.com> Date: Mon, 17 Feb 2025 20:10:09 +0530 Subject: [PATCH] Hoisting vector.transfer operations for bf16 type (#1012) This PR includes changes to extend support for hoisting `vector.transfer read/write` operations outside the `batch` and `k` loop for `bf16` type with `vnni` layout. --- lib/TPP/Transforms/HoistVectorTransfers.cpp | 27 +++--- ...oist-vector-transfer-operation-brgemm.mlir | 89 +++++++++++++++++++ 2 files changed, 104 insertions(+), 12 deletions(-) diff --git a/lib/TPP/Transforms/HoistVectorTransfers.cpp b/lib/TPP/Transforms/HoistVectorTransfers.cpp index 0faf49a40..22ba8f9cc 100644 --- a/lib/TPP/Transforms/HoistVectorTransfers.cpp +++ b/lib/TPP/Transforms/HoistVectorTransfers.cpp @@ -128,6 +128,18 @@ struct HoistVectorTransferOp : OpRewritePattern { std::count(contractIteratorTypes.begin(), contractIteratorTypes.end(), vector::IteratorType::reduction); + if (reductionCount == 0) + return rewriter.notifyMatchFailure( + contractOp, "Matmul operation not supported yet"); + + if (reductionCount == 1) + return rewriter.notifyMatchFailure( + contractOp, "Batch matmul operation not supported yet"); + + if (reductionCount > 3) + return rewriter.notifyMatchFailure( + contractOp, "The vector contract operation is not a gemm"); + auto vectorReadOpLhsType = cast(vectorReadOpLhs.getType()); auto vectorReadOpRhsRank = (cast(vectorReadOpRhs.getType())).getRank(); @@ -137,19 +149,10 @@ struct HoistVectorTransferOp : OpRewritePattern { return rewriter.notifyMatchFailure( contractOp, "Invalid rank for batch reduce operation"); - if (reductionCount == 1) + if (reductionCount == 3 && + (vectorReadOpLhsType.getRank() != 4 || vectorReadOpRhsRank != 4)) return rewriter.notifyMatchFailure( - contractOp, "Batch matmul operation not supported yet"); - - if (reductionCount > 2) - return rewriter.notifyMatchFailure( - contractOp, "The vector contract operation is not a gemm"); - - // Check the K-dim to be 1 - int64_t K = - vectorReadOpLhsType.getDimSize(vectorReadOpLhsType.getRank() - 1); - if (K != 1) - return rewriter.notifyMatchFailure(contractOp, "K dim is not 1"); + contractOp, "Invalid rank for batch reduce operation with vnni layout"); // Check whether the linalg tiling + vector contract pattern matches for the // 4-nested loop structure diff --git a/test/Passes/pass-hoist-vector-transfer-operation-brgemm.mlir b/test/Passes/pass-hoist-vector-transfer-operation-brgemm.mlir index a58e5b7cb..34ed1176e 100644 --- a/test/Passes/pass-hoist-vector-transfer-operation-brgemm.mlir +++ b/test/Passes/pass-hoist-vector-transfer-operation-brgemm.mlir @@ -265,6 +265,95 @@ module { // ----- + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> +module { + memref.global "private" constant @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} + func.func @hoist_gemm_bf16_vnni_layout(%arg0: memref<4x16x64x64xbf16>) -> memref<4x16x64x64xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %cst_0 = arith.constant dense<0.000000e+00> : vector<64x64xbf16> + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %0 = memref.get_global @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x64x64xbf16> + %expand_shape = memref.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [4, 16, 64, 32, 2] : memref<4x16x64x64xbf16> into memref<4x16x64x32x2xbf16> + scf.forall (%arg1, %arg2) in (4, 16) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<4x16x64x64xbf16> to memref<64x64xbf16, strided<[64, 1], offset: ?>> + vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<64x64xbf16>, memref<64x64xbf16, strided<[64, 1], offset: ?>> + %subview_1 = memref.subview %expand_shape[%arg1, 0, 0, 0, 0] [1, 16, 64, 32, 2] [1, 1, 1, 1, 1] : memref<4x16x64x32x2xbf16> to memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> + scf.for %arg3 = %c0 to %c64 step %c32 { + scf.for %arg4 = %c0 to %c64 step %c32 { + %subview_2 = memref.subview %subview[%arg3, %arg4] [32, 32] [1, 1] : memref<64x64xbf16, strided<[64, 1], offset: ?>> to memref<32x32xbf16, strided<[64, 1], offset: ?>> + scf.for %arg5 = %c0 to %c16 step %c1 { + scf.for %arg6 = %c0 to %c32 step %c16 { + %subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6, 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> to memref<1x32x16x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> + %subview_4 = memref.subview %0[%arg5, %arg6, %arg4, 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x32x64x2xbf16> to memref<1x16x32x2xbf16, strided<[4096, 128, 2, 1], offset: ?>> + %1 = vector.transfer_read %subview_3[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[4096, 64, 2, 1], offset: ?>>, vector<1x32x16x2xbf16> + %2 = vector.transfer_read %subview_4[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[4096, 128, 2, 1], offset: ?>>, vector<1x16x32x2xbf16> + %3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x32xbf16, strided<[64, 1], offset: ?>>, vector<32x32xbf16> + %4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %1, %2, %3 : vector<1x32x16x2xbf16>, vector<1x16x32x2xbf16> into vector<32x32xbf16> + vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<32x32xbf16>, memref<32x32xbf16, strided<[64, 1], offset: ?>> + } + } + } + } + } + return %alloc : memref<4x16x64x64xbf16> + } +} + + + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> +// CHECK-LABEL: memref.global "private" constant @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} +// CHECK-LABEL: func.func @hoist_gemm_bf16_vnni_layout( +// CHECK-SAME: %[[VAL_0:.*]]: memref<4x16x64x64xbf16>) -> memref<4x16x64x64xbf16> { +// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<64x64xbf16> +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 16 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 32 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 64 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_8:.*]] = memref.get_global @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> +// CHECK: %[[VAL_9:.*]] = memref.alloc() {alignment = 64 : i64} : memref<4x16x64x64xbf16> +// CHECK: %[[VAL_10:.*]] = memref.expand_shape %[[VAL_0]] {{\[\[}}0], [1], [2], [3, 4]] output_shape [4, 16, 64, 32, 2] : memref<4x16x64x64xbf16> into memref<4x16x64x32x2xbf16> +// CHECK: scf.forall (%[[VAL_11:.*]], %[[VAL_12:.*]]) in (4, 16) { +// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_11]], %[[VAL_12]], 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<4x16x64x64xbf16> to memref<64x64xbf16, strided<[64, 1], offset: ?>> +// CHECK: vector.transfer_write %[[VAL_2]], %[[VAL_13]]{{\[}}%[[VAL_7]], %[[VAL_7]]] {in_bounds = [true, true]} : vector<64x64xbf16>, memref<64x64xbf16, strided<[64, 1], offset: ?>> +// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_11]], 0, 0, 0, 0] [1, 16, 64, 32, 2] [1, 1, 1, 1, 1] : memref<4x16x64x32x2xbf16> to memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_5]] { +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_5]] { +// CHECK: %[[VAL_17:.*]] = memref.subview %[[VAL_13]]{{\[}}%[[VAL_15]], %[[VAL_16]]] [32, 32] [1, 1] : memref<64x64xbf16, strided<[64, 1], offset: ?>> to memref<32x32xbf16, strided<[64, 1], offset: ?>> +// CHECK: %[[VAL_18:.*]] = vector.transfer_read %[[VAL_17]]{{\[}}%[[VAL_7]], %[[VAL_7]]], %[[VAL_1]] {in_bounds = [true, true]} : memref<32x32xbf16, strided<[64, 1], offset: ?>>, vector<32x32xbf16> +// CHECK: %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_21:.*]] = %[[VAL_18]]) -> (vector<32x32xbf16>) { +// CHECK: %[[VAL_22:.*]] = scf.for %[[VAL_23:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_24:.*]] = %[[VAL_21]]) -> (vector<32x32xbf16>) { +// CHECK: %[[VAL_25:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_20]], %[[VAL_15]], %[[VAL_23]], 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> to memref<1x32x16x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> +// CHECK: %[[VAL_26:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_20]], %[[VAL_23]], %[[VAL_16]], 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x32x64x2xbf16> to memref<1x16x32x2xbf16, strided<[4096, 128, 2, 1], offset: ?>> +// CHECK: %[[VAL_27:.*]] = vector.transfer_read %[[VAL_25]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_1]] {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[4096, 64, 2, 1], offset: ?>>, vector<1x32x16x2xbf16> +// CHECK: %[[VAL_28:.*]] = vector.transfer_read %[[VAL_26]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_1]] {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[4096, 128, 2, 1], offset: ?>>, vector<1x16x32x2xbf16> +// CHECK: %[[VAL_29:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %4, %5, %arg8 : vector<1x32x16x2xbf16>, vector<1x16x32x2xbf16> into vector<32x32xbf16> +// CHECK: scf.yield %[[VAL_29]] : vector<32x32xbf16> +// CHECK: } +// CHECK: scf.yield %[[VAL_22]] : vector<32x32xbf16> +// CHECK: } +// CHECK: vector.transfer_write %[[VAL_19]], %[[VAL_17]]{{\[}}%[[VAL_7]], %[[VAL_7]]] {in_bounds = [true, true]} : vector<32x32xbf16>, memref<32x32xbf16, strided<[64, 1], offset: ?>> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return %[[VAL_9]] : memref<4x16x64x64xbf16> +// CHECK: } + +// ----- + + #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>