Skip to content

Commit

Permalink
code-refactoring + 1 updated test-case
Browse files Browse the repository at this point in the history
  • Loading branch information
Arun Thangamani committed Feb 19, 2025
1 parent 0923d3e commit 34e5283
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 23 deletions.
32 changes: 13 additions & 19 deletions lib/TPP/Transforms/BrgemmLinalgTiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,21 @@ struct LinalgOpTiling : OpRewritePattern<BrgemmOp> {
std::count(brgemmIteratorTypes.begin(), brgemmIteratorTypes.end(),
utils::IteratorType::reduction);

if (reductionCount == 0)
if (reductionCount == 0 || reductionCount > 3)
return rewriter.notifyMatchFailure(brgemmOp,
"Matmul operation not supported yet");
"Excepted GEMM like operation");

if (reductionCount == 1)
return rewriter.notifyMatchFailure(
brgemmOp, "Batch matmul operation not supported yet");

if (reductionCount > 3)
return rewriter.notifyMatchFailure(brgemmOp,
"The operation is not a gemm");
auto shapeTypeLhs =
dyn_cast<ShapedType>(brgemmOp.getOperand(0).getType());
auto shapeTypeRhs =
dyn_cast<ShapedType>(brgemmOp.getOperand(1).getType());

auto shapeLhs =
dyn_cast<MemRefType>(brgemmOp.getOperand(0).getType()).getShape();
auto shapeRhs =
dyn_cast<MemRefType>(brgemmOp.getOperand(1).getType()).getShape();
auto shapeLhs = shapeTypeLhs.getShape();
auto shapeRhs = shapeTypeRhs.getShape();

if (reductionCount == 2 &&
(shapeLhs.size() != 3 || shapeRhs.size() != 3))
Expand All @@ -98,24 +97,20 @@ struct LinalgOpTiling : OpRewritePattern<BrgemmOp> {
FailureOr<linalg::TiledLinalgOp> tiledOp;

// Get rank and map of linalg op
unsigned rankA =
(dyn_cast<ShapedType>((brgemmOp->getOperand(0)).getType())).getRank();
unsigned rankB =
(dyn_cast<ShapedType>((brgemmOp->getOperand(1)).getType())).getRank();
unsigned rankA = shapeTypeLhs.getRank();
unsigned rankB = shapeTypeRhs.getRank();
AffineMap mapA =
brgemmOp.getMatchingIndexingMap(&brgemmOp->getOpOperand(0));
AffineMap mapB =
brgemmOp.getMatchingIndexingMap(&brgemmOp->getOpOperand(1));

if (vnniOpt) {
// k-tile size adjusted based on the vnni layout for bf16 type
auto shape =
dyn_cast<MemRefType>(brgemmOp.getOperand(0).getType()).getShape();
auto kTileVnni = options.registerTileShape[2] / shape[3];
auto kTileVnni = options.registerTileShape[2] / shapeLhs[3];

// Note: We make an assumption that the k tile size is divisible to
// the powers of 2.
if (kTileVnni < 1 || (options.registerTileShape[2] % shape[3] != 0))
if (kTileVnni < 1 || (options.registerTileShape[2] % shapeLhs[3] != 0))
return rewriter.notifyMatchFailure(
brgemmOp, "Failed matching K tile size for batch reduce operation "
"with vnni layout. K tile size should be >= vnni layout "
Expand Down Expand Up @@ -144,7 +139,6 @@ struct LinalgOpTiling : OpRewritePattern<BrgemmOp> {
tilingOptions.setTileSizes({tileSizes[0], tileSizes[1], tileSizes[2],
tileSizes[3], tileSizes[4]});
tilingOptions.setInterchange({dimM, dimN, dimBR, dimK, vnniDim});
tiledOp = linalg::tileLinalgOp(rewriter, brgemmOp, tilingOptions);

} else {

Expand Down Expand Up @@ -176,9 +170,9 @@ struct LinalgOpTiling : OpRewritePattern<BrgemmOp> {
tilingOptions.setTileSizes(
{tileSizes[0], tileSizes[1], tileSizes[2], tileSizes[3]});
tilingOptions.setInterchange({dimM, dimN, dimBR, dimK});
tiledOp = linalg::tileLinalgOp(rewriter, brgemmOp, tilingOptions);
}

tiledOp = linalg::tileLinalgOp(rewriter, brgemmOp, tilingOptions);
if (failed(tiledOp)) {
return failure();
}
Expand Down
8 changes: 4 additions & 4 deletions test/Passes/pass-tile-brgemm-linalg-matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ module {

// -----

#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1, d4)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
module {
memref.global "private" constant @__constant_16x32x64x2xbf16 : memref<16x32x64x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64}
Expand Down Expand Up @@ -135,6 +135,7 @@ module {
// CONF1-LABEL: func.func @brgemm_tensor_type_no_tiling
func.func @brgemm_tensor_type_no_tiling(%arg0: tensor<128x256x512xf32>, %arg1: tensor<128x512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
// CONF1-NOT: scf.for
// CONF2-NOT: scf.for
%0 = linalg.batch_reduce_matmul ins(%arg0, %arg1 : tensor<128x256x512xf32>, tensor<128x512x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
return %0 : tensor<256x256xf32>
}
Expand All @@ -153,9 +154,8 @@ module {
// CONF1-LABEL: func.func @matmul_no_tiling
func.func @matmul_no_tiling(%arg0: memref<64x64xf32>, %arg1: memref<64x64xf32>, %arg2: memref<64x64xf32>) {
// CONF1-NOT: scf.for
// CONF2-NOT: scf.for
linalg.matmul ins(%arg0, %arg1 : memref<64x64xf32>, memref<64x64xf32>)
outs(%arg2 : memref<64x64xf32>)
return
}


0 comments on commit 34e5283

Please sign in to comment.