Skip to content

Commit

Permalink
tile sizes and interchange options are adjusted with respect to maps
Browse files Browse the repository at this point in the history
  • Loading branch information
Arun Thangamani committed Feb 19, 2025
1 parent 3306413 commit e7682e2
Showing 1 changed file with 77 additions and 26 deletions.
103 changes: 77 additions & 26 deletions lib/TPP/Transforms/BrgemmLinalgTiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,45 +92,96 @@ struct LinalgOpTiling : OpRewritePattern<BrgemmOp> {
brgemmOp,
"Failed matching for batch reduce operation with vnni layout");

// Get the register blocking tile shape from the user input
SmallVector<int64_t> mxnxkTile(options.registerTileShape.begin(),
options.registerTileShape.end());

linalg::LinalgTilingOptions options;
options.setLoopType(linalg::LinalgTilingLoopType::Loops);
// Tiling with the help of upstream APIs
linalg::LinalgTilingOptions tilingOptions;
tilingOptions.setLoopType(linalg::LinalgTilingLoopType::Loops);
FailureOr<linalg::TiledLinalgOp> tiledOp;

// Get rank and map of linalg op
unsigned rankA =
(dyn_cast<ShapedType>((brgemmOp->getOperand(0)).getType())).getRank();
unsigned rankB =
(dyn_cast<ShapedType>((brgemmOp->getOperand(1)).getType())).getRank();
AffineMap mapA =
brgemmOp.getMatchingIndexingMap(&brgemmOp->getOpOperand(0));
AffineMap mapB =
brgemmOp.getMatchingIndexingMap(&brgemmOp->getOpOperand(1));

if (vnniOpt) {
// k-tile size adjusted based on the vnni layout for bf16 type
auto tensorShape =
auto shape =
dyn_cast<MemRefType>(brgemmOp.getOperand(0).getType()).getShape();
auto kTileVnni = mxnxkTile[2] / tensorShape[3];
auto kTileVnni = options.registerTileShape[2] / shape[3];

// Note: We make an assumption that the k tile size is divisible to
// Note: We make an assumption that the k tile size is divisible to
// the powers of 2.
if (kTileVnni < 1)
return rewriter.notifyMatchFailure(
if (kTileVnni < 1 || (kTileVnni % 2 != 0))
return rewriter.notifyMatchFailure(
brgemmOp, "Failed matching K tile size for batch reduce operation "
"with vnni layout. K tile size should be >= vnni layout");

mxnxkTile[2] = kTileVnni;
// Tile options for bf16 type with vnni layout
options.setTileSizes({1, 0, mxnxkTile[0], mxnxkTile[1], mxnxkTile[2]});
options.setInterchange({2, 3, 0, 4, 1});
tiledOp =
linalg::tileLinalgOp(rewriter, brgemmOp, options);
"with vnni layout. K tile size should be >= vnni layout "
"and divisible by 2");

// Calculating the tile sizes based on affine map for bf16 type with vnni
auto vnniDim =
(dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1))).getPosition();
auto dimM =
(dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 3))).getPosition();
auto dimN =
(dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2))).getPosition();
auto dimBR =
(dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 4))).getPosition();
auto dimK =
(dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2))).getPosition();

// To set the loop interchange options
SmallVector<int64_t> tileSizes(5);
tileSizes[dimBR] = 1;
tileSizes[dimM] = options.registerTileShape[0];
tileSizes[dimN] = options.registerTileShape[1];
tileSizes[dimK] = kTileVnni;
tileSizes[vnniDim] = 0;

tilingOptions.setTileSizes({tileSizes[0], tileSizes[1], tileSizes[2],
tileSizes[3], tileSizes[4]});
tilingOptions.setInterchange({dimM, dimN, dimBR, dimK, vnniDim});
tiledOp = linalg::tileLinalgOp(rewriter, brgemmOp, tilingOptions);

} else {
// Tile options for f32 type.
options.setTileSizes({1, mxnxkTile[0], mxnxkTile[1], mxnxkTile[2]});
options.setInterchange({1, 2, 0, 3});
tiledOp =
linalg::tileLinalgOp(rewriter, brgemmOp, options);

// Calculating the tile sizes based on affine map for fp32 type
auto dimM =
(dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2))).getPosition();
auto dimN =
(dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1))).getPosition();
auto dimBR =
(dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 3))).getPosition();
auto dimK =
(dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1))).getPosition();

// Checks dimensions are aligned with the iterator types
if (brgemmIteratorTypes[dimM] != mlir::utils::IteratorType::parallel ||
brgemmIteratorTypes[dimN] != mlir::utils::IteratorType::parallel ||
brgemmIteratorTypes[dimBR] != mlir::utils::IteratorType::reduction ||
brgemmIteratorTypes[dimK] != mlir::utils::IteratorType::reduction)
return rewriter.notifyMatchFailure(
brgemmOp, "Failed macthing with iterator types and dimension");

// To set the loop interchange options
SmallVector<int64_t> tileSizes(4);
tileSizes[dimBR] = 1;
tileSizes[dimM] = options.registerTileShape[0];
tileSizes[dimN] = options.registerTileShape[1];
tileSizes[dimK] = options.registerTileShape[2];

tilingOptions.setTileSizes(
{tileSizes[0], tileSizes[1], tileSizes[2], tileSizes[3]});
tilingOptions.setInterchange({dimM, dimN, dimBR, dimK});
tiledOp = linalg::tileLinalgOp(rewriter, brgemmOp, tilingOptions);
}

if (failed(tiledOp)) {
return failure();
}
return failure();
}
rewriter.replaceOp(brgemmOp, tiledOp->op->getResults());

return success();
Expand Down

0 comments on commit e7682e2

Please sign in to comment.