Skip to content

Commit

Permalink
Hoisting vector.transfer operations for bf16 type (#1012)
Browse files Browse the repository at this point in the history
This PR includes changes to extend support for hoisting `vector.transfer
read/write` operations outside the `batch` and `k` loop for `bf16` type
with `vnni` layout.
  • Loading branch information
arun-thmn authored Feb 17, 2025
1 parent bac688f commit 1742063
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 12 deletions.
27 changes: 15 additions & 12 deletions lib/TPP/Transforms/HoistVectorTransfers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,18 @@ struct HoistVectorTransferOp : OpRewritePattern<vector::ContractionOp> {
std::count(contractIteratorTypes.begin(), contractIteratorTypes.end(),
vector::IteratorType::reduction);

if (reductionCount == 0)
return rewriter.notifyMatchFailure(
contractOp, "Matmul operation not supported yet");

if (reductionCount == 1)
return rewriter.notifyMatchFailure(
contractOp, "Batch matmul operation not supported yet");

if (reductionCount > 3)
return rewriter.notifyMatchFailure(
contractOp, "The vector contract operation is not a gemm");

auto vectorReadOpLhsType = cast<ShapedType>(vectorReadOpLhs.getType());
auto vectorReadOpRhsRank =
(cast<ShapedType>(vectorReadOpRhs.getType())).getRank();
Expand All @@ -137,19 +149,10 @@ struct HoistVectorTransferOp : OpRewritePattern<vector::ContractionOp> {
return rewriter.notifyMatchFailure(
contractOp, "Invalid rank for batch reduce operation");

if (reductionCount == 1)
if (reductionCount == 3 &&
(vectorReadOpLhsType.getRank() != 4 || vectorReadOpRhsRank != 4))
return rewriter.notifyMatchFailure(
contractOp, "Batch matmul operation not supported yet");

if (reductionCount > 2)
return rewriter.notifyMatchFailure(
contractOp, "The vector contract operation is not a gemm");

// Check the K-dim to be 1
int64_t K =
vectorReadOpLhsType.getDimSize(vectorReadOpLhsType.getRank() - 1);
if (K != 1)
return rewriter.notifyMatchFailure(contractOp, "K dim is not 1");
contractOp, "Invalid rank for batch reduce operation with vnni layout");

// Check whether the linalg tiling + vector contract pattern matches for the
// 4-nested loop structure
Expand Down
89 changes: 89 additions & 0 deletions test/Passes/pass-hoist-vector-transfer-operation-brgemm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,95 @@ module {

// -----


#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
module {
memref.global "private" constant @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64}
func.func @hoist_gemm_bf16_vnni_layout(%arg0: memref<4x16x64x64xbf16>) -> memref<4x16x64x64xbf16> {
%cst = arith.constant 0.000000e+00 : bf16
%cst_0 = arith.constant dense<0.000000e+00> : vector<64x64xbf16>
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%0 = memref.get_global @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x64x64xbf16>
%expand_shape = memref.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [4, 16, 64, 32, 2] : memref<4x16x64x64xbf16> into memref<4x16x64x32x2xbf16>
scf.forall (%arg1, %arg2) in (4, 16) {
%subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<4x16x64x64xbf16> to memref<64x64xbf16, strided<[64, 1], offset: ?>>
vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<64x64xbf16>, memref<64x64xbf16, strided<[64, 1], offset: ?>>
%subview_1 = memref.subview %expand_shape[%arg1, 0, 0, 0, 0] [1, 16, 64, 32, 2] [1, 1, 1, 1, 1] : memref<4x16x64x32x2xbf16> to memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>>
scf.for %arg3 = %c0 to %c64 step %c32 {
scf.for %arg4 = %c0 to %c64 step %c32 {
%subview_2 = memref.subview %subview[%arg3, %arg4] [32, 32] [1, 1] : memref<64x64xbf16, strided<[64, 1], offset: ?>> to memref<32x32xbf16, strided<[64, 1], offset: ?>>
scf.for %arg5 = %c0 to %c16 step %c1 {
scf.for %arg6 = %c0 to %c32 step %c16 {
%subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg6, 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> to memref<1x32x16x2xbf16, strided<[4096, 64, 2, 1], offset: ?>>
%subview_4 = memref.subview %0[%arg5, %arg6, %arg4, 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x32x64x2xbf16> to memref<1x16x32x2xbf16, strided<[4096, 128, 2, 1], offset: ?>>
%1 = vector.transfer_read %subview_3[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[4096, 64, 2, 1], offset: ?>>, vector<1x32x16x2xbf16>
%2 = vector.transfer_read %subview_4[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[4096, 128, 2, 1], offset: ?>>, vector<1x16x32x2xbf16>
%3 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<32x32xbf16, strided<[64, 1], offset: ?>>, vector<32x32xbf16>
%4 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %2, %3 : vector<1x32x16x2xbf16>, vector<1x16x32x2xbf16> into vector<32x32xbf16>
vector.transfer_write %4, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<32x32xbf16>, memref<32x32xbf16, strided<[64, 1], offset: ?>>
}
}
}
}
}
return %alloc : memref<4x16x64x64xbf16>
}
}



// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
// CHECK-LABEL: memref.global "private" constant @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64}
// CHECK-LABEL: func.func @hoist_gemm_bf16_vnni_layout(
// CHECK-SAME: %[[VAL_0:.*]]: memref<4x16x64x64xbf16>) -> memref<4x16x64x64xbf16> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : bf16
// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<64x64xbf16>
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 16 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 32 : index
// CHECK: %[[VAL_6:.*]] = arith.constant 64 : index
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_8:.*]] = memref.get_global @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16>
// CHECK: %[[VAL_9:.*]] = memref.alloc() {alignment = 64 : i64} : memref<4x16x64x64xbf16>
// CHECK: %[[VAL_10:.*]] = memref.expand_shape %[[VAL_0]] {{\[\[}}0], [1], [2], [3, 4]] output_shape [4, 16, 64, 32, 2] : memref<4x16x64x64xbf16> into memref<4x16x64x32x2xbf16>
// CHECK: scf.forall (%[[VAL_11:.*]], %[[VAL_12:.*]]) in (4, 16) {
// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_11]], %[[VAL_12]], 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<4x16x64x64xbf16> to memref<64x64xbf16, strided<[64, 1], offset: ?>>
// CHECK: vector.transfer_write %[[VAL_2]], %[[VAL_13]]{{\[}}%[[VAL_7]], %[[VAL_7]]] {in_bounds = [true, true]} : vector<64x64xbf16>, memref<64x64xbf16, strided<[64, 1], offset: ?>>
// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_11]], 0, 0, 0, 0] [1, 16, 64, 32, 2] [1, 1, 1, 1, 1] : memref<4x16x64x32x2xbf16> to memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>>
// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_5]] {
// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_5]] {
// CHECK: %[[VAL_17:.*]] = memref.subview %[[VAL_13]]{{\[}}%[[VAL_15]], %[[VAL_16]]] [32, 32] [1, 1] : memref<64x64xbf16, strided<[64, 1], offset: ?>> to memref<32x32xbf16, strided<[64, 1], offset: ?>>
// CHECK: %[[VAL_18:.*]] = vector.transfer_read %[[VAL_17]]{{\[}}%[[VAL_7]], %[[VAL_7]]], %[[VAL_1]] {in_bounds = [true, true]} : memref<32x32xbf16, strided<[64, 1], offset: ?>>, vector<32x32xbf16>
// CHECK: %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_21:.*]] = %[[VAL_18]]) -> (vector<32x32xbf16>) {
// CHECK: %[[VAL_22:.*]] = scf.for %[[VAL_23:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_24:.*]] = %[[VAL_21]]) -> (vector<32x32xbf16>) {
// CHECK: %[[VAL_25:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_20]], %[[VAL_15]], %[[VAL_23]], 0] [1, 32, 16, 2] [1, 1, 1, 1] : memref<16x64x32x2xbf16, strided<[4096, 64, 2, 1], offset: ?>> to memref<1x32x16x2xbf16, strided<[4096, 64, 2, 1], offset: ?>>
// CHECK: %[[VAL_26:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_20]], %[[VAL_23]], %[[VAL_16]], 0] [1, 16, 32, 2] [1, 1, 1, 1] : memref<16x32x64x2xbf16> to memref<1x16x32x2xbf16, strided<[4096, 128, 2, 1], offset: ?>>
// CHECK: %[[VAL_27:.*]] = vector.transfer_read %[[VAL_25]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_1]] {in_bounds = [true, true, true, true]} : memref<1x32x16x2xbf16, strided<[4096, 64, 2, 1], offset: ?>>, vector<1x32x16x2xbf16>
// CHECK: %[[VAL_28:.*]] = vector.transfer_read %[[VAL_26]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_1]] {in_bounds = [true, true, true, true]} : memref<1x16x32x2xbf16, strided<[4096, 128, 2, 1], offset: ?>>, vector<1x16x32x2xbf16>
// CHECK: %[[VAL_29:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8 : vector<1x32x16x2xbf16>, vector<1x16x32x2xbf16> into vector<32x32xbf16>
// CHECK: scf.yield %[[VAL_29]] : vector<32x32xbf16>
// CHECK: }
// CHECK: scf.yield %[[VAL_22]] : vector<32x32xbf16>
// CHECK: }
// CHECK: vector.transfer_write %[[VAL_19]], %[[VAL_17]]{{\[}}%[[VAL_7]], %[[VAL_7]]] {in_bounds = [true, true]} : vector<32x32xbf16>, memref<32x32xbf16, strided<[64, 1], offset: ?>>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: return %[[VAL_9]] : memref<4x16x64x64xbf16>
// CHECK: }

// -----


#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
Expand Down

0 comments on commit 1742063

Please sign in to comment.